{-# LANGUAGE LambdaCase #-}

module Data.Graph.Sparse.TopSort where

import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Buffer
import Data.Graph.Sparse

{- |
>>> topSort $ buildDirectedGraph 4 4 $ U.fromList [(0,1),(0,2),(1,3),(2,3)]
Just [0,1,2,3]
>>> topSort $ buildDirectedGraph 2 0 U.empty
Just [0,1]
>>> topSort $ buildDirectedGraph 2 2 $ U.fromList [(0,1),(1,0)]
Nothing
-}
topSort :: SparseGraph w -> Maybe (U.Vector Int)
topSort :: forall w. SparseGraph w -> Maybe (Vector Int)
topSort SparseGraph w
gr = (forall s. ST s (Maybe (Vector Int))) -> Maybe (Vector Int)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Maybe (Vector Int))) -> Maybe (Vector Int))
-> (forall s. ST s (Maybe (Vector Int))) -> Maybe (Vector Int)
forall a b. (a -> b) -> a -> b
$ do
  let n :: Int
n = SparseGraph w -> Int
forall w. SparseGraph w -> Int
numVerticesSG SparseGraph w
gr
  Buffer s Int
q <- Int -> ST s (Buffer (PrimState (ST s)) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBufferAsQueue Int
n
  let inDegree :: Vector Int
inDegree =
        (Int -> Int -> Int)
-> Vector Int -> Vector (Int, Int) -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> Vector a -> Vector (Int, b) -> Vector a
U.unsafeAccumulate Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> Int -> Vector Int
forall a. Unbox a => Int -> a -> Vector a
U.replicate Int
n (Int
0 :: Int))
          (Vector (Int, Int) -> Vector Int)
-> (Vector Int -> Vector (Int, Int)) -> Vector Int -> Vector Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> (Int, Int)) -> Vector Int -> Vector (Int, Int)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map ((Int -> Int -> (Int, Int)) -> Int -> Int -> (Int, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) Int
1)
          (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ SparseGraph w -> Vector Int
forall w. SparseGraph w -> Vector Int
adjacentSG SparseGraph w
gr
  ((Int, Int) -> ST s ()) -> Vector (Int, Int) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
(a -> m b) -> Vector a -> m ()
U.mapM_ ((Int -> Buffer s Int -> ST s ()) -> Buffer s Int -> Int -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Buffer s Int -> ST s ()
Int -> Buffer (PrimState (ST s)) Int -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Buffer s Int
q (Int -> ST s ()) -> ((Int, Int) -> Int) -> (Int, Int) -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Int) -> Int
forall a b. (a, b) -> a
fst)
    (Vector (Int, Int) -> ST s ())
-> (Vector (Int, Int) -> Vector (Int, Int))
-> Vector (Int, Int)
-> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Int) -> Bool) -> Vector (Int, Int) -> Vector (Int, Int)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> ((Int, Int) -> Int) -> (Int, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Int) -> Int
forall a b. (a, b) -> b
snd)
    (Vector (Int, Int) -> ST s ()) -> Vector (Int, Int) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector (Int, Int)
forall a. Unbox a => Vector a -> Vector (Int, a)
U.indexed Vector Int
inDegree
  MVector s Int
inDeg <- Vector Int -> ST s (MVector (PrimState (ST s)) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.unsafeThaw Vector Int
inDegree
  (ST s () -> ST s ()) -> ST s ()
forall a. (a -> a) -> a
fix ((ST s () -> ST s ()) -> ST s ())
-> (ST s () -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ST s ()
loop -> do
    Buffer (PrimState (ST s)) Int -> ST s (Maybe Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popFront Buffer s Int
Buffer (PrimState (ST s)) Int
q ST s (Maybe Int) -> (Maybe Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just Int
v -> do
        Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
u -> do
          MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
inDeg Int
u ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Int
1 -> Int -> Buffer (PrimState (ST s)) Int -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Int
u Buffer s Int
Buffer (PrimState (ST s)) Int
q
            Int
i -> MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
inDeg Int
u (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        ST s ()
loop
      Maybe Int
Nothing -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  Vector Int
buf <- Buffer (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Vector a)
unsafeFreezeInternalBuffer Buffer s Int
Buffer (PrimState (ST s)) Int
q
  if Vector Int -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector Int
buf Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
    then Maybe (Vector Int) -> ST s (Maybe (Vector Int))
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Vector Int) -> ST s (Maybe (Vector Int)))
-> Maybe (Vector Int) -> ST s (Maybe (Vector Int))
forall a b. (a -> b) -> a -> b
$ Vector Int -> Maybe (Vector Int)
forall a. a -> Maybe a
Just Vector Int
buf
    else Maybe (Vector Int) -> ST s (Maybe (Vector Int))
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Vector Int)
forall a. Maybe a
Nothing