{-# LANGUAGE LambdaCase #-} module Math.Prime.Sieve where import Control.Monad (when) import Control.Monad.ST (runST) import Data.Bits (Bits (clearBit, testBit, unsafeShiftR, (.&.))) import Data.Function (fix) import Data.Primitive ( ByteArray, fillByteArray, indexByteArray, newByteArray, readByteArray, unsafeFreezeByteArray, writeByteArray, ) import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Unboxed.Mutable as UM import Data.Word (Word64, Word8) withPrimes :: Int -> (U.Vector Int -> a) -> a withPrimes :: forall a. Int -> (Vector Int -> a) -> a withPrimes Int n Vector Int -> a f = Vector Int -> a f (Vector Int -> a) -> (Vector Int -> Vector Int) -> Vector Int -> a forall b c a. (b -> c) -> (a -> b) -> a -> c . (Int -> Bool) -> Vector Int -> Vector Int forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a U.filter Int -> Bool isP (Vector Int -> a) -> Vector Int -> a forall a b. (a -> b) -> a -> b $ Int -> (Int -> Int) -> Vector Int forall a. Unbox a => Int -> (Int -> a) -> Vector a U.generate (Int n Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Int -> Int forall a. a -> a id where !(Sieve ByteArray sieved) = Int -> Sieve sieve Int n isP :: Int -> Bool isP Int i = let seg :: Word64 seg = forall a. Prim a => ByteArray -> Int -> a indexByteArray @Word64 ByteArray sieved (Int -> Int -> Int forall a. Bits a => a -> Int -> a unsafeShiftR Int i Int 6) in Word64 -> Int -> Bool forall a. Bits a => a -> Int -> Bool testBit Word64 seg (Int i Int -> Int -> Int forall a. Bits a => a -> a -> a .&. Int 0x3f) newtype Sieve = Sieve ByteArray sieve :: Int -> Sieve sieve :: Int -> Sieve sieve Int n = (forall s. ST s Sieve) -> Sieve forall a. (forall s. ST s a) -> a runST ((forall s. ST s Sieve) -> Sieve) -> (forall s. ST s Sieve) -> Sieve forall a b. (a -> b) -> a -> b $ do let lim :: Int lim = ((Int n Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Int -> Int -> Int forall a. Num a => a -> a -> a + Int 63) Int -> Int -> Int forall a. Integral a => a -> a -> a `quot` Int 64 Int -> Int -> Int forall a. Num a => a -> a -> a * Int 64 isp <- Int -> ST s (MutableByteArray (PrimState (ST s))) forall (m :: * -> *). PrimMonad m => Int -> m (MutableByteArray (PrimState m)) newByteArray (Int lim Int -> Int -> Int forall a. Num a => a -> a -> a * Int 8) fillByteArray isp 0 (lim * 8) 0b10101010 writeByteArray @Word8 isp 0 0b10101100 let !sqrtLim = Double -> Int forall b. Integral b => Double -> b forall a b. (RealFrac a, Integral b) => a -> b floor (Double -> Int) -> (Double -> Double) -> Double -> Int forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. Floating a => a -> a sqrt @Double (Double -> Int) -> Double -> Int forall a b. (a -> b) -> a -> b $ Int -> Double forall a b. (Integral a, Num b) => a -> b fromIntegral Int lim flip fix 3 $ \Int -> ST s () loop !Int p -> do seg <- forall a (m :: * -> *). (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> m a readByteArray @Word64 MutableByteArray s MutableByteArray (PrimState (ST s)) isp (Int -> Int -> Int forall a. Bits a => a -> Int -> a unsafeShiftR Int p Int 6) when (testBit seg (p .&. 0x3f)) $ do flip fix (p * p) $ \Int -> ST s () loop' !Int i -> do Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Int i Int -> Int -> Bool forall a. Ord a => a -> a -> Bool < Int lim) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do seg' <- forall a (m :: * -> *). (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> m a readByteArray @Word64 MutableByteArray s MutableByteArray (PrimState (ST s)) isp (Int -> Int -> Int forall a. Bits a => a -> Int -> a unsafeShiftR Int i Int 6) writeByteArray @Word64 isp (unsafeShiftR i 6) $ clearBit seg' (i .&. 0x3f) loop' (i + 2 * p) when (p + 2 <= sqrtLim) $ do loop (p + 2) Sieve <$> unsafeFreezeByteArray isp buildMoebiusTable :: Int -> U.Vector Int buildMoebiusTable :: Int -> Vector Int buildMoebiusTable Int n = (forall s. ST s (MVector s Int)) -> Vector Int forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a U.create ((forall s. ST s (MVector s Int)) -> Vector Int) -> (forall s. ST s (MVector s Int)) -> Vector Int forall a b. (a -> b) -> a -> b $ do isp <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool) forall (m :: * -> *) a. (PrimMonad m, Unbox a) => Int -> a -> m (MVector (PrimState m) a) UM.replicate (Int n Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Bool True ms <- UM.replicate (n + 1) 1 UM.write isp 0 False UM.write isp 1 False fix ( \Int -> ST s () outer Int p -> Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Int p Int -> Int -> Bool forall a. Ord a => a -> a -> Bool <= Int n) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do MVector (PrimState (ST s)) Bool -> Int -> ST s Bool forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> m a UM.unsafeRead MVector s Bool MVector (PrimState (ST s)) Bool isp Int p ST s Bool -> (Bool -> ST s ()) -> ST s () forall a b. ST s a -> (a -> ST s b) -> ST s b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= \case Bool False -> () -> ST s () forall a. a -> ST s a forall (f :: * -> *) a. Applicative f => a -> f a pure () Bool True -> do MVector (PrimState (ST s)) Int -> Int -> Int -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.unsafeWrite MVector s Int MVector (PrimState (ST s)) Int ms Int p (-Int 1) ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s () forall a. (a -> a) -> a fix ( \Int -> ST s () inner Int q -> Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Int q Int -> Int -> Bool forall a. Ord a => a -> a -> Bool <= Int n) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.unsafeWrite MVector s Bool MVector (PrimState (ST s)) Bool isp Int q Bool False if Int -> Int -> Int forall a. Integral a => a -> a -> a rem (Int -> Int -> Int forall a. Integral a => a -> a -> a quot Int q Int p) Int p Int -> Int -> Bool forall a. Eq a => a -> a -> Bool == Int 0 then MVector (PrimState (ST s)) Int -> Int -> Int -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.unsafeWrite MVector s Int MVector (PrimState (ST s)) Int ms Int q Int 0 else MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> (a -> a) -> Int -> m () UM.unsafeModify MVector s Int MVector (PrimState (ST s)) Int ms Int -> Int forall a. Num a => a -> a negate Int q Int -> ST s () inner (Int -> ST s ()) -> Int -> ST s () forall a b. (a -> b) -> a -> b $ Int q Int -> Int -> Int forall a. Num a => a -> a -> a + Int p ) (Int 2 Int -> Int -> Int forall a. Num a => a -> a -> a * Int p) Int -> ST s () outer (Int p Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) ) 2 return ms