{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
module Data.CSR where
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Buffer
data CSR a = CSR
{ forall a. CSR a -> Int
numRowsCSR :: !Int
, forall a. CSR a -> Vector Int
offsetCSR :: !(U.Vector Int)
, forall a. CSR a -> Vector a
bufferCSR :: !(U.Vector a)
}
rowAt :: (U.Unbox a) => CSR a -> Int -> U.Vector a
rowAt :: forall a. Unbox a => CSR a -> Int -> Vector a
rowAt CSR{Int
Vector a
Vector Int
numRowsCSR :: forall a. CSR a -> Int
offsetCSR :: forall a. CSR a -> Vector Int
bufferCSR :: forall a. CSR a -> Vector a
numRowsCSR :: Int
offsetCSR :: Vector Int
bufferCSR :: Vector a
..} Int
i = Int -> Int -> Vector a -> Vector a
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
o (Int
o' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o) Vector a
bufferCSR
where
o :: Int
o = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetCSR Int
i
o' :: Int
o' = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetCSR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE rowAt #-}
accumulateToCSR ::
(U.Unbox a) =>
Int ->
Int ->
U.Vector (Int, a) ->
CSR a
accumulateToCSR :: forall a. Unbox a => Int -> Int -> Vector (Int, a) -> CSR a
accumulateToCSR Int
n Int
m Vector (Int, a)
v = Int -> Int -> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
forall a.
Unbox a =>
Int -> Int -> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
createCSR Int
n Int
m ((forall s. CSRBuilder s a -> ST s ()) -> CSR a)
-> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
forall a b. (a -> b) -> a -> b
$ \CSRBuilder s a
builder -> do
Vector (Int, a) -> ((Int, a) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Int, a)
v (((Int, a) -> ST s ()) -> ST s ())
-> ((Int, a) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, a
x) -> do
(Int, a) -> CSRBuilder (PrimState (ST s)) a -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
(Int, a) -> CSRBuilder (PrimState m) a -> m ()
pushCSRB (Int
i, a
x) CSRBuilder s a
CSRBuilder (PrimState (ST s)) a
builder
createCSR ::
(U.Unbox a) =>
Int ->
Int ->
(forall s. CSRBuilder s a -> ST s ()) ->
CSR a
createCSR :: forall a.
Unbox a =>
Int -> Int -> (forall s. CSRBuilder s a -> ST s ()) -> CSR a
createCSR Int
n Int
m forall s. CSRBuilder s a -> ST s ()
run = (forall s. ST s (CSR a)) -> CSR a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (CSR a)) -> CSR a)
-> (forall s. ST s (CSR a)) -> CSR a
forall a b. (a -> b) -> a -> b
$ do
builder <- Int -> Int -> ST s (CSRBuilder (PrimState (ST s)) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> Int -> m (CSRBuilder (PrimState m) a)
newCSRBuilder Int
n Int
m
run builder
buildCSR builder
data CSRBuilder s a = CSRBuilder
{ forall s a. CSRBuilder s a -> Int
numRowsCSRB :: !Int
, forall s a. CSRBuilder s a -> Buffer s (Int, a)
queueCSRB :: !(Buffer s (Int, a))
, forall s a. CSRBuilder s a -> MVector s Int
outDegCSRB :: !(UM.MVector s Int)
}
newCSRBuilder ::
(U.Unbox a, PrimMonad m) =>
Int ->
Int ->
m (CSRBuilder (PrimState m) a)
newCSRBuilder :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> Int -> m (CSRBuilder (PrimState m) a)
newCSRBuilder Int
numRows Int
bufferSize =
Int
-> Buffer (PrimState m) (Int, a)
-> MVector (PrimState m) Int
-> CSRBuilder (PrimState m) a
forall s a.
Int -> Buffer s (Int, a) -> MVector s Int -> CSRBuilder s a
CSRBuilder Int
numRows
(Buffer (PrimState m) (Int, a)
-> MVector (PrimState m) Int -> CSRBuilder (PrimState m) a)
-> m (Buffer (PrimState m) (Int, a))
-> m (MVector (PrimState m) Int -> CSRBuilder (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m (Buffer (PrimState m) (Int, a))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer Int
bufferSize
m (MVector (PrimState m) Int -> CSRBuilder (PrimState m) a)
-> m (MVector (PrimState m) Int) -> m (CSRBuilder (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numRows Int
0
{-# INLINE newCSRBuilder #-}
buildCSR ::
(U.Unbox a, PrimMonad m) =>
CSRBuilder (PrimState m) a ->
m (CSR a)
buildCSR :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
CSRBuilder (PrimState m) a -> m (CSR a)
buildCSR CSRBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, a)
numRowsCSRB :: forall s a. CSRBuilder s a -> Int
queueCSRB :: forall s a. CSRBuilder s a -> Buffer s (Int, a)
outDegCSRB :: forall s a. CSRBuilder s a -> MVector s Int
numRowsCSRB :: Int
queueCSRB :: Buffer (PrimState m) (Int, a)
outDegCSRB :: MVector (PrimState m) Int
..} = do
m <- Buffer (PrimState m) (Int, a) -> m Int
forall (m :: * -> *) a.
PrimMonad m =>
Buffer (PrimState m) a -> m Int
lengthBuffer Buffer (PrimState m) (Int, a)
queueCSRB
offsetCSR <- U.scanl' (+) 0 <$> U.freeze outDegCSRB
moffset <- U.thaw offsetCSR
mbuffer <- UM.unsafeNew m
fix $ \m ()
loop -> do
Buffer (PrimState m) (Int, a) -> m (Maybe (Int, a))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popFront Buffer (PrimState m) (Int, a)
queueCSRB m (Maybe (Int, a)) -> (Maybe (Int, a) -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (Int
i, a
x) -> do
pos <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
moffset Int
i
UM.unsafeWrite moffset i (pos + 1)
UM.unsafeWrite mbuffer pos x
loop
Maybe (Int, a)
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
bufferCSR <- U.unsafeFreeze mbuffer
return $ CSR{numRowsCSR = numRowsCSRB, ..}
{-# INLINE buildCSR #-}
pushCSRB ::
(U.Unbox a, PrimMonad m) =>
(Int, a) ->
CSRBuilder (PrimState m) a ->
m ()
pushCSRB :: forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
(Int, a) -> CSRBuilder (PrimState m) a -> m ()
pushCSRB (Int
i, a
x) CSRBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, a)
numRowsCSRB :: forall s a. CSRBuilder s a -> Int
queueCSRB :: forall s a. CSRBuilder s a -> Buffer s (Int, a)
outDegCSRB :: forall s a. CSRBuilder s a -> MVector s Int
numRowsCSRB :: Int
queueCSRB :: Buffer (PrimState m) (Int, a)
outDegCSRB :: MVector (PrimState m) Int
..} = do
(Int, a) -> Buffer (PrimState m) (Int, a) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
i, a
x) Buffer (PrimState m) (Int, a)
queueCSRB
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
outDegCSRB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
i
{-# INLINE pushCSRB #-}