{-# LANGUAGE RecordWildCards #-}
module Math.NTT where
import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Function
import qualified Data.List.NonEmpty as NE
import Data.Proxy (Proxy (..))
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import GHC.TypeLits (KnownNat)
import Data.GaloisField (GF (GF), natValAsInt, reifyNat)
import Math.Prime (primeFactors)
import My.Prelude (
rep,
unsafeShiftRL,
(..<),
(>..),
)
ntt ::
forall p.
(KnownNat p) =>
U.Vector (GF p) ->
U.Vector (GF p)
ntt :: forall (p :: Nat). KnownNat p => Vector (GF p) -> Vector (GF p)
ntt = (forall s. MVector s (GF p) -> ST s ())
-> Vector (GF p) -> Vector (GF p)
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
U.modify MVector s (GF p) -> ST s ()
MVector (PrimState (ST s)) (GF p) -> ST s ()
forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
forall s. MVector s (GF p) -> ST s ()
butterfly
{-# INLINE ntt #-}
intt :: forall p. (KnownNat p) => U.Vector (GF p) -> U.Vector (GF p)
intt :: forall (p :: Nat). KnownNat p => Vector (GF p) -> Vector (GF p)
intt Vector (GF p)
f = (GF p -> GF p) -> Vector (GF p) -> Vector (GF p)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map (GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
invn) (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s (GF p) -> ST s ())
-> Vector (GF p) -> Vector (GF p)
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
U.modify MVector s (GF p) -> ST s ()
MVector (PrimState (ST s)) (GF p) -> ST s ()
forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
forall s. MVector s (GF p) -> ST s ()
invButterfly Vector (GF p)
f
where
!invn :: GF p
invn = GF p -> GF p
forall a. Fractional a => a -> a
recip (Int -> GF p
forall (p :: Nat). Int -> GF p
GF (Int -> GF p) -> Int -> GF p
forall a b. (a -> b) -> a -> b
$ Vector (GF p) -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector (GF p)
f)
{-# INLINE intt #-}
convolute ::
forall p.
(KnownNat p) =>
U.Vector (GF p) ->
U.Vector (GF p) ->
U.Vector (GF p)
convolute :: forall (p :: Nat).
KnownNat p =>
Vector (GF p) -> Vector (GF p) -> Vector (GF p)
convolute Vector (GF p)
xs Vector (GF p)
ys = (forall s. ST s (MVector s (GF p))) -> Vector (GF p)
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s (GF p))) -> Vector (GF p))
-> (forall s. ST s (MVector s (GF p))) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ do
MVector s (GF p)
mxs <- Int -> GF p -> ST s (MVector (PrimState (ST s)) (GF p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
len (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
0)
MVector (PrimState (ST s)) (GF p) -> Vector (GF p) -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
U.unsafeCopy (Int -> MVector s (GF p) -> MVector s (GF p)
forall a s. Unbox a => Int -> MVector s a -> MVector s a
UM.take Int
n MVector s (GF p)
mxs) Vector (GF p)
xs
MVector (PrimState (ST s)) (GF p) -> ST s ()
forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
butterfly MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mxs
MVector s (GF p)
mys <- Int -> GF p -> ST s (MVector (PrimState (ST s)) (GF p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
len (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
0)
MVector (PrimState (ST s)) (GF p) -> Vector (GF p) -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
U.unsafeCopy (Int -> MVector s (GF p) -> MVector s (GF p)
forall a s. Unbox a => Int -> MVector s a -> MVector s a
UM.take Int
m MVector s (GF p)
mys) Vector (GF p)
ys
MVector (PrimState (ST s)) (GF p) -> ST s ()
forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
butterfly MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mys
Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
len ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
GF p
yi <- MVector (PrimState (ST s)) (GF p) -> Int -> ST s (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mys Int
i
MVector (PrimState (ST s)) (GF p)
-> (GF p -> GF p) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mxs (GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
yi) Int
i
MVector (PrimState (ST s)) (GF p) -> ST s ()
forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
invButterfly MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mxs
Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
MVector (PrimState (ST s)) (GF p)
-> (GF p -> GF p) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mxs (GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
ilen) Int
i
MVector s (GF p) -> ST s (MVector s (GF p))
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector s (GF p) -> ST s (MVector s (GF p)))
-> MVector s (GF p) -> ST s (MVector s (GF p))
forall a b. (a -> b) -> a -> b
$ Int -> MVector s (GF p) -> MVector s (GF p)
forall a s. Unbox a => Int -> MVector s a -> MVector s a
UM.take (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) MVector s (GF p)
mxs
where
n :: Int
n = Vector (GF p) -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector (GF p)
xs
m :: Int
m = Vector (GF p) -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector (GF p)
ys
!h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int
extendToPowerOfTwo (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
!len :: Int
len = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 Int
h
!ilen :: GF p
ilen = GF p -> GF p
forall a. Fractional a => a -> a
recip (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
len)
{-# INLINE convolute #-}
data NTTRunner p = NTTRunner
{ forall (p :: Nat). NTTRunner p -> Vector (GF p)
sesNR :: !(U.Vector (GF p))
, forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: !(U.Vector (GF p))
}
nttRunner :: forall p. (KnownNat p) => NTTRunner p
nttRunner :: forall (p :: Nat). KnownNat p => NTTRunner p
nttRunner = NTTRunner{Vector (GF p)
sesNR :: Vector (GF p)
siesNR :: Vector (GF p)
sesNR :: Vector (GF p)
siesNR :: Vector (GF p)
..}
where
p :: Int
p = Proxy p -> Int
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
natValAsInt (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @p)
g :: Int
g = Int -> Int
primitiveRoot Int
p
ctz :: Int
ctz = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
!e :: GF p
e = Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
g GF p -> Int -> GF p
forall a b. (Num a, Integral b) => a -> b -> a
^ Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
ctz
!ie :: GF p
ie = GF p -> GF p
forall a. Fractional a => a -> a
recip GF p
e
es :: Vector (GF p)
es = Vector (GF p) -> Vector (GF p)
forall a. Unbox a => Vector a -> Vector a
U.reverse (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ Int -> (GF p -> GF p) -> GF p -> Vector (GF p)
forall a. Unbox a => Int -> (a -> a) -> a -> Vector a
U.iterateN (Int
ctz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (\GF p
x -> GF p
x GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
x) GF p
e
ies :: Vector (GF p)
ies = Vector (GF p) -> Vector (GF p)
forall a. Unbox a => Vector a -> Vector a
U.reverse (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ Int -> (GF p -> GF p) -> GF p -> Vector (GF p)
forall a. Unbox a => Int -> (a -> a) -> a -> Vector a
U.iterateN (Int
ctz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (\GF p
x -> GF p
x GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
x) GF p
ie
sesNR :: Vector (GF p)
sesNR = (GF p -> GF p -> GF p)
-> Vector (GF p) -> Vector (GF p) -> Vector (GF p)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) Vector (GF p)
es (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ (GF p -> GF p -> GF p) -> GF p -> Vector (GF p) -> Vector (GF p)
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) GF p
1 Vector (GF p)
ies
siesNR :: Vector (GF p)
siesNR = (GF p -> GF p -> GF p)
-> Vector (GF p) -> Vector (GF p) -> Vector (GF p)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) Vector (GF p)
ies (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ (GF p -> GF p -> GF p) -> GF p -> Vector (GF p) -> Vector (GF p)
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) GF p
1 Vector (GF p)
es
{-# NOINLINE nttRunner #-}
butterfly ::
(KnownNat p, PrimMonad m) =>
UM.MVector (PrimState m) (GF p) ->
m ()
butterfly :: forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
butterfly MVector (PrimState m) (GF p)
mvec = do
((Int -> m ()) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m ()) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int
1 Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
ph -> do
let !w :: Int
w = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
!p :: Int
p = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph)
m (GF p) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (GF p) -> m ()) -> m (GF p) -> m ()
forall a b. (a -> b) -> a -> b
$
(GF p -> Int -> m (GF p)) -> GF p -> Stream m Int -> m (GF p)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldlM'
( \GF p
acc Int
s -> do
let offset :: Int
offset = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
s (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
((Int -> m ()) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m ()) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int
offset Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)) ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
GF p
l <- MVector (PrimState m) (GF p) -> Int -> m (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) (GF p)
mvec Int
i
GF p
r <- (GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
acc) (GF p -> GF p) -> m (GF p) -> m (GF p)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState m) (GF p) -> Int -> m (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) (GF p)
mvec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)
MVector (PrimState m) (GF p) -> Int -> GF p -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) (GF p)
mvec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p) (GF p -> m ()) -> GF p -> m ()
forall a b. (a -> b) -> a -> b
$ GF p
l GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
- GF p
r
MVector (PrimState m) (GF p) -> Int -> GF p -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) (GF p)
mvec Int
i (GF p -> m ()) -> GF p -> m ()
forall a b. (a -> b) -> a -> b
$ GF p
l GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
+ GF p
r
GF p -> m (GF p)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (GF p -> m (GF p)) -> GF p -> m (GF p)
forall a b. (a -> b) -> a -> b
$! GF p
acc GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Vector (GF p) -> Int -> GF p
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector (GF p)
siesNR (Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s))
)
GF p
1
(Int
0 Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
w)
where
n :: Int
n = MVector (PrimState m) (GF p) -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) (GF p)
mvec
!h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int
extendToPowerOfTwo Int
n
NTTRunner{Vector (GF p)
sesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: Vector (GF p)
sesNR :: Vector (GF p)
..} = NTTRunner p
forall (p :: Nat). KnownNat p => NTTRunner p
nttRunner
{-# INLINE butterfly #-}
invButterfly ::
forall p m.
(KnownNat p, PrimMonad m) =>
UM.MVector (PrimState m) (GF p) ->
m ()
invButterfly :: forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
invButterfly MVector (PrimState m) (GF p)
mvec = m () -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
((Int -> m (GF p)) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m (GF p)) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m (GF p)) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ ((Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
>.. Int
1) ((Int -> m (GF p)) -> m ()) -> (Int -> m (GF p)) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
ph -> do
let !w :: Int
w = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
!p :: Int
p = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph)
(GF p -> Int -> m (GF p)) -> GF p -> Stream m Int -> m (GF p)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldlM'
( \GF p
acc Int
s -> do
let offset :: Int
offset = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
s (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
((Int -> m ()) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m ()) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int
offset Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)) ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
GF p
l <- MVector (PrimState m) (GF p) -> Int -> m (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) (GF p)
mvec Int
i
GF p
r <- MVector (PrimState m) (GF p) -> Int -> m (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) (GF p)
mvec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)
MVector (PrimState m) (GF p) -> Int -> GF p -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) (GF p)
mvec (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p) (GF p -> m ()) -> GF p -> m ()
forall a b. (a -> b) -> a -> b
$ GF p
acc GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* (GF p
l GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
- GF p
r)
MVector (PrimState m) (GF p) -> Int -> GF p -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) (GF p)
mvec Int
i (GF p -> m ()) -> GF p -> m ()
forall a b. (a -> b) -> a -> b
$ GF p
l GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
+ GF p
r
GF p -> m (GF p)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (GF p -> m (GF p)) -> GF p -> m (GF p)
forall a b. (a -> b) -> a -> b
$! GF p
acc GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Vector (GF p) -> Int -> GF p
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector (GF p)
sesNR (Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s))
)
GF p
1
(Int
0 Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
w)
where
n :: Int
n = MVector (PrimState m) (GF p) -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) (GF p)
mvec
!h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int
extendToPowerOfTwo Int
n
NTTRunner{Vector (GF p)
sesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
sesNR :: Vector (GF p)
siesNR :: Vector (GF p)
..} = NTTRunner p
forall (p :: Nat). KnownNat p => NTTRunner p
nttRunner
{-# INLINE invButterfly #-}
growToPowerOfTwo :: (Num a, U.Unbox a) => U.Vector a -> U.Vector a
growToPowerOfTwo :: forall a. (Num a, Unbox a) => Vector a -> Vector a
growToPowerOfTwo Vector a
v
| Vector a -> Bool
forall a. Unbox a => Vector a -> Bool
U.null Vector a
v = a -> Vector a
forall a. Unbox a => a -> Vector a
U.singleton a
0
| Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Vector a
v
| Int
n <- Int -> Int -> Int
unsafeShiftRL (-Int
1) (Int -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros (Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 =
Vector a
v Vector a -> Vector a -> Vector a
forall a. Unbox a => Vector a -> Vector a -> Vector a
U.++ Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
U.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v) a
0
extendToPowerOfTwo :: Int -> Int
extendToPowerOfTwo :: Int -> Int
extendToPowerOfTwo Int
x
| Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 = Int -> Int -> Int
unsafeShiftRL (-Int
1) (Int -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
| Bool
otherwise = Int
1
primitiveRoot ::
Int ->
Int
primitiveRoot :: Int -> Int
primitiveRoot Int
2 = Int
1
primitiveRoot Int
prime = Int -> (forall {n :: Nat}. KnownNat n => Proxy n -> Int) -> Int
forall i a.
Integral i =>
i -> (forall (n :: Nat). KnownNat n => Proxy n -> a) -> a
reifyNat Int
prime ((forall {n :: Nat}. KnownNat n => Proxy n -> Int) -> Int)
-> (forall {n :: Nat}. KnownNat n => Proxy n -> Int) -> Int
forall a b. (a -> b) -> a -> b
$ \Proxy n
proxy ->
(((Int -> Int) -> Int -> Int) -> Int -> Int)
-> Int -> ((Int -> Int) -> Int -> Int) -> Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> Int) -> Int -> Int) -> Int -> Int
forall a. (a -> a) -> a
fix Int
2 (((Int -> Int) -> Int -> Int) -> Int)
-> ((Int -> Int) -> Int -> Int) -> Int
forall a b. (a -> b) -> a -> b
$ \Int -> Int
loop !Int
g ->
if (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (GF n -> Int -> Bool
forall (p :: Nat). KnownNat p => GF p -> Int -> Bool
check (Proxy n -> Int -> GF n
forall (p :: Nat). Proxy p -> Int -> GF p
toGF Proxy n
proxy Int
g)) [Int]
ps
then Int
g
else Int -> Int
loop (Int
g Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
where
!ps :: [Int]
ps = (NonEmpty Int -> Int) -> [NonEmpty Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map NonEmpty Int -> Int
forall a. NonEmpty a -> a
NE.head ([NonEmpty Int] -> [Int])
-> ([Int] -> [NonEmpty Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [NonEmpty Int]
forall (f :: * -> *) a. (Foldable f, Eq a) => f a -> [NonEmpty a]
NE.group ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> [Int]
forall i. Integral i => i -> [i]
primeFactors (Int
prime Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
toGF :: Proxy p -> Int -> GF p
toGF :: forall (p :: Nat). Proxy p -> Int -> GF p
toGF Proxy p
_ = Int -> GF p
forall (p :: Nat). Int -> GF p
GF
check :: (KnownNat p) => GF p -> Int -> Bool
check :: forall (p :: Nat). KnownNat p => GF p -> Int -> Bool
check GF p
g Int
p = GF p
g GF p -> Int -> GF p
forall a b. (Num a, Integral b) => a -> b -> a
^ Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
prime Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
p GF p -> GF p -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
1