{-# LANGUAGE LambdaCase #-}

module Math.Prime.Sieve where

import Control.Monad (when)
import Control.Monad.ST (runST)
import Data.Bits (Bits (clearBit, testBit, unsafeShiftR, (.&.)))
import Data.Function (fix)
import Data.Primitive (
  ByteArray,
  fillByteArray,
  indexByteArray,
  newByteArray,
  readByteArray,
  unsafeFreezeByteArray,
  writeByteArray,
 )
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Word (Word64, Word8)

withPrimes :: Int -> (U.Vector Int -> a) -> a
withPrimes :: forall a. Int -> (Vector Int -> a) -> a
withPrimes Int
n Vector Int -> a
f = Vector Int -> a
f (Vector Int -> a) -> (Vector Int -> Vector Int) -> Vector Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter Int -> Bool
isP (Vector Int -> a) -> Vector Int -> a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int
forall a. a -> a
id
  where
    !(Sieve ByteArray
sieved) = Int -> Sieve
sieve Int
n
    isP :: Int -> Bool
isP Int
i =
      let seg :: Word64
seg = forall a. Prim a => ByteArray -> Int -> a
indexByteArray @Word64 ByteArray
sieved (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6)
       in Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word64
seg (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0x3f)

newtype Sieve = Sieve ByteArray

sieve :: Int -> Sieve
sieve :: Int -> Sieve
sieve Int
n = (forall s. ST s Sieve) -> Sieve
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Sieve) -> Sieve)
-> (forall s. ST s Sieve) -> Sieve
forall a b. (a -> b) -> a -> b
$ do
  let lim :: Int
lim = ((Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
63) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
64
  MutableByteArray s
isp <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
lim Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
  MutableByteArray (PrimState (ST s))
-> Int -> Int -> Word8 -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> Int -> Word8 -> m ()
fillByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
isp Int
0 (Int
lim Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8) Word8
0b10101010
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray @Word8 MutableByteArray s
MutableByteArray (PrimState (ST s))
isp Int
0 Word8
0b10101100
  let !sqrtLim :: Int
sqrtLim = Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Int) -> (Double -> Double) -> Double -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt @Double (Double -> Int) -> Double -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
lim
  (((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ())
-> Int -> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix Int
3 (((Int -> ST s ()) -> Int -> ST s ()) -> ST s ())
-> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int -> ST s ()
loop !Int
p -> do
    Word64
seg <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray @Word64 MutableByteArray s
MutableByteArray (PrimState (ST s))
isp (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
p Int
6)
    Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word64
seg (Int
p Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0x3f)) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
      (((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ())
-> Int -> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p) (((Int -> ST s ()) -> Int -> ST s ()) -> ST s ())
-> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int -> ST s ()
loop' !Int
i -> do
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lim) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
          Word64
seg' <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray @Word64 MutableByteArray s
MutableByteArray (PrimState (ST s))
isp (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6)
          forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray @Word64 MutableByteArray s
MutableByteArray (PrimState (ST s))
isp (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6) (Word64 -> ST s ()) -> Word64 -> ST s ()
forall a b. (a -> b) -> a -> b
$
            Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
clearBit Word64
seg' (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0x3f)
          Int -> ST s ()
loop' (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
    Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
sqrtLim) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
      Int -> ST s ()
loop (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)
  ByteArray -> Sieve
Sieve (ByteArray -> Sieve) -> ST s ByteArray -> ST s Sieve
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
isp

buildMoebiusTable :: Int -> U.Vector Int
buildMoebiusTable :: Int -> Vector Int
buildMoebiusTable Int
n = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
  MVector s Bool
isp <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Bool
True
  MVector s Int
ms <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
1
  MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Bool
MVector (PrimState (ST s)) Bool
isp Int
0 Bool
False
  MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Bool
MVector (PrimState (ST s)) Bool
isp Int
1 Bool
False
  ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix
    ( \Int -> ST s ()
outer Int
p -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
isp Int
p ST s Bool -> (Bool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
False -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          Bool
True -> do
            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
ms Int
p (-Int
1)
            ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix
              ( \Int -> ST s ()
inner Int
q -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
q Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                  MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Bool
MVector (PrimState (ST s)) Bool
isp Int
q Bool
False
                  if Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
q Int
p) Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                    then MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
ms Int
q Int
0
                    else MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s Int
MVector (PrimState (ST s)) Int
ms Int -> Int
forall a. Num a => a -> a
negate Int
q
                  Int -> ST s ()
inner (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p
              )
              (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
        Int -> ST s ()
outer (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    )
    Int
2
  MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Int
ms