{-# LANGUAGE LambdaCase #-}

module Math.Prime.Sieve where

import Control.Monad (when)
import Control.Monad.ST (runST)
import Data.Bits (Bits (clearBit, testBit, unsafeShiftR, (.&.)))
import Data.Function (fix)
import Data.Primitive (
  ByteArray,
  fillByteArray,
  indexByteArray,
  newByteArray,
  readByteArray,
  unsafeFreezeByteArray,
  writeByteArray,
 )
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Word (Word64, Word8)

withPrimes :: Int -> (U.Vector Int -> a) -> a
withPrimes :: forall a. Int -> (Vector Int -> a) -> a
withPrimes Int
n Vector Int -> a
f = Vector Int -> a
f (Vector Int -> a) -> (Vector Int -> Vector Int) -> Vector Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter Int -> Bool
isP (Vector Int -> a) -> Vector Int -> a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int
forall a. a -> a
id
  where
    !(Sieve ByteArray
sieved) = Int -> Sieve
sieve Int
n
    isP :: Int -> Bool
isP Int
i =
      let seg :: Word64
seg = forall a. Prim a => ByteArray -> Int -> a
indexByteArray @Word64 ByteArray
sieved (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6)
       in Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word64
seg (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0x3f)

newtype Sieve = Sieve ByteArray

sieve :: Int -> Sieve
sieve :: Int -> Sieve
sieve Int
n = (forall s. ST s Sieve) -> Sieve
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Sieve) -> Sieve)
-> (forall s. ST s Sieve) -> Sieve
forall a b. (a -> b) -> a -> b
$ do
  let lim :: Int
lim = ((Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
63) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
64
  isp <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
lim Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8)
  fillByteArray isp 0 (lim * 8) 0b10101010
  writeByteArray @Word8 isp 0 0b10101100
  let !sqrtLim = Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Int) -> (Double -> Double) -> Double -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Floating a => a -> a
sqrt @Double (Double -> Int) -> Double -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
lim
  flip fix 3 $ \Int -> ST s ()
loop !Int
p -> do
    seg <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray @Word64 MutableByteArray s
MutableByteArray (PrimState (ST s))
isp (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
p Int
6)
    when (testBit seg (p .&. 0x3f)) $ do
      flip fix (p * p) $ \Int -> ST s ()
loop' !Int
i -> do
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lim) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
          seg' <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray @Word64 MutableByteArray s
MutableByteArray (PrimState (ST s))
isp (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
i Int
6)
          writeByteArray @Word64 isp (unsafeShiftR i 6) $
            clearBit seg' (i .&. 0x3f)
          loop' (i + 2 * p)
    when (p + 2 <= sqrtLim) $ do
      loop (p + 2)
  Sieve <$> unsafeFreezeByteArray isp

buildMoebiusTable :: Int -> U.Vector Int
buildMoebiusTable :: Int -> Vector Int
buildMoebiusTable Int
n = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
  isp <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Bool
True
  ms <- UM.replicate (n + 1) 1
  UM.write isp 0 False
  UM.write isp 1 False
  fix
    ( \Int -> ST s ()
outer Int
p -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
isp Int
p ST s Bool -> (Bool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
False -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          Bool
True -> do
            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
ms Int
p (-Int
1)
            ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix
              ( \Int -> ST s ()
inner Int
q -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
q Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                  MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Bool
MVector (PrimState (ST s)) Bool
isp Int
q Bool
False
                  if Int -> Int -> Int
forall a. Integral a => a -> a -> a
rem (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
q Int
p) Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
                    then MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
ms Int
q Int
0
                    else MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s Int
MVector (PrimState (ST s)) Int
ms Int -> Int
forall a. Num a => a -> a
negate Int
q
                  Int -> ST s ()
inner (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p
              )
              (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
p)
        Int -> ST s ()
outer (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    )
    2
  return ms