module Math.Modulus.Sqrt where

import Control.Monad.ST
import Data.Bits (Bits (unsafeShiftR, (.&.)))
import Data.Function (fix)
import System.Random.Stateful

import Math.Modulus (powMod)
import System.Random.Utils

{- | Legendre symbol (Euler's criterion)

p is /odd/ prime
-}
legendreSymbol :: Int -> Int -> Int
legendreSymbol :: Int -> Int -> Int
legendreSymbol Int
a Int
p = Int -> Int -> Int -> Int
forall a. Integral a => a -> Int -> a -> a
powMod Int
a (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
2) Int
p

{- |
@x^2 = a (mod p)@ p is prime

>>> sqrtMod 2 1000000007
[59713600,940286407]
>>> sqrtMod 3 1000000007
[82062379,917937628]
>>> sqrtMod 4 1000000007
[2,1000000005]
>>> sqrtMod 5 1000000007
[]
-}
sqrtMod :: Int -> Int -> [Int]
sqrtMod :: Int -> Int -> [Int]
sqrtMod Int
0 Int
_ = [Int
0]
sqrtMod Int
a Int
2 = [Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod Int
a Int
2]
sqrtMod Int
a Int
p = case Int -> Int -> Int
legendreSymbol Int
a Int
p of
  Int
0 -> [Int
0]
  Int
1
    | Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem Int
p Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 -> Int -> [Int]
withConjugate (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> Int
forall a. Integral a => a -> Int -> a -> a
powMod Int
a (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
4) Int
p
    | Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem Int
p Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 -> Int -> [Int]
withConjugate (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int
cipolla (Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod Int
a Int
p) Int
p
  Int
x | Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 -> []
  Int
_ -> [Char] -> [Int]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [Int]) -> [Char] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Char]
"sqrtMod: unreachable (a, p) = " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Int
a, Int
p)
  where
    withConjugate :: Int -> [Int]
withConjugate !Int
x =
      let !x' :: Int
x' = Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x
       in [Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
x Int
x', Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
x Int
x']

{- |
@cipolla a p ^ 2 = a (mod p)@ p is odd prime, (a | p) = 1

>>> cipolla 9 998244353
3
>>> cipolla 9 1000000007
1000000004
-}
cipolla :: Int -> Int -> Int
cipolla :: Int -> Int -> Int
cipolla Int
a Int
p = (Int, Int) -> Int
forall a b. (a, b) -> a
fst ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Int -> (Int, Int)
forall {a}. (Num a, Bits a) => (Int, Int) -> a -> (Int, Int)
pow (Int
ns, Int
1) (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
2)
  where
    ns :: Int
    !ns :: Int
ns = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ (PrimGenM StdGen (PrimState (ST s)) -> ST s Int) -> ST s Int
forall (m :: * -> *) a.
PrimMonad m =>
(PrimGenM StdGen (PrimState m) -> m a) -> m a
withGlobalStdGen_ ((PrimGenM StdGen (PrimState (ST s)) -> ST s Int) -> ST s Int)
-> (PrimGenM StdGen (PrimState (ST s)) -> ST s Int) -> ST s Int
forall a b. (a -> b) -> a -> b
$ \PrimGenM StdGen (PrimState (ST s))
rng ->
      (ST s Int -> ST s Int) -> ST s Int
forall a. (a -> a) -> a
fix ((ST s Int -> ST s Int) -> ST s Int)
-> (ST s Int -> ST s Int) -> ST s Int
forall a b. (a -> b) -> a -> b
$ \ST s Int
loop -> do
        x <- (Int, Int) -> PrimGenM StdGen s -> ST s Int
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (Int, Int) -> g -> m Int
uniformRM (Int
0, Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) PrimGenM StdGen s
PrimGenM StdGen (PrimState (ST s))
rng
        case legendreSymbol (x *% x -% a) p of
          Int
0 -> ST s Int
loop
          Int
1 -> ST s Int
loop
          Int
_ -> Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x

    !ww :: Int
ww = Int
ns Int -> Int -> Int
*% Int
ns Int -> Int -> Int
-% Int
a

    mul :: (Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int
x0, Int
y0) (Int
x1, Int
y1) = (Int
x', Int
y')
      where
        !x' :: Int
x' = Int
x0 Int -> Int -> Int
*% Int
x1 Int -> Int -> Int
+% Int
y0 Int -> Int -> Int
*% Int
y1 Int -> Int -> Int
*% Int
ww
        !y' :: Int
y' = Int
x0 Int -> Int -> Int
*% Int
y1 Int -> Int -> Int
+% Int
y0 Int -> Int -> Int
*% Int
x1

    pow :: (Int, Int) -> a -> (Int, Int)
pow (Int, Int)
_ a
0 = (Int
1, Int
0)
    pow (Int, Int)
xy0 a
n = (Int, Int) -> (Int, Int) -> a -> (Int, Int)
forall {a}.
(Bits a, Num a) =>
(Int, Int) -> (Int, Int) -> a -> (Int, Int)
go (Int
1, Int
0) (Int, Int)
xy0 a
n
      where
        go :: (Int, Int) -> (Int, Int) -> a -> (Int, Int)
go !(Int, Int)
acc !(Int, Int)
xy !a
i
          | a
i a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = (Int, Int) -> (Int, Int) -> a -> (Int, Int)
go (Int, Int)
acc ((Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
xy (Int, Int)
xy) (a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftR a
i Int
1)
          | a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 = (Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
acc (Int, Int)
xy
          | Bool
otherwise = (Int, Int) -> (Int, Int) -> a -> (Int, Int)
go ((Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
acc (Int, Int)
xy) ((Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
xy (Int, Int)
xy) (a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftR (a
i a -> a -> a
forall a. Num a => a -> a -> a
- a
1) Int
1)

    infixl 7 *%
    infixl 6 +%, -%

    Int
x +% :: Int -> Int -> Int
+% Int
y = case Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
y of
      Int
r
        | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
p -> Int
r
        | Bool
otherwise -> Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p
    {-# INLINE (+%) #-}

    Int
x -% :: Int -> Int -> Int
-% Int
y = case Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
y of
      Int
r
        | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 -> Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p
        | Bool
otherwise -> Int
r
    {-# INLINE (-%) #-}

    Int
x *% :: Int -> Int -> Int
*% Int
y = Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
y Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
p
    {-# INLINE (*%) #-}