{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.CSR where

import Control.Monad.Primitive
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Buffer

data CSR a = CSR
  { forall a. CSR a -> Int
numRowsCSR :: !Int
  , forall a. CSR a -> Vector Int
offsetCSR :: !(U.Vector Int)
  , forall a. CSR a -> Vector a
bufferCSR :: !(U.Vector a)
  }

rowAt :: (U.Unbox a) => CSR a -> Int -> U.Vector a
rowAt :: forall a. Unbox a => CSR a -> Int -> Vector a
rowAt CSR{Int
Vector a
Vector Int
numRowsCSR :: forall a. CSR a -> Int
offsetCSR :: forall a. CSR a -> Vector Int
bufferCSR :: forall a. CSR a -> Vector a
numRowsCSR :: Int
offsetCSR :: Vector Int
bufferCSR :: Vector a
..} Int
i = Int -> Int -> Vector a -> Vector a
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
o (Int
o' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o) Vector a
bufferCSR
  where
    o :: Int
o = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetCSR Int
i
    o' :: Int
o' = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetCSR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE rowAt #-}

accumulateToCSR ::
  (U.Unbox a) =>
  -- | num rows
  Int ->
  -- | buffer size
  Int ->
  U.Vector (Int, a) ->
  CSR a
accumulateToCSR :: forall a. Unbox a => Int -> Int -> Vector (Int, a) -> CSR a
accumulateToCSR Int
n Int
m Vector (Int, a)
v = Int -> Int -> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
forall a.
Unbox a =>
Int -> Int -> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
createCSR Int
n Int
m ((forall s. CSRBuilder s a -> ST s ()) -> CSR a)
-> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
forall a b. (a -> b) -> a -> b
$ \CSRBuilder s a
builder -> do
  Vector (Int, a) -> ((Int, a) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Int, a)
v (((Int, a) -> ST s ()) -> ST s ())
-> ((Int, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, a
x) -> do
    (Int, a) -> CSRBuilder (PrimState (ST s)) a -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
(Int, a) -> CSRBuilder (PrimState m) a -> m ()
pushCSRB (Int
i, a
x) CSRBuilder s a
CSRBuilder (PrimState (ST s)) a
builder

createCSR ::
  (U.Unbox a) =>
  -- | num rows
  Int ->
  -- | buffer size
  Int ->
  (forall s. CSRBuilder s a -> ST s ()) ->
  CSR a
createCSR :: forall a.
Unbox a =>
Int -> Int -> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
createCSR Int
n Int
m forall s. CSRBuilder s a -> ST s ()
run = (forall s. ST s (CSR a)) -> CSR a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (CSR a)) -> CSR a)
-> (forall s. ST s (CSR a)) -> CSR a
forall a b. (a -> b) -> a -> b
$ do
  builder <- Int -> Int -> ST s (CSRBuilder (PrimState (ST s)) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> Int -> m (CSRBuilder (PrimState m) a)
newCSRBuilder Int
n Int
m
  run builder
  buildCSR builder

data CSRBuilder s a = CSRBuilder
  { forall s a. CSRBuilder s a -> Int
numRowsCSRB :: !Int
  , forall s a. CSRBuilder s a -> Buffer s (Int, a)
queueCSRB :: !(Buffer s (Int, a))
  , forall s a. CSRBuilder s a -> MVector s Int
outDegCSRB :: !(UM.MVector s Int)
  }

newCSRBuilder ::
  (U.Unbox a, PrimMonad m) =>
  -- | num rows
  Int ->
  -- | buffer size
  Int ->
  m (CSRBuilder (PrimState m) a)
newCSRBuilder :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> Int -> m (CSRBuilder (PrimState m) a)
newCSRBuilder Int
numRows Int
bufferSize =
  Int
-> Buffer (PrimState m) (Int, a)
-> MVector (PrimState m) Int
-> CSRBuilder (PrimState m) a
forall s a.
Int -> Buffer s (Int, a) -> MVector s Int -> CSRBuilder s a
CSRBuilder Int
numRows
    (Buffer (PrimState m) (Int, a)
 -> MVector (PrimState m) Int -> CSRBuilder (PrimState m) a)
-> m (Buffer (PrimState m) (Int, a))
-> m (MVector (PrimState m) Int -> CSRBuilder (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m (Buffer (PrimState m) (Int, a))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer Int
bufferSize
    m (MVector (PrimState m) Int -> CSRBuilder (PrimState m) a)
-> m (MVector (PrimState m) Int) -> m (CSRBuilder (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numRows Int
0
{-# INLINE newCSRBuilder #-}

buildCSR ::
  (U.Unbox a, PrimMonad m) =>
  CSRBuilder (PrimState m) a ->
  m (CSR a)
buildCSR :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
CSRBuilder (PrimState m) a -> m (CSR a)
buildCSR CSRBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, a)
numRowsCSRB :: forall s a. CSRBuilder s a -> Int
queueCSRB :: forall s a. CSRBuilder s a -> Buffer s (Int, a)
outDegCSRB :: forall s a. CSRBuilder s a -> MVector s Int
numRowsCSRB :: Int
queueCSRB :: Buffer (PrimState m) (Int, a)
outDegCSRB :: MVector (PrimState m) Int
..} = do
  m <- Buffer (PrimState m) (Int, a) -> m Int
forall (m :: * -> *) a.
PrimMonad m =>
Buffer (PrimState m) a -> m Int
lengthBuffer Buffer (PrimState m) (Int, a)
queueCSRB
  offsetCSR <- U.scanl' (+) 0 <$> U.freeze outDegCSRB
  moffset <- U.thaw offsetCSR
  mbuffer <- UM.unsafeNew m
  fix $ \m ()
loop -> do
    Buffer (PrimState m) (Int, a) -> m (Maybe (Int, a))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popFront Buffer (PrimState m) (Int, a)
queueCSRB m (Maybe (Int, a)) -> (Maybe (Int, a) -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (Int
i, a
x) -> do
        pos <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
moffset Int
i
        UM.unsafeWrite moffset i (pos + 1)
        UM.unsafeWrite mbuffer pos x
        loop
      Maybe (Int, a)
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  bufferCSR <- U.unsafeFreeze mbuffer
  return $ CSR{numRowsCSR = numRowsCSRB, ..}
{-# INLINE buildCSR #-}

pushCSRB ::
  (U.Unbox a, PrimMonad m) =>
  (Int, a) ->
  CSRBuilder (PrimState m) a ->
  m ()
pushCSRB :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
(Int, a) -> CSRBuilder (PrimState m) a -> m ()
pushCSRB (Int
i, a
x) CSRBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, a)
numRowsCSRB :: forall s a. CSRBuilder s a -> Int
queueCSRB :: forall s a. CSRBuilder s a -> Buffer s (Int, a)
outDegCSRB :: forall s a. CSRBuilder s a -> MVector s Int
numRowsCSRB :: Int
queueCSRB :: Buffer (PrimState m) (Int, a)
outDegCSRB :: MVector (PrimState m) Int
..} = do
  (Int, a) -> Buffer (PrimState m) (Int, a) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
i, a
x) Buffer (PrimState m) (Int, a)
queueCSRB
  MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
outDegCSRB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
i
{-# INLINE pushCSRB #-}