{-# LANGUAGE LambdaCase #-} module Math.Prime.Sieve where import Control.Monad (when) import Control.Monad.ST (runST) import Data.Bits (Bits (clearBit, testBit, unsafeShiftR, (.&.))) import Data.Function (fix) import Data.Primitive ( ByteArray, fillByteArray, indexByteArray, newByteArray, readByteArray, unsafeFreezeByteArray, writeByteArray, ) import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Unboxed.Mutable as UM import Data.Word (Word64, Word8) withPrimes :: Int -> (U.Vector Int -> a) -> a withPrimes :: forall a. Int -> (Vector Int -> a) -> a withPrimes Int n Vector Int -> a f = Vector Int -> a f (Vector Int -> a) -> (Vector Int -> Vector Int) -> Vector Int -> a forall b c a. (b -> c) -> (a -> b) -> a -> c . (Int -> Bool) -> Vector Int -> Vector Int forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a U.filter Int -> Bool isP (Vector Int -> a) -> Vector Int -> a forall a b. (a -> b) -> a -> b $ Int -> (Int -> Int) -> Vector Int forall a. Unbox a => Int -> (Int -> a) -> Vector a U.generate (Int n Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Int -> Int forall a. a -> a id where !(Sieve ByteArray sieved) = Int -> Sieve sieve Int n isP :: Int -> Bool isP Int i = let seg :: Word64 seg = forall a. Prim a => ByteArray -> Int -> a indexByteArray @Word64 ByteArray sieved (Int -> Int -> Int forall a. Bits a => a -> Int -> a unsafeShiftR Int i Int 6) in Word64 -> Int -> Bool forall a. Bits a => a -> Int -> Bool testBit Word64 seg (Int i Int -> Int -> Int forall a. Bits a => a -> a -> a .&. Int 0x3f) newtype Sieve = Sieve ByteArray sieve :: Int -> Sieve sieve :: Int -> Sieve sieve Int n = (forall s. ST s Sieve) -> Sieve forall a. (forall s. ST s a) -> a runST ((forall s. ST s Sieve) -> Sieve) -> (forall s. ST s Sieve) -> Sieve forall a b. (a -> b) -> a -> b $ do let lim :: Int lim = ((Int n Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Int -> Int -> Int forall a. Num a => a -> a -> a + Int 63) Int -> Int -> Int forall a. Integral a => a -> a -> a `quot` Int 64 Int -> Int -> Int forall a. Num a => a -> a -> a * Int 64 MutableByteArray s isp <- Int -> ST s (MutableByteArray (PrimState (ST s))) forall (m :: * -> *). PrimMonad m => Int -> m (MutableByteArray (PrimState m)) newByteArray (Int lim Int -> Int -> Int forall a. Num a => a -> a -> a * Int 8) MutableByteArray (PrimState (ST s)) -> Int -> Int -> Word8 -> ST s () forall (m :: * -> *). PrimMonad m => MutableByteArray (PrimState m) -> Int -> Int -> Word8 -> m () fillByteArray MutableByteArray s MutableByteArray (PrimState (ST s)) isp Int 0 (Int lim Int -> Int -> Int forall a. Num a => a -> a -> a * Int 8) Word8 0b10101010 forall a (m :: * -> *). (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> a -> m () writeByteArray @Word8 MutableByteArray s MutableByteArray (PrimState (ST s)) isp Int 0 Word8 0b10101100 let !sqrtLim :: Int sqrtLim = Double -> Int forall b. Integral b => Double -> b forall a b. (RealFrac a, Integral b) => a -> b floor (Double -> Int) -> (Double -> Double) -> Double -> Int forall b c a. (b -> c) -> (a -> b) -> a -> c . forall a. Floating a => a -> a sqrt @Double (Double -> Int) -> Double -> Int forall a b. (a -> b) -> a -> b $ Int -> Double forall a b. (Integral a, Num b) => a -> b fromIntegral Int lim (((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()) -> Int -> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s () forall a b c. (a -> b -> c) -> b -> a -> c flip ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s () forall a. (a -> a) -> a fix Int 3 (((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()) -> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s () forall a b. (a -> b) -> a -> b $ \Int -> ST s () loop !Int p -> do Word64 seg <- forall a (m :: * -> *). (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> m a readByteArray @Word64 MutableByteArray s MutableByteArray (PrimState (ST s)) isp (Int -> Int -> Int forall a. Bits a => a -> Int -> a unsafeShiftR Int p Int 6) Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Word64 -> Int -> Bool forall a. Bits a => a -> Int -> Bool testBit Word64 seg (Int p Int -> Int -> Int forall a. Bits a => a -> a -> a .&. Int 0x3f)) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do (((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()) -> Int -> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s () forall a b c. (a -> b -> c) -> b -> a -> c flip ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s () forall a. (a -> a) -> a fix (Int p Int -> Int -> Int forall a. Num a => a -> a -> a * Int p) (((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()) -> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s () forall a b. (a -> b) -> a -> b $ \Int -> ST s () loop' !Int i -> do Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Int i Int -> Int -> Bool forall a. Ord a => a -> a -> Bool < Int lim) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do Word64 seg' <- forall a (m :: * -> *). (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> m a readByteArray @Word64 MutableByteArray s MutableByteArray (PrimState (ST s)) isp (Int -> Int -> Int forall a. Bits a => a -> Int -> a unsafeShiftR Int i Int 6) forall a (m :: * -> *). (Prim a, PrimMonad m) => MutableByteArray (PrimState m) -> Int -> a -> m () writeByteArray @Word64 MutableByteArray s MutableByteArray (PrimState (ST s)) isp (Int -> Int -> Int forall a. Bits a => a -> Int -> a unsafeShiftR Int i Int 6) (Word64 -> ST s ()) -> Word64 -> ST s () forall a b. (a -> b) -> a -> b $ Word64 -> Int -> Word64 forall a. Bits a => a -> Int -> a clearBit Word64 seg' (Int i Int -> Int -> Int forall a. Bits a => a -> a -> a .&. Int 0x3f) Int -> ST s () loop' (Int i Int -> Int -> Int forall a. Num a => a -> a -> a + Int 2 Int -> Int -> Int forall a. Num a => a -> a -> a * Int p) Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Int p Int -> Int -> Int forall a. Num a => a -> a -> a + Int 2 Int -> Int -> Bool forall a. Ord a => a -> a -> Bool <= Int sqrtLim) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do Int -> ST s () loop (Int p Int -> Int -> Int forall a. Num a => a -> a -> a + Int 2) ByteArray -> Sieve Sieve (ByteArray -> Sieve) -> ST s ByteArray -> ST s Sieve forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> MutableByteArray (PrimState (ST s)) -> ST s ByteArray forall (m :: * -> *). PrimMonad m => MutableByteArray (PrimState m) -> m ByteArray unsafeFreezeByteArray MutableByteArray s MutableByteArray (PrimState (ST s)) isp buildMoebiusTable :: Int -> U.Vector Int buildMoebiusTable :: Int -> Vector Int buildMoebiusTable Int n = (forall s. ST s (MVector s Int)) -> Vector Int forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a U.create ((forall s. ST s (MVector s Int)) -> Vector Int) -> (forall s. ST s (MVector s Int)) -> Vector Int forall a b. (a -> b) -> a -> b $ do MVector s Bool isp <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool) forall (m :: * -> *) a. (PrimMonad m, Unbox a) => Int -> a -> m (MVector (PrimState m) a) UM.replicate (Int n Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Bool True MVector s Int ms <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int) forall (m :: * -> *) a. (PrimMonad m, Unbox a) => Int -> a -> m (MVector (PrimState m) a) UM.replicate (Int n Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) Int 1 MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.write MVector s Bool MVector (PrimState (ST s)) Bool isp Int 0 Bool False MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.write MVector s Bool MVector (PrimState (ST s)) Bool isp Int 1 Bool False ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s () forall a. (a -> a) -> a fix ( \Int -> ST s () outer Int p -> Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Int p Int -> Int -> Bool forall a. Ord a => a -> a -> Bool <= Int n) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do MVector (PrimState (ST s)) Bool -> Int -> ST s Bool forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> m a UM.unsafeRead MVector s Bool MVector (PrimState (ST s)) Bool isp Int p ST s Bool -> (Bool -> ST s ()) -> ST s () forall a b. ST s a -> (a -> ST s b) -> ST s b forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= \case Bool False -> () -> ST s () forall a. a -> ST s a forall (f :: * -> *) a. Applicative f => a -> f a pure () Bool True -> do MVector (PrimState (ST s)) Int -> Int -> Int -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.unsafeWrite MVector s Int MVector (PrimState (ST s)) Int ms Int p (-Int 1) ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s () forall a. (a -> a) -> a fix ( \Int -> ST s () inner Int q -> Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (Int q Int -> Int -> Bool forall a. Ord a => a -> a -> Bool <= Int n) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b $ do MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.unsafeWrite MVector s Bool MVector (PrimState (ST s)) Bool isp Int q Bool False if Int -> Int -> Int forall a. Integral a => a -> a -> a rem (Int -> Int -> Int forall a. Integral a => a -> a -> a quot Int q Int p) Int p Int -> Int -> Bool forall a. Eq a => a -> a -> Bool == Int 0 then MVector (PrimState (ST s)) Int -> Int -> Int -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> Int -> a -> m () UM.unsafeWrite MVector s Int MVector (PrimState (ST s)) Int ms Int q Int 0 else MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s () forall (m :: * -> *) a. (PrimMonad m, Unbox a) => MVector (PrimState m) a -> (a -> a) -> Int -> m () UM.unsafeModify MVector s Int MVector (PrimState (ST s)) Int ms Int -> Int forall a. Num a => a -> a negate Int q Int -> ST s () inner (Int -> ST s ()) -> Int -> ST s () forall a b. (a -> b) -> a -> b $ Int q Int -> Int -> Int forall a. Num a => a -> a -> a + Int p ) (Int 2 Int -> Int -> Int forall a. Num a => a -> a -> a * Int p) Int -> ST s () outer (Int p Int -> Int -> Int forall a. Num a => a -> a -> a + Int 1) ) Int 2 MVector s Int -> ST s (MVector s Int) forall a. a -> ST s a forall (m :: * -> *) a. Monad m => a -> m a return MVector s Int ms