{-# LANGUAGE RecordWildCards #-}

module Math.NTT where

import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Function
import qualified Data.List.NonEmpty as NE
import Data.Proxy (Proxy (..))
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import GHC.TypeLits (KnownNat)

import Data.GaloisField (GF (GF), natValAsInt, reifyNat)
import Math.Prime (primeFactors)
import My.Prelude (
  rep,
  unsafeShiftRL,
  (..<),
  (>..),
 )

{- | Number Theoretic Transform
p: prime (c * 2 ^ k + 1)

n = 2 ^ i, n < 2 ^ k

 /O(n log n)/

>>> ntt @998244353 [1,1,1,1]
[4,0,0,0]
>>> ntt @469762049 [123,0,0,0]
[123,123,123,123]
-}
ntt ::
  forall p.
  (KnownNat p) =>
  U.Vector (GF p) ->
  U.Vector (GF p)
ntt :: forall (p :: Nat). KnownNat p => Vector (GF p) -> Vector (GF p)
ntt = (forall s. MVector s (GF p) -> ST s ())
-> Vector (GF p) -> Vector (GF p)
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
U.modify MVector s (GF p) -> ST s ()
MVector (PrimState (ST s)) (GF p) -> ST s ()
forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
forall s. MVector s (GF p) -> ST s ()
butterfly
{-# INLINE ntt #-}

intt :: forall p. (KnownNat p) => U.Vector (GF p) -> U.Vector (GF p)
intt :: forall (p :: Nat). KnownNat p => Vector (GF p) -> Vector (GF p)
intt Vector (GF p)
f = (GF p -> GF p) -> Vector (GF p) -> Vector (GF p)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map (GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
invn) (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s (GF p) -> ST s ())
-> Vector (GF p) -> Vector (GF p)
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
U.modify MVector s (GF p) -> ST s ()
MVector (PrimState (ST s)) (GF p) -> ST s ()
forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
forall s. MVector s (GF p) -> ST s ()
invButterfly Vector (GF p)
f
  where
    !invn :: GF p
invn = GF p -> GF p
forall a. Fractional a => a -> a
recip (Int -> GF p
forall (p :: Nat). Int -> GF p
GF (Int -> GF p) -> Int -> GF p
forall a b. (a -> b) -> a -> b
$ Vector (GF p) -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector (GF p)
f)
{-# INLINE intt #-}

{- |
>>> convolute @998244353 [1,1,1,0] [1,1,1,0]
[1,2,3,2,1,0,0]
>>> convolute @998244353 [1,1,1] [1,1,1,0]
[1,2,3,2,1,0]
-}
convolute ::
  forall p.
  (KnownNat p) =>
  U.Vector (GF p) ->
  U.Vector (GF p) ->
  U.Vector (GF p)
convolute :: forall (p :: Nat).
KnownNat p =>
Vector (GF p) -> Vector (GF p) -> Vector (GF p)
convolute Vector (GF p)
xs Vector (GF p)
ys = (forall s. ST s (MVector s (GF p))) -> Vector (GF p)
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s (GF p))) -> Vector (GF p))
-> (forall s. ST s (MVector s (GF p))) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ do
  mxs <- Int -> GF p -> ST s (MVector (PrimState (ST s)) (GF p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
len (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
0)
  U.unsafeCopy (UM.take n mxs) xs
  butterfly mxs
  mys <- UM.replicate len (GF 0)
  U.unsafeCopy (UM.take m mys) ys
  butterfly mys
  rep len $ \Int
i -> do
    yi <- MVector (PrimState (ST s)) (GF p) -> Int -> ST s (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mys Int
i
    UM.unsafeModify mxs (* yi) i
  invButterfly mxs
  rep (n + m - 1) $ \Int
i -> do
    MVector (PrimState (ST s)) (GF p)
-> (GF p -> GF p) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
mxs (GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
ilen) Int
i
  return $ UM.take (n + m - 1) mxs
  where
    n :: Int
n = Vector (GF p) -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector (GF p)
xs
    m :: Int
m = Vector (GF p) -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector (GF p)
ys
    !h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int
extendToPowerOfTwo (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    !len :: Int
len = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 Int
h
    !ilen :: GF p
ilen = GF p -> GF p
forall a. Fractional a => a -> a
recip (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
len)
{-# INLINE convolute #-}

data NTTRunner p = NTTRunner
  { forall (p :: Nat). NTTRunner p -> Vector (GF p)
sesNR :: !(U.Vector (GF p))
  , forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: !(U.Vector (GF p))
  }

nttRunner :: forall p. (KnownNat p) => NTTRunner p
nttRunner :: forall (p :: Nat). KnownNat p => NTTRunner p
nttRunner = NTTRunner{Vector (GF p)
sesNR :: Vector (GF p)
siesNR :: Vector (GF p)
sesNR :: Vector (GF p)
siesNR :: Vector (GF p)
..}
  where
    p :: Int
p = Proxy p -> Int
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
natValAsInt (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @p)
    g :: Int
g = Int -> Int
primitiveRoot Int
p

    ctz :: Int
ctz = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    !e :: GF p
e = Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
g GF p -> Int -> GF p
forall a b. (Num a, Integral b) => a -> b -> a
^ Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
ctz
    !ie :: GF p
ie = GF p -> GF p
forall a. Fractional a => a -> a
recip GF p
e

    es :: Vector (GF p)
es = Vector (GF p) -> Vector (GF p)
forall a. Unbox a => Vector a -> Vector a
U.reverse (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ Int -> (GF p -> GF p) -> GF p -> Vector (GF p)
forall a. Unbox a => Int -> (a -> a) -> a -> Vector a
U.iterateN (Int
ctz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (\GF p
x -> GF p
x GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
x) GF p
e
    ies :: Vector (GF p)
ies = Vector (GF p) -> Vector (GF p)
forall a. Unbox a => Vector a -> Vector a
U.reverse (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ Int -> (GF p -> GF p) -> GF p -> Vector (GF p)
forall a. Unbox a => Int -> (a -> a) -> a -> Vector a
U.iterateN (Int
ctz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (\GF p
x -> GF p
x GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
x) GF p
ie

    sesNR :: Vector (GF p)
sesNR = (GF p -> GF p -> GF p)
-> Vector (GF p) -> Vector (GF p) -> Vector (GF p)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) Vector (GF p)
es (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ (GF p -> GF p -> GF p) -> GF p -> Vector (GF p) -> Vector (GF p)
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) GF p
1 Vector (GF p)
ies
    siesNR :: Vector (GF p)
siesNR = (GF p -> GF p -> GF p)
-> Vector (GF p) -> Vector (GF p) -> Vector (GF p)
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) Vector (GF p)
ies (Vector (GF p) -> Vector (GF p)) -> Vector (GF p) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ (GF p -> GF p -> GF p) -> GF p -> Vector (GF p) -> Vector (GF p)
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) GF p
1 Vector (GF p)
es
{-# NOINLINE nttRunner #-}

butterfly ::
  (KnownNat p, PrimMonad m) =>
  UM.MVector (PrimState m) (GF p) ->
  m ()
butterfly :: forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
butterfly MVector (PrimState m) (GF p)
mvec = do
  ((Int -> m ()) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m ()) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int
1 Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
ph -> do
    let !w :: Int
w = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        !p :: Int
p = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph)
    m (GF p) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (GF p) -> m ()) -> m (GF p) -> m ()
forall a b. (a -> b) -> a -> b
$
      (GF p -> Int -> m (GF p)) -> GF p -> Stream m Int -> m (GF p)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldlM'
        ( \GF p
acc Int
s -> do
            let offset :: Int
offset = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
s (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            ((Int -> m ()) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m ()) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int
offset Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)) ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
              l <- MVector (PrimState m) (GF p) -> Int -> m (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) (GF p)
mvec Int
i
              r <- (* acc) <$!> UM.unsafeRead mvec (i + p)
              UM.unsafeWrite mvec (i + p) $ l - r
              UM.unsafeWrite mvec i $ l + r
            GF p -> m (GF p)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (GF p -> m (GF p)) -> GF p -> m (GF p)
forall a b. (a -> b) -> a -> b
$! GF p
acc GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Vector (GF p) -> Int -> GF p
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector (GF p)
siesNR (Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s))
        )
        GF p
1
        (Int
0 Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
w)
  where
    n :: Int
n = MVector (PrimState m) (GF p) -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) (GF p)
mvec
    !h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int
extendToPowerOfTwo Int
n
    NTTRunner{Vector (GF p)
sesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: Vector (GF p)
sesNR :: Vector (GF p)
..} = NTTRunner p
forall (p :: Nat). KnownNat p => NTTRunner p
nttRunner
{-# INLINE butterfly #-}

invButterfly ::
  forall p m.
  (KnownNat p, PrimMonad m) =>
  UM.MVector (PrimState m) (GF p) ->
  m ()
invButterfly :: forall (p :: Nat) (m :: * -> *).
(KnownNat p, PrimMonad m) =>
MVector (PrimState m) (GF p) -> m ()
invButterfly MVector (PrimState m) (GF p)
mvec = m () -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  ((Int -> m (GF p)) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m (GF p)) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m (GF p)) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ ((Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
>.. Int
1) ((Int -> m (GF p)) -> m ()) -> (Int -> m (GF p)) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
ph -> do
    let !w :: Int
w = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        !p :: Int
p = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph)
    (GF p -> Int -> m (GF p)) -> GF p -> Stream m Int -> m (GF p)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldlM'
      ( \GF p
acc Int
s -> do
          let offset :: Int
offset = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
s (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
ph Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          ((Int -> m ()) -> Stream m Int -> m ())
-> Stream m Int -> (Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> m ()) -> Stream m Int -> m ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int
offset Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p)) ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
            l <- MVector (PrimState m) (GF p) -> Int -> m (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) (GF p)
mvec Int
i
            r <- UM.unsafeRead mvec (i + p)
            UM.unsafeWrite mvec (i + p) $ acc * (l - r)
            UM.unsafeWrite mvec i $ l + r
          GF p -> m (GF p)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (GF p -> m (GF p)) -> GF p -> m (GF p)
forall a b. (a -> b) -> a -> b
$! GF p
acc GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Vector (GF p) -> Int -> GF p
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector (GF p)
sesNR (Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int
forall a. Bits a => a -> a
complement Int
s))
      )
      GF p
1
      (Int
0 Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
w)
  where
    n :: Int
n = MVector (PrimState m) (GF p) -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) (GF p)
mvec
    !h :: Int
h = Int -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int
extendToPowerOfTwo Int
n
    NTTRunner{Vector (GF p)
sesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
siesNR :: forall (p :: Nat). NTTRunner p -> Vector (GF p)
sesNR :: Vector (GF p)
siesNR :: Vector (GF p)
..} = NTTRunner p
forall (p :: Nat). KnownNat p => NTTRunner p
nttRunner
{-# INLINE invButterfly #-}

{- |
>>> growToPowerOfTwo (U.fromListN 3 [1::Int,2,3])
[1,2,3,0]
-}
growToPowerOfTwo :: (Num a, U.Unbox a) => U.Vector a -> U.Vector a
growToPowerOfTwo :: forall a. (Num a, Unbox a) => Vector a -> Vector a
growToPowerOfTwo Vector a
v
  | Vector a -> Bool
forall a. Unbox a => Vector a -> Bool
U.null Vector a
v = a -> Vector a
forall a. Unbox a => a -> Vector a
U.singleton a
0
  | Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Vector a
v
  | Int
n <- Int -> Int -> Int
unsafeShiftRL (-Int
1) (Int -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros (Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 =
      Vector a
v Vector a -> Vector a -> Vector a
forall a. Unbox a => Vector a -> Vector a -> Vector a
U.++ Int -> a -> Vector a
forall a. Unbox a => Int -> a -> Vector a
U.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v) a
0

{- |
>>> extendToPowerOfTwo 0
1
-}
extendToPowerOfTwo :: Int -> Int
extendToPowerOfTwo :: Int -> Int
extendToPowerOfTwo Int
x
  | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 = Int -> Int -> Int
unsafeShiftRL (-Int
1) (Int -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  | Bool
otherwise = Int
1

{- |
>>> primitiveRoot 2
1
>>> primitiveRoot 998244353
3
-}
primitiveRoot ::
  -- | prime
  Int ->
  Int
primitiveRoot :: Int -> Int
primitiveRoot Int
2 = Int
1
primitiveRoot Int
prime = Int -> (forall (n :: Nat). KnownNat n => Proxy n -> Int) -> Int
forall i a.
Integral i =>
i -> (forall (n :: Nat). KnownNat n => Proxy n -> a) -> a
reifyNat Int
prime ((forall (n :: Nat). KnownNat n => Proxy n -> Int) -> Int)
-> (forall (n :: Nat). KnownNat n => Proxy n -> Int) -> Int
forall a b. (a -> b) -> a -> b
$ \Proxy n
proxy ->
  (((Int -> Int) -> Int -> Int) -> Int -> Int)
-> Int -> ((Int -> Int) -> Int -> Int) -> Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> Int) -> Int -> Int) -> Int -> Int
forall a. (a -> a) -> a
fix Int
2 (((Int -> Int) -> Int -> Int) -> Int)
-> ((Int -> Int) -> Int -> Int) -> Int
forall a b. (a -> b) -> a -> b
$ \Int -> Int
loop !Int
g ->
    if (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (GF n -> Int -> Bool
forall (p :: Nat). KnownNat p => GF p -> Int -> Bool
check (Proxy n -> Int -> GF n
forall (p :: Nat). Proxy p -> Int -> GF p
toGF Proxy n
proxy Int
g)) [Int]
ps
    then Int
g
    else Int -> Int
loop (Int
g Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  where
    !ps :: [Int]
ps = (NonEmpty Int -> Int) -> [NonEmpty Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map NonEmpty Int -> Int
forall a. NonEmpty a -> a
NE.head ([NonEmpty Int] -> [Int])
-> ([Int] -> [NonEmpty Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [NonEmpty Int]
forall (f :: * -> *) a. (Foldable f, Eq a) => f a -> [NonEmpty a]
NE.group ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> [Int]
forall i. Integral i => i -> [i]
primeFactors (Int
prime Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

    toGF :: Proxy p -> Int -> GF p
    toGF :: forall (p :: Nat). Proxy p -> Int -> GF p
toGF Proxy p
_ = Int -> GF p
forall (p :: Nat). Int -> GF p
GF

    check :: (KnownNat p) => GF p -> Int -> Bool
    check :: forall (p :: Nat). KnownNat p => GF p -> Int -> Bool
check GF p
g Int
p = GF p
g GF p -> Int -> GF p
forall a b. (Num a, Integral b) => a -> b -> a
^ Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (Int
prime Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
p GF p -> GF p -> Bool
forall a. Eq a => a -> a -> Bool
/= Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
1