module Math.Prime.LinearSieve where

import Control.Monad
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import My.Prelude

data LinearSieve = LinearSieve
  { LinearSieve -> Vector Int
leastPrimeFactor :: !(U.Vector Int)
  , LinearSieve -> Vector Int
primes :: !(U.Vector Int)
  }

{- | /O(N)/
>>> primes $ buildLinearSieve 32
[2,3,5,7,11,13,17,19,23,29,31]
-}
buildLinearSieve :: Int -> LinearSieve
buildLinearSieve :: Int -> LinearSieve
buildLinearSieve Int
n = (forall s. ST s LinearSieve) -> LinearSieve
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s LinearSieve) -> LinearSieve)
-> (forall s. ST s LinearSieve) -> LinearSieve
forall a b. (a -> b) -> a -> b
$ do
  lpf <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
0
  ps <- UM.replicate (primeCountUpperBound n) 0
  cntP <-
    MS.foldM'
      ( \Int
cnt Int
x -> do
          lpfx <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
lpf Int
x
          let !cnt'
                | Int
lpfx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = Int
cnt
                | Bool
otherwise = Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
          when (cnt < cnt') $ do
            UM.unsafeWrite lpf x x
            UM.unsafeWrite ps cnt x
          lpfx' <- UM.unsafeRead lpf x
          fix
            ( \Int -> ST s ()
loop !Int
i -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
cnt') (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                p <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
ps Int
i
                when (p <= lpfx' && p * x <= n) $ do
                  UM.unsafeWrite lpf (p * x) p
                  loop (i + 1)
            )
            0
          pure cnt'
      )
      0
      (2 ..< n + 1)
  LinearSieve
    <$> U.unsafeFreeze lpf
    <*> U.unsafeFreeze (UM.take cntP ps)

{- |
>>> ls = buildLinearSieve 100
>>> primeFactors ls 60
[2,2,3,5]
>>> primeFactors ls 0
[]
>>> primeFactors ls 1
[]
>>> primeFactors ls 2
[2]
>>> primeFactors ls 4
[2,2]
-}
primeFactors :: LinearSieve -> Int -> [Int]
primeFactors :: LinearSieve -> Int -> [Int]
primeFactors LinearSieve{leastPrimeFactor :: LinearSieve -> Vector Int
leastPrimeFactor = Vector Int
lpf} = Int -> [Int]
go
  where
    go :: Int -> [Int]
go Int
n
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 = []
      | Int
p <- Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
lpf Int
n = Int
p Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int]
go (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
p)

{- |
>>> ls = buildLinearSieve 100
>>> isPrime ls 97
True
>>> isPrime ls 100
False
>>> isPrime ls 0
False
>>> isPrime ls 1
False
>>> isPrime ls 2
True
-}
isPrime :: LinearSieve -> Int -> Bool
isPrime :: LinearSieve -> Int -> Bool
isPrime LinearSieve{leastPrimeFactor :: LinearSieve -> Vector Int
leastPrimeFactor = Vector Int
lpf} Int
n =
  Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2 Bool -> Bool -> Bool
&& Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
lpf Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n

{- |
>>> primeCountUpperBound 100
32
>>> primeCountUpperBound (10 ^ 6)
100000
-}
primeCountUpperBound :: Int -> Int
primeCountUpperBound :: Int -> Int
primeCountUpperBound Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
64720 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
10
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
24300 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
9
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
8472 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
8
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
3094 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
7
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
1134 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
6
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
360 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
5
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
120 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot Int
n Int
4
  | Bool
otherwise = Int
32