{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

module Math.Combinatrics where

import Data.Coerce
import Data.Proxy
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import GHC.Exts
import GHC.TypeLits

import Data.GaloisField (GF (GF), natValAsInt)
import My.Prelude (rep, rep1)

newtype FactCache p = FactCache (U.Vector (GF p))
type HasFactCache (p :: Nat) = (?factCache :: FactCache p)

newtype RecipFactCache p = RecipFactCache (U.Vector (GF p))
type HasRecipFactCache (p :: Nat) = (?recipFactCache :: RecipFactCache p)

type HasCombCache (p :: Nat) = (HasFactCache p, HasRecipFactCache p)

{- | /O(1)/

>>> :set -XDataKinds
>>> withFactCache @1000000007 10 $ fact 10
3628800
-}
fact :: (HasFactCache p, KnownNat p) => Int -> GF p
fact :: forall (p :: Nat). (HasFactCache p, KnownNat p) => Int -> GF p
fact = Vector (GF p) -> Int -> GF p
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex (FactCache p -> Vector (GF p)
forall a b. Coercible a b => a -> b
coerce HasFactCache p
FactCache p
?factCache)
{-# INLINE fact #-}

-- | /O(1)/
recipFact :: (HasRecipFactCache p, KnownNat p) => Int -> GF p
recipFact :: forall (p :: Nat). (HasRecipFactCache p, KnownNat p) => Int -> GF p
recipFact = Vector (GF p) -> Int -> GF p
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex (RecipFactCache p -> Vector (GF p)
forall a b. Coercible a b => a -> b
coerce HasRecipFactCache p
RecipFactCache p
?recipFactCache)
{-# INLINE recipFact #-}

{- | /O(1)/

 n < p
-}
perm :: (HasFactCache p, HasRecipFactCache p, KnownNat p) => Int -> Int -> GF p
perm :: forall (p :: Nat).
(HasFactCache p, HasRecipFactCache p, KnownNat p) =>
Int -> Int -> GF p
perm Int
n Int
k
  | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
k, Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n = Int -> GF p
forall (p :: Nat). (HasFactCache p, KnownNat p) => Int -> GF p
fact Int
n GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Int -> GF p
forall (p :: Nat). (HasRecipFactCache p, KnownNat p) => Int -> GF p
recipFact (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k)
  | Bool
otherwise = Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
0
{-# INLINE perm #-}

{- | /O(1)/

 n < p
-}
comb :: (HasFactCache p, HasRecipFactCache p, KnownNat p) => Int -> Int -> GF p
comb :: forall (p :: Nat).
(HasFactCache p, HasRecipFactCache p, KnownNat p) =>
Int -> Int -> GF p
comb Int
n Int
k
  | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
k, Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n = Int -> GF p
forall (p :: Nat). (HasFactCache p, KnownNat p) => Int -> GF p
fact Int
n GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Int -> GF p
forall (p :: Nat). (HasRecipFactCache p, KnownNat p) => Int -> GF p
recipFact (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Int -> GF p
forall (p :: Nat). (HasRecipFactCache p, KnownNat p) => Int -> GF p
recipFact Int
k
  | Bool
otherwise = Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
0
{-# INLINE comb #-}

{- | /O(r)/

>>> combNaive 64 32
1832624140942590534
>>> combNaive 123456789 2
7620789313366866
>>> combNaive 123 456
0
-}
combNaive :: Int -> Int -> Int
combNaive :: Int -> Int -> Int
combNaive n :: Int
n@(I# Int#
ni#) r :: Int
r@(I# Int#
ri#)
  | Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
r, Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n = Word# -> Word# -> Int
go# Word#
1## Word#
1##
  | Bool
otherwise = Int
0
  where
    n# :: Word#
n# = Int# -> Word#
int2Word# Int#
ni#
    r# :: Word#
r# = Int# -> Word#
int2Word# Int#
ri#
    go# :: Word# -> Word# -> Int
go# Word#
acc# Word#
i#
      | Int# -> Bool
isTrue# (Word# -> Word# -> Int#
leWord# Word#
i# Word#
r#) =
          case Word# -> Word# -> (# Word#, Word# #)
timesWord2# Word#
acc# (Word# -> Word# -> Word#
minusWord# Word#
n# (Word# -> Word# -> Word#
minusWord# Word#
i# Word#
1##)) of
            (# Word#
x#, Word#
y# #) -> case Word# -> Word# -> Word# -> (# Word#, Word# #)
quotRemWord2# Word#
x# Word#
y# Word#
i# of
              (# Word#
z#, Word#
_ #) -> Word# -> Word# -> Int
go# Word#
z# (Word# -> Word# -> Word#
plusWord# Word#
i# Word#
1##)
      | Bool
otherwise = Int# -> Int
I# (Word# -> Int#
word2Int# Word#
acc#)

buildFactCache :: forall p. (KnownNat p) => Int -> FactCache p
buildFactCache :: forall (p :: Nat). KnownNat p => Int -> FactCache p
buildFactCache Int
n =
  Vector (GF p) -> FactCache p
forall (p :: Nat). Vector (GF p) -> FactCache p
FactCache
    (Vector (GF p) -> FactCache p)
-> (Vector Int -> Vector (GF p)) -> Vector Int -> FactCache p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (GF p -> Int -> GF p) -> GF p -> Vector Int -> Vector (GF p)
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' (\GF p
x Int
y -> GF p
x GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* Int -> GF p
forall a b. Coercible a b => a -> b
coerce Int
y) (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
1)
    (Vector Int -> FactCache p) -> Vector Int -> FactCache p
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate Int
size (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  where
    size :: Int
size = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n (Proxy p -> Int
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
natValAsInt (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @p) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

withFactCache :: forall p r. (KnownNat p) => Int -> ((HasFactCache p) => r) -> r
withFactCache :: forall (p :: Nat) r.
KnownNat p =>
Int -> (HasFactCache p => r) -> r
withFactCache Int
n HasFactCache p => r
x = let ?factCache = HasFactCache p
FactCache p
cache in r
HasFactCache p => r
x
  where
    !cache :: FactCache p
cache = Int -> FactCache p
forall (p :: Nat). KnownNat p => Int -> FactCache p
buildFactCache Int
n
{-# INLINE withFactCache #-}

buildRecipFactCache :: forall p. (HasFactCache p, KnownNat p) => Int -> RecipFactCache p
buildRecipFactCache :: forall (p :: Nat).
(HasFactCache p, KnownNat p) =>
Int -> RecipFactCache p
buildRecipFactCache Int
n =
  Vector (GF p) -> RecipFactCache p
forall (p :: Nat). Vector (GF p) -> RecipFactCache p
RecipFactCache
    (Vector (GF p) -> RecipFactCache p)
-> (Vector Int -> Vector (GF p)) -> Vector Int -> RecipFactCache p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> GF p -> GF p) -> GF p -> Vector Int -> Vector (GF p)
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> b) -> b -> Vector a -> Vector b
U.scanr' (GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(*) (GF p -> GF p -> GF p) -> (Int -> GF p) -> Int -> GF p -> GF p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> GF p
forall a b. Coercible a b => a -> b
coerce) (GF p
1 GF p -> GF p -> GF p
forall a. Fractional a => a -> a -> a
/ Int -> GF p
forall (p :: Nat). (HasFactCache p, KnownNat p) => Int -> GF p
fact Int
size)
    (Vector Int -> RecipFactCache p) -> Vector Int -> RecipFactCache p
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate Int
size (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  where
    size :: Int
size = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n (Proxy p -> Int
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
natValAsInt (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @p) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

withRecipFactCache :: forall p r. (HasFactCache p, KnownNat p) => Int -> ((HasRecipFactCache p) => r) -> r
withRecipFactCache :: forall (p :: Nat) r.
(HasFactCache p, KnownNat p) =>
Int -> (HasRecipFactCache p => r) -> r
withRecipFactCache Int
n HasRecipFactCache p => r
x = let ?recipFactCache = HasRecipFactCache p
RecipFactCache p
cache in r
HasRecipFactCache p => r
x
  where
    !cache :: RecipFactCache p
cache = Int -> RecipFactCache p
forall (p :: Nat).
(HasFactCache p, KnownNat p) =>
Int -> RecipFactCache p
buildRecipFactCache Int
n
{-# INLINE withRecipFactCache #-}

withCombCache :: forall p r. (KnownNat p) => Int -> ((HasCombCache p) => r) -> r
withCombCache :: forall (p :: Nat) r.
KnownNat p =>
Int -> (HasCombCache p => r) -> r
withCombCache Int
n HasCombCache p => r
x = Int -> (HasFactCache p => r) -> r
forall (p :: Nat) r.
KnownNat p =>
Int -> (HasFactCache p => r) -> r
withFactCache Int
n ((HasFactCache p => r) -> r) -> (HasFactCache p => r) -> r
forall a b. (a -> b) -> a -> b
$ Int -> (HasRecipFactCache p => r) -> r
forall (p :: Nat) r.
(HasFactCache p, KnownNat p) =>
Int -> (HasRecipFactCache p => r) -> r
withRecipFactCache Int
n r
HasRecipFactCache p => r
HasCombCache p => r
x
{-# INLINE withCombCache #-}

{- | Lucas's theorem

/O(log N)/
-}
combSmall :: forall p. (KnownNat p) => Int -> Int -> GF p
combSmall :: forall (p :: Nat). KnownNat p => Int -> Int -> GF p
combSmall = GF p -> Int -> Int -> GF p
forall {p :: Nat}. KnownNat p => GF p -> Int -> Int -> GF p
go (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
1)
  where
    p :: Int
p = Proxy p -> Int
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
natValAsInt (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @p)
    go :: GF p -> Int -> Int -> GF p
go !GF p
acc Int
0 Int
0 = GF p
acc
    go !GF p
acc !Int
n !Int
r = GF p -> Int -> Int -> GF p
go (GF p
acc GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
* GF p
c) Int
qn Int
qr
      where
        (Int
qn, Int
rn) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
n Int
p
        (Int
qr, Int
rr) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
r Int
p
        c :: GF p
c = Vector (GF p) -> Int -> GF p
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector (GF p)
forall (p :: Nat). KnownNat p => Vector (GF p)
combSmallTable (Int
rn Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rr)

-- | /O(p ^ 2)/
combSmallTable :: forall p. (KnownNat p) => U.Vector (GF p)
combSmallTable :: forall (p :: Nat). KnownNat p => Vector (GF p)
combSmallTable = (forall s. ST s (MVector s (GF p))) -> Vector (GF p)
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s (GF p))) -> Vector (GF p))
-> (forall s. ST s (MVector s (GF p))) -> Vector (GF p)
forall a b. (a -> b) -> a -> b
$ do
  MVector s (GF p)
dp <- Int -> GF p -> ST s (MVector (PrimState (ST s)) (GF p))
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
0)
  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    MVector (PrimState (ST s)) (GF p) -> Int -> GF p -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
dp (Int -> Int -> Int
ix Int
i Int
0) (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
1)
    MVector (PrimState (ST s)) (GF p) -> Int -> GF p -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
dp (Int -> Int -> Int
ix Int
i Int
i) (Int -> GF p
forall (p :: Nat). Int -> GF p
GF Int
1)
  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep1 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
    Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep1 (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
y -> do
      GF p -> GF p -> GF p
forall a. Num a => a -> a -> a
(+)
        (GF p -> GF p -> GF p) -> ST s (GF p) -> ST s (GF p -> GF p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) (GF p) -> Int -> ST s (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
dp (Int -> Int -> Int
ix (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
        ST s (GF p -> GF p) -> ST s (GF p) -> ST s (GF p)
forall a b. ST s (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MVector (PrimState (ST s)) (GF p) -> Int -> ST s (GF p)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
dp (Int -> Int -> Int
ix (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
y)
        ST s (GF p) -> (GF p -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s)) (GF p) -> Int -> GF p -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s (GF p)
MVector (PrimState (ST s)) (GF p)
dp (Int -> Int -> Int
ix Int
x Int
y)
  MVector s (GF p) -> ST s (MVector s (GF p))
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s (GF p)
dp
  where
    n :: Int
n = Proxy p -> Int
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
natValAsInt (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @p)
    ix :: Int -> Int -> Int
ix Int
x Int
y = Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
y
{-# NOINLINE combSmallTable #-}