{-# LANGUAGE LambdaCase #-}
module Data.Graph.Sparse.TopSort where
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Buffer
import Data.Graph.Sparse
topSort :: SparseGraph w -> Maybe (U.Vector Int)
topSort :: forall w. SparseGraph w -> Maybe (Vector Int)
topSort SparseGraph w
gr = (forall s. ST s (Maybe (Vector Int))) -> Maybe (Vector Int)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Maybe (Vector Int))) -> Maybe (Vector Int))
-> (forall s. ST s (Maybe (Vector Int))) -> Maybe (Vector Int)
forall a b. (a -> b) -> a -> b
$ do
let n :: Int
n = SparseGraph w -> Int
forall w. SparseGraph w -> Int
numVerticesSG SparseGraph w
gr
q <- Int -> ST s (Buffer (PrimState (ST s)) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBufferAsQueue Int
n
let inDegree =
(Int -> Int -> Int)
-> Vector Int -> Vector (Int, Int) -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> Vector a -> Vector (Int, b) -> Vector a
U.unsafeAccumulate Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> Int -> Vector Int
forall a. Unbox a => Int -> a -> Vector a
U.replicate Int
n (Int
0 :: Int))
(Vector (Int, Int) -> Vector Int)
-> (Vector Int -> Vector (Int, Int)) -> Vector Int -> Vector Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> (Int, Int)) -> Vector Int -> Vector (Int, Int)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map ((Int -> Int -> (Int, Int)) -> Int -> Int -> (Int, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) Int
1)
(Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ SparseGraph w -> Vector Int
forall w. SparseGraph w -> Vector Int
adjacentSG SparseGraph w
gr
U.mapM_ (flip pushBack q . fst)
. U.filter ((== 0) . snd)
$ U.indexed inDegree
inDeg <- U.unsafeThaw inDegree
fix $ \ST s ()
loop -> do
Buffer (PrimState (ST s)) Int -> ST s (Maybe Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popFront Buffer s Int
Buffer (PrimState (ST s)) Int
q ST s (Maybe Int) -> (Maybe Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just Int
v -> do
Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
u -> do
MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
inDeg Int
u ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Int
1 -> Int -> Buffer (PrimState (ST s)) Int -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Int
u Buffer s Int
Buffer (PrimState (ST s)) Int
q
Int
i -> MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
inDeg Int
u (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
ST s ()
loop
Maybe Int
Nothing -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
buf <- unsafeFreezeInternalBuffer q
if U.length buf == n
then return $ Just buf
else return Nothing