module Math.Modulus.Sqrt where
import Control.Monad.ST
import Data.Bits (Bits (unsafeShiftR, (.&.)))
import Data.Function (fix)
import System.Random.Stateful
import Math.Modulus (powMod)
import System.Random.Utils
legendreSymbol :: Int -> Int -> Int
legendreSymbol :: Int -> Int -> Int
legendreSymbol Int
a Int
p = Int -> Int -> Int -> Int
forall a. Integral a => a -> Int -> a -> a
powMod Int
a (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
2) Int
p
sqrtMod :: Int -> Int -> [Int]
sqrtMod :: Int -> Int -> [Int]
sqrtMod Int
0 Int
_ = [Int
0]
sqrtMod Int
a Int
2 = [Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod Int
a Int
2]
sqrtMod Int
a Int
p = case Int -> Int -> Int
legendreSymbol Int
a Int
p of
Int
0 -> [Int
0]
Int
1
| Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem Int
p Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3 -> Int -> [Int]
withConjugate (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> Int
forall a. Integral a => a -> Int -> a -> a
powMod Int
a (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
4) Int
p
| Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem Int
p Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 -> Int -> [Int]
withConjugate (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int
cipolla (Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod Int
a Int
p) Int
p
Int
x | Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 -> []
Int
_ -> [Char] -> [Int]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [Int]) -> [Char] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Char]
"sqrtMod: unreachable (a, p) = " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Int
a, Int
p)
where
withConjugate :: Int -> [Int]
withConjugate !Int
x =
let !x' :: Int
x' = Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x
in [Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
x Int
x', Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
x Int
x']
cipolla :: Int -> Int -> Int
cipolla :: Int -> Int -> Int
cipolla Int
a Int
p = (Int, Int) -> Int
forall a b. (a, b) -> a
fst ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Int -> (Int, Int)
forall {a}. (Num a, Bits a) => (Int, Int) -> a -> (Int, Int)
pow (Int
ns, Int
1) (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
2)
where
ns :: Int
!ns :: Int
ns = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ (PrimGenM StdGen (PrimState (ST s)) -> ST s Int) -> ST s Int
forall (m :: * -> *) a.
PrimMonad m =>
(PrimGenM StdGen (PrimState m) -> m a) -> m a
withGlobalStdGen_ ((PrimGenM StdGen (PrimState (ST s)) -> ST s Int) -> ST s Int)
-> (PrimGenM StdGen (PrimState (ST s)) -> ST s Int) -> ST s Int
forall a b. (a -> b) -> a -> b
$ \PrimGenM StdGen (PrimState (ST s))
rng ->
(ST s Int -> ST s Int) -> ST s Int
forall a. (a -> a) -> a
fix ((ST s Int -> ST s Int) -> ST s Int)
-> (ST s Int -> ST s Int) -> ST s Int
forall a b. (a -> b) -> a -> b
$ \ST s Int
loop -> do
Int
x <- (Int, Int) -> PrimGenM StdGen s -> ST s Int
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *). StatefulGen g m => (Int, Int) -> g -> m Int
uniformRM (Int
0, Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) PrimGenM StdGen s
PrimGenM StdGen (PrimState (ST s))
rng
case Int -> Int -> Int
legendreSymbol (Int
x Int -> Int -> Int
*% Int
x Int -> Int -> Int
-% Int
a) Int
p of
Int
0 -> ST s Int
loop
Int
1 -> ST s Int
loop
Int
_ -> Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x
!ww :: Int
ww = Int
ns Int -> Int -> Int
*% Int
ns Int -> Int -> Int
-% Int
a
mul :: (Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int
x0, Int
y0) (Int
x1, Int
y1) = (Int
x', Int
y')
where
!x' :: Int
x' = Int
x0 Int -> Int -> Int
*% Int
x1 Int -> Int -> Int
+% Int
y0 Int -> Int -> Int
*% Int
y1 Int -> Int -> Int
*% Int
ww
!y' :: Int
y' = Int
x0 Int -> Int -> Int
*% Int
y1 Int -> Int -> Int
+% Int
y0 Int -> Int -> Int
*% Int
x1
pow :: (Int, Int) -> a -> (Int, Int)
pow (Int, Int)
_ a
0 = (Int
1, Int
0)
pow (Int, Int)
xy0 a
n = (Int, Int) -> (Int, Int) -> a -> (Int, Int)
forall {a}.
(Bits a, Num a) =>
(Int, Int) -> (Int, Int) -> a -> (Int, Int)
go (Int
1, Int
0) (Int, Int)
xy0 a
n
where
go :: (Int, Int) -> (Int, Int) -> a -> (Int, Int)
go !(Int, Int)
acc !(Int, Int)
xy !a
i
| a
i a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = (Int, Int) -> (Int, Int) -> a -> (Int, Int)
go (Int, Int)
acc ((Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
xy (Int, Int)
xy) (a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftR a
i Int
1)
| a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1 = (Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
acc (Int, Int)
xy
| Bool
otherwise = (Int, Int) -> (Int, Int) -> a -> (Int, Int)
go ((Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
acc (Int, Int)
xy) ((Int, Int) -> (Int, Int) -> (Int, Int)
mul (Int, Int)
xy (Int, Int)
xy) (a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftR (a
i a -> a -> a
forall a. Num a => a -> a -> a
- a
1) Int
1)
infixl 7 *%
infixl 6 +%, -%
Int
x +% :: Int -> Int -> Int
+% Int
y = case Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
y of
Int
r
| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
p -> Int
r
| Bool
otherwise -> Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p
{-# INLINE (+%) #-}
Int
x -% :: Int -> Int -> Int
-% Int
y = case Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
y of
Int
r
| Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 -> Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p
| Bool
otherwise -> Int
r
{-# INLINE (-%) #-}
Int
x *% :: Int -> Int -> Int
*% Int
y = Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
y Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
p
{-# INLINE (*%) #-}