{-# LANGUAGE LambdaCase #-}

module Data.Graph.Sparse.SCC where

import Control.Monad
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Buffer
import Data.Graph.Sparse
import My.Prelude (rep)

type ComponentId = Int

stronglyConnectedComponents :: SparseGraph w -> U.Vector ComponentId
stronglyConnectedComponents :: forall w. SparseGraph w -> Vector ComponentId
stronglyConnectedComponents SparseGraph w
gr = (forall s. ST s (Vector ComponentId)) -> Vector ComponentId
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector ComponentId)) -> Vector ComponentId)
-> (forall s. ST s (Vector ComponentId)) -> Vector ComponentId
forall a b. (a -> b) -> a -> b
$ do
  let numV :: ComponentId
numV = SparseGraph w -> ComponentId
forall w. SparseGraph w -> ComponentId
numVerticesSG SparseGraph w
gr
  low <- ComponentId
-> ComponentId -> ST s (MVector (PrimState (ST s)) ComponentId)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
ComponentId -> a -> m (MVector (PrimState m) a)
UM.replicate ComponentId
numV ComponentId
nothing
  preord <- UM.replicate numV nothing
  stack <- newBufferAsStack numV
  component <- UM.replicate numV nothing
  vars <- UM.replicate 2 0

  rep numV $ \ComponentId
root -> do
    rootOrd <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
root
    when (rootOrd == nothing) $ do
      flip fix root $ \ComponentId -> ST s ()
dfs ComponentId
v -> do
        preordId <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
vars ComponentId
_preordId
        UM.unsafeWrite vars _preordId (preordId + 1)

        UM.unsafeWrite preord v preordId
        UM.unsafeWrite low v preordId

        pushBack v stack

        U.forM_ (adj gr v) $ \ComponentId
u -> do
          ordU <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
u
          if ordU == nothing
            then do
              dfs u
              lowU <- UM.unsafeRead low u
              UM.unsafeModify low (min lowU) v
            else UM.unsafeModify low (min ordU) v

        lowV <- UM.unsafeRead low v
        ordV <- UM.unsafeRead preord v
        when (lowV == ordV) $ do
          compId <- UM.unsafeRead vars _compId
          fix $ \ST s ()
loop -> do
            Buffer (PrimState (ST s)) ComponentId -> ST s (Maybe ComponentId)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popBack Buffer s ComponentId
Buffer (PrimState (ST s)) ComponentId
stack ST s (Maybe ComponentId)
-> (Maybe ComponentId -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Just ComponentId
x -> do
                MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
x ComponentId
numV
                MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
component ComponentId
x ComponentId
compId
                Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ComponentId
x ComponentId -> ComponentId -> Bool
forall a. Eq a => a -> a -> Bool
/= ComponentId
v) ST s ()
loop
              Maybe ComponentId
Nothing -> ST s ()
forall a. HasCallStack => a
undefined
          UM.unsafeWrite vars _compId (compId + 1)
  maxCompId <- subtract 1 <$!> UM.unsafeRead vars _compId
  U.map (maxCompId -) <$> U.unsafeFreeze component
  where
    nothing :: ComponentId
nothing = -ComponentId
1
    _preordId :: ComponentId
_preordId = ComponentId
0
    _compId :: ComponentId
_compId = ComponentId
1