{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TypeFamilies #-}

module System.Random.Utils where

import Control.Monad
import Control.Monad.Primitive
import Data.Primitive.MutVar
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Generic.Mutable as GM
import System.Random.Stateful

import My.Prelude

-- | global StdGen
newStdGenPrim :: (PrimMonad m) => m StdGen
newStdGenPrim :: forall (m :: * -> *). PrimMonad m => m StdGen
newStdGenPrim = IO StdGen -> m StdGen
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafeIOToPrim IO StdGen
forall (m :: * -> *). MonadIO m => m StdGen
newStdGen
{-# INLINE newStdGenPrim #-}

newtype PrimGen g = PrimGen {forall g. PrimGen g -> g
unPrimGen :: g}
  deriving newtype (Int -> PrimGen g -> (ShortByteString, PrimGen g)
Word32 -> PrimGen g -> (Word32, PrimGen g)
Word64 -> PrimGen g -> (Word64, PrimGen g)
PrimGen g -> (Int, Int)
PrimGen g -> (Int, PrimGen g)
PrimGen g -> (Word8, PrimGen g)
PrimGen g -> (Word16, PrimGen g)
PrimGen g -> (Word32, PrimGen g)
PrimGen g -> (Word64, PrimGen g)
PrimGen g -> (PrimGen g, PrimGen g)
(PrimGen g -> (Int, PrimGen g))
-> (PrimGen g -> (Word8, PrimGen g))
-> (PrimGen g -> (Word16, PrimGen g))
-> (PrimGen g -> (Word32, PrimGen g))
-> (PrimGen g -> (Word64, PrimGen g))
-> (Word32 -> PrimGen g -> (Word32, PrimGen g))
-> (Word64 -> PrimGen g -> (Word64, PrimGen g))
-> (Int -> PrimGen g -> (ShortByteString, PrimGen g))
-> (PrimGen g -> (Int, Int))
-> (PrimGen g -> (PrimGen g, PrimGen g))
-> RandomGen (PrimGen g)
forall g.
RandomGen g =>
Int -> PrimGen g -> (ShortByteString, PrimGen g)
forall g. RandomGen g => Word32 -> PrimGen g -> (Word32, PrimGen g)
forall g. RandomGen g => Word64 -> PrimGen g -> (Word64, PrimGen g)
forall g. RandomGen g => PrimGen g -> (Int, Int)
forall g. RandomGen g => PrimGen g -> (Int, PrimGen g)
forall g. RandomGen g => PrimGen g -> (Word8, PrimGen g)
forall g. RandomGen g => PrimGen g -> (Word16, PrimGen g)
forall g. RandomGen g => PrimGen g -> (Word32, PrimGen g)
forall g. RandomGen g => PrimGen g -> (Word64, PrimGen g)
forall g. RandomGen g => PrimGen g -> (PrimGen g, PrimGen g)
forall g.
(g -> (Int, g))
-> (g -> (Word8, g))
-> (g -> (Word16, g))
-> (g -> (Word32, g))
-> (g -> (Word64, g))
-> (Word32 -> g -> (Word32, g))
-> (Word64 -> g -> (Word64, g))
-> (Int -> g -> (ShortByteString, g))
-> (g -> (Int, Int))
-> (g -> (g, g))
-> RandomGen g
$cnext :: forall g. RandomGen g => PrimGen g -> (Int, PrimGen g)
next :: PrimGen g -> (Int, PrimGen g)
$cgenWord8 :: forall g. RandomGen g => PrimGen g -> (Word8, PrimGen g)
genWord8 :: PrimGen g -> (Word8, PrimGen g)
$cgenWord16 :: forall g. RandomGen g => PrimGen g -> (Word16, PrimGen g)
genWord16 :: PrimGen g -> (Word16, PrimGen g)
$cgenWord32 :: forall g. RandomGen g => PrimGen g -> (Word32, PrimGen g)
genWord32 :: PrimGen g -> (Word32, PrimGen g)
$cgenWord64 :: forall g. RandomGen g => PrimGen g -> (Word64, PrimGen g)
genWord64 :: PrimGen g -> (Word64, PrimGen g)
$cgenWord32R :: forall g. RandomGen g => Word32 -> PrimGen g -> (Word32, PrimGen g)
genWord32R :: Word32 -> PrimGen g -> (Word32, PrimGen g)
$cgenWord64R :: forall g. RandomGen g => Word64 -> PrimGen g -> (Word64, PrimGen g)
genWord64R :: Word64 -> PrimGen g -> (Word64, PrimGen g)
$cgenShortByteString :: forall g.
RandomGen g =>
Int -> PrimGen g -> (ShortByteString, PrimGen g)
genShortByteString :: Int -> PrimGen g -> (ShortByteString, PrimGen g)
$cgenRange :: forall g. RandomGen g => PrimGen g -> (Int, Int)
genRange :: PrimGen g -> (Int, Int)
$csplit :: forall g. RandomGen g => PrimGen g -> (PrimGen g, PrimGen g)
split :: PrimGen g -> (PrimGen g, PrimGen g)
RandomGen)

newtype PrimGenM g s = PrimGenM {forall g s. PrimGenM g s -> MutVar s g
unPrimGenM :: MutVar s g}

newPrimGenM :: (PrimMonad m) => g -> m (PrimGenM g (PrimState m))
newPrimGenM :: forall (m :: * -> *) g.
PrimMonad m =>
g -> m (PrimGenM g (PrimState m))
newPrimGenM g
g = MutVar (PrimState m) g -> PrimGenM g (PrimState m)
forall g s. MutVar s g -> PrimGenM g s
PrimGenM (MutVar (PrimState m) g -> PrimGenM g (PrimState m))
-> m (MutVar (PrimState m) g) -> m (PrimGenM g (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> g -> m (MutVar (PrimState m) g)
forall (m :: * -> *) a.
PrimMonad m =>
a -> m (MutVar (PrimState m) a)
newMutVar g
g

applyPrimGen ::
  (PrimMonad m) =>
  (g -> (a, g)) ->
  PrimGenM g (PrimState m) ->
  m a
applyPrimGen :: forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen g -> (a, g)
f (PrimGenM MutVar (PrimState m) g
ref) = do
  (!a
a, !g
g) <- g -> (a, g)
f (g -> (a, g)) -> m g -> m (a, g)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutVar (PrimState m) g -> m g
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar MutVar (PrimState m) g
ref
  MutVar (PrimState m) g -> g -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> a -> m ()
writeMutVar MutVar (PrimState m) g
ref g
g
  a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
{-# INLINE applyPrimGen #-}

instance
  (RandomGen g, s ~ PrimState m, PrimMonad m) =>
  StatefulGen (PrimGenM g s) m
  where
  uniformWord32R :: Word32 -> PrimGenM g s -> m Word32
uniformWord32R = (g -> (Word32, g)) -> PrimGenM g s -> m Word32
(g -> (Word32, g)) -> PrimGenM g (PrimState m) -> m Word32
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen ((g -> (Word32, g)) -> PrimGenM g s -> m Word32)
-> (Word32 -> g -> (Word32, g))
-> Word32
-> PrimGenM g s
-> m Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> g -> (Word32, g)
forall g. RandomGen g => Word32 -> g -> (Word32, g)
genWord32R
  {-# INLINE uniformWord32R #-}
  uniformWord64R :: Word64 -> PrimGenM g s -> m Word64
uniformWord64R = (g -> (Word64, g)) -> PrimGenM g s -> m Word64
(g -> (Word64, g)) -> PrimGenM g (PrimState m) -> m Word64
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen ((g -> (Word64, g)) -> PrimGenM g s -> m Word64)
-> (Word64 -> g -> (Word64, g))
-> Word64
-> PrimGenM g s
-> m Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> g -> (Word64, g)
forall g. RandomGen g => Word64 -> g -> (Word64, g)
genWord64R
  {-# INLINE uniformWord64R #-}
  uniformWord8 :: PrimGenM g s -> m Word8
uniformWord8 = (g -> (Word8, g)) -> PrimGenM g (PrimState m) -> m Word8
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen g -> (Word8, g)
forall g. RandomGen g => g -> (Word8, g)
genWord8
  {-# INLINE uniformWord8 #-}
  uniformWord16 :: PrimGenM g s -> m Word16
uniformWord16 = (g -> (Word16, g)) -> PrimGenM g (PrimState m) -> m Word16
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen g -> (Word16, g)
forall g. RandomGen g => g -> (Word16, g)
genWord16
  {-# INLINE uniformWord16 #-}
  uniformWord32 :: PrimGenM g s -> m Word32
uniformWord32 = (g -> (Word32, g)) -> PrimGenM g (PrimState m) -> m Word32
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen g -> (Word32, g)
forall g. RandomGen g => g -> (Word32, g)
genWord32
  {-# INLINE uniformWord32 #-}
  uniformWord64 :: PrimGenM g s -> m Word64
uniformWord64 = (g -> (Word64, g)) -> PrimGenM g (PrimState m) -> m Word64
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen g -> (Word64, g)
forall g. RandomGen g => g -> (Word64, g)
genWord64
  {-# INLINE uniformWord64 #-}
  uniformShortByteString :: Int -> PrimGenM g s -> m ShortByteString
uniformShortByteString = (g -> (ShortByteString, g)) -> PrimGenM g s -> m ShortByteString
(g -> (ShortByteString, g))
-> PrimGenM g (PrimState m) -> m ShortByteString
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen ((g -> (ShortByteString, g)) -> PrimGenM g s -> m ShortByteString)
-> (Int -> g -> (ShortByteString, g))
-> Int
-> PrimGenM g s
-> m ShortByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> g -> (ShortByteString, g)
forall g. RandomGen g => Int -> g -> (ShortByteString, g)
genShortByteString
  {-# INLINE uniformShortByteString #-}

instance
  (RandomGen g, s ~ PrimState m, PrimMonad m) =>
  RandomGenM (PrimGenM g s) g m
  where
  applyRandomGenM :: forall a. (g -> (a, g)) -> PrimGenM g s -> m a
applyRandomGenM = (g -> (a, g)) -> PrimGenM g s -> m a
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
forall (m :: * -> *) g a.
PrimMonad m =>
(g -> (a, g)) -> PrimGenM g (PrimState m) -> m a
applyPrimGen

instance
  (RandomGen g, PrimMonad m) =>
  FrozenGen (PrimGen g) m
  where
  type MutableGen (PrimGen g) m = PrimGenM g (PrimState m)
  freezeGen :: MutableGen (PrimGen g) m -> m (PrimGen g)
freezeGen (PrimGenM MutVar (PrimState m) g
ref) = g -> PrimGen g
forall g. g -> PrimGen g
PrimGen (g -> PrimGen g) -> m g -> m (PrimGen g)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutVar (PrimState m) g -> m g
forall (m :: * -> *) a.
PrimMonad m =>
MutVar (PrimState m) a -> m a
readMutVar MutVar (PrimState m) g
ref
  thawGen :: PrimGen g -> m (MutableGen (PrimGen g) m)
thawGen (PrimGen g
g) = g -> m (PrimGenM g (PrimState m))
forall (m :: * -> *) g.
PrimMonad m =>
g -> m (PrimGenM g (PrimState m))
newPrimGenM g
g

withGlobalStdGen ::
  (PrimMonad m) =>
  (PrimGenM StdGen (PrimState m) -> m a) ->
  m (a, StdGen)
withGlobalStdGen :: forall (m :: * -> *) a.
PrimMonad m =>
(PrimGenM StdGen (PrimState m) -> m a) -> m (a, StdGen)
withGlobalStdGen PrimGenM StdGen (PrimState m) -> m a
f = do
  StdGen
rng <- m StdGen
forall (m :: * -> *). PrimMonad m => m StdGen
newStdGenPrim
  (PrimGen StdGen -> StdGen) -> (a, PrimGen StdGen) -> (a, StdGen)
forall a b. (a -> b) -> (a, a) -> (a, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimGen StdGen -> StdGen
forall g. PrimGen g -> g
unPrimGen ((a, PrimGen StdGen) -> (a, StdGen))
-> m (a, PrimGen StdGen) -> m (a, StdGen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimGen StdGen
-> (MutableGen (PrimGen StdGen) m -> m a) -> m (a, PrimGen StdGen)
forall f (m :: * -> *) a.
FrozenGen f m =>
f -> (MutableGen f m -> m a) -> m (a, f)
withMutableGen (StdGen -> PrimGen StdGen
forall g. g -> PrimGen g
PrimGen StdGen
rng) MutableGen (PrimGen StdGen) m -> m a
PrimGenM StdGen (PrimState m) -> m a
f

withGlobalStdGen_ ::
  (PrimMonad m) =>
  (PrimGenM StdGen (PrimState m) -> m a) ->
  m a
withGlobalStdGen_ :: forall (m :: * -> *) a.
PrimMonad m =>
(PrimGenM StdGen (PrimState m) -> m a) -> m a
withGlobalStdGen_ PrimGenM StdGen (PrimState m) -> m a
f = do
  StdGen
rng <- m StdGen
forall (m :: * -> *). PrimMonad m => m StdGen
newStdGenPrim
  PrimGen StdGen -> (MutableGen (PrimGen StdGen) m -> m a) -> m a
forall f (m :: * -> *) a.
FrozenGen f m =>
f -> (MutableGen f m -> m a) -> m a
withMutableGen_ (StdGen -> PrimGen StdGen
forall g. g -> PrimGen g
PrimGen StdGen
rng) MutableGen (PrimGen StdGen) m -> m a
PrimGenM StdGen (PrimState m) -> m a
f

{- |
>>> withFixedStdGen 123 (uniformRM @Int (1, 6))
(1,StdGen {unStdGen = SMGen 3794253433779795923 13032462758197477675})
>>> withFixedStdGen 1 (uniformRM @Int (1, 6))
(6,StdGen {unStdGen = SMGen 4999253871718377453 10451216379200822465})
-}
withFixedStdGen ::
  (PrimMonad m) =>
  Int ->
  (PrimGenM StdGen (PrimState m) -> m a) ->
  m (a, StdGen)
withFixedStdGen :: forall (m :: * -> *) a.
PrimMonad m =>
Int -> (PrimGenM StdGen (PrimState m) -> m a) -> m (a, StdGen)
withFixedStdGen Int
seed PrimGenM StdGen (PrimState m) -> m a
f = do
  (PrimGen StdGen -> StdGen) -> (a, PrimGen StdGen) -> (a, StdGen)
forall a b. (a -> b) -> (a, a) -> (a, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimGen StdGen -> StdGen
forall g. PrimGen g -> g
unPrimGen ((a, PrimGen StdGen) -> (a, StdGen))
-> m (a, PrimGen StdGen) -> m (a, StdGen)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimGen StdGen
-> (MutableGen (PrimGen StdGen) m -> m a) -> m (a, PrimGen StdGen)
forall f (m :: * -> *) a.
FrozenGen f m =>
f -> (MutableGen f m -> m a) -> m (a, f)
withMutableGen (StdGen -> PrimGen StdGen
forall g. g -> PrimGen g
PrimGen (Int -> StdGen
mkStdGen Int
seed)) MutableGen (PrimGen StdGen) m -> m a
PrimGenM StdGen (PrimState m) -> m a
f

{- |
>>> withFixedStdGen_ 123 (uniformRM @Int (1, 6))
1
>>> withFixedStdGen_ 1 (uniformRM @Int (1, 6))
6
-}
withFixedStdGen_ ::
  (PrimMonad m) =>
  Int ->
  (PrimGenM StdGen (PrimState m) -> m a) ->
  m a
withFixedStdGen_ :: forall (m :: * -> *) a.
PrimMonad m =>
Int -> (PrimGenM StdGen (PrimState m) -> m a) -> m a
withFixedStdGen_ Int
seed PrimGenM StdGen (PrimState m) -> m a
f = do
  PrimGen StdGen -> (MutableGen (PrimGen StdGen) m -> m a) -> m a
forall f (m :: * -> *) a.
FrozenGen f m =>
f -> (MutableGen f m -> m a) -> m a
withMutableGen_ (StdGen -> PrimGen StdGen
forall g. g -> PrimGen g
PrimGen (Int -> StdGen
mkStdGen Int
seed)) MutableGen (PrimGen StdGen) m -> m a
PrimGenM StdGen (PrimState m) -> m a
f

{- |
>>> import qualified Data.Vector.Unboxed as U
>>> U.modify (shuffle (mkStdGen 123)) $ U.fromList "abcdef"
"fcdbea"
-}
shuffle ::
  (PrimMonad m, GM.MVector mv a, RandomGen g) =>
  g ->
  mv (PrimState m) a ->
  m ()
shuffle :: forall (m :: * -> *) (mv :: * -> * -> *) a g.
(PrimMonad m, MVector mv a, RandomGen g) =>
g -> mv (PrimState m) a -> m ()
shuffle g
rng0 mv (PrimState m) a
mv = do
  m g -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void
    (m g -> m ()) -> m g -> m ()
forall a b. (a -> b) -> a -> b
$ (g -> Int -> m g) -> g -> Stream m Int -> m g
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
      ( \g
rng Int
i -> do
          case Word64 -> g -> (Word64, g)
forall g. RandomGen g => Word64 -> g -> (Word64, g)
genWord64R (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) g
rng of
            (Word64
j, g
rng') -> do
              mv (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
GM.unsafeSwap mv (PrimState m) a
mv Int
i (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
j)
              g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
rng'
      )
      g
rng0
      (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mv Int -> Int -> Stream m Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
>.. Int
0)
{-# INLINE shuffle #-}