{-# LANGUAGE RecordWildCards #-}

module Math.Prime.Divisor where

import Control.Monad.ST
import Data.Int
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import My.Prelude

data DivisorCache a = DivisorCache
  { forall a. DivisorCache a -> Vector Int
offsetDC :: !(U.Vector Int)
  , forall a. DivisorCache a -> Vector a
bufferDC :: !(U.Vector a)
  }

buildDivisorCache :: (Integral a, U.Unbox a) => Int -> DivisorCache a
buildDivisorCache :: forall a. (Integral a, Unbox a) => Int -> DivisorCache a
buildDivisorCache Int
n = (forall s. ST s (DivisorCache a)) -> DivisorCache a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (DivisorCache a)) -> DivisorCache a)
-> (forall s. ST s (DivisorCache a)) -> DivisorCache a
forall a b. (a -> b) -> a -> b
$ do
  freq <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
1 :: Int)
  UM.write freq 0 0
  flip MS.mapM_ (2 ..< n + 1) $ \Int
d -> do
    ((Int -> ST s ()) -> Stream (ST s) Int -> ST s ())
-> Stream (ST s) Int -> (Int -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> ST s ()) -> Stream (ST s) Int -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int -> Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Int -> Stream m Int
stride Int
d (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
d) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
      MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s Int
MVector (PrimState (ST s)) Int
freq (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
x
  offsetDC <- U.scanl' (+) 0 <$> U.unsafeFreeze freq
  buf <- UM.unsafeNew (U.last offsetDC)
  pos <- U.thaw offsetDC
  flip MS.mapM_ (1 ..< n + 1) $ \Int
d -> do
    ((Int -> ST s ()) -> Stream (ST s) Int -> ST s ())
-> Stream (ST s) Int -> (Int -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> ST s ()) -> Stream (ST s) Int -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (Int -> Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Int -> Stream m Int
stride Int
d (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
d) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
      i <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
pos Int
x
      UM.unsafeWrite pos x (i + 1)
      UM.unsafeWrite buf i (fromIntegral d)
  bufferDC <- U.unsafeFreeze buf
  return DivisorCache{..}
{-# SPECIALIZE buildDivisorCache :: Int -> DivisorCache Int #-}
{-# SPECIALIZE buildDivisorCache :: Int -> DivisorCache Int32 #-}

{- |
>>> dc = buildDivisorCache @Int 100
>>> divisors dc 60
[1,2,3,4,5,6,10,12,15,20,30,60]
>>> divisors dc 0
[]
>>> divisors dc 1
[1]
>>> divisors dc 2
[1,2]
>>> divisors dc 100
[1,2,4,5,10,20,25,50,100]
-}
divisors :: (U.Unbox a) => DivisorCache a -> Int -> U.Vector a
divisors :: forall a. Unbox a => DivisorCache a -> Int -> Vector a
divisors DivisorCache{Vector a
Vector Int
offsetDC :: forall a. DivisorCache a -> Vector Int
bufferDC :: forall a. DivisorCache a -> Vector a
offsetDC :: Vector Int
bufferDC :: Vector a
..} Int
i = Int -> Int -> Vector a -> Vector a
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
o (Int
o' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o) Vector a
bufferDC
  where
    o :: Int
o = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetDC Int
i
    o' :: Int
o' = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetDC (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE divisors #-}

{- |
>>> naiveDivisors 60
[1,60,2,30,3,20,4,15,5,12,6,10]
>>> naiveDivisors 0
[]
>>> naiveDivisors 1
[1]
>>> naiveDivisors 100
[1,100,2,50,4,25,5,20,10]
>>> length $ naiveDivisors 720720
240
-}
naiveDivisors :: Int -> [Int]
naiveDivisors :: Int -> [Int]
naiveDivisors Int
n = Int -> [Int]
go Int
1
  where
    go :: Int -> [Int]
go !Int
x
      | Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n = case Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
quotRem Int
n Int
x of
          (Int
q, Int
0) -> Int
x Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
q Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int -> [Int]
go (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          (Int, Int)
_ -> Int -> [Int]
go (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = [Int
x]
      | Bool
otherwise = []

numDivisors :: (U.Unbox a) => DivisorCache a -> Int -> Int
numDivisors :: forall a. Unbox a => DivisorCache a -> Int -> Int
numDivisors DivisorCache{Vector a
Vector Int
offsetDC :: forall a. DivisorCache a -> Vector Int
bufferDC :: forall a. DivisorCache a -> Vector a
offsetDC :: Vector Int
bufferDC :: Vector a
..} Int
i = Int
o' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o
  where
    o :: Int
o = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetDC Int
i
    o' :: Int
o' = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetDC (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE numDivisors #-}