{-# LANGUAGE LambdaCase #-}

module Data.Graph.Sparse.SCC where

import Control.Monad
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Buffer
import Data.Graph.Sparse
import My.Prelude (rep)

type ComponentId = Int

stronglyConnectedComponents :: SparseGraph w -> U.Vector ComponentId
stronglyConnectedComponents :: forall w. SparseGraph w -> Vector ComponentId
stronglyConnectedComponents SparseGraph w
gr = (forall s. ST s (Vector ComponentId)) -> Vector ComponentId
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector ComponentId)) -> Vector ComponentId)
-> (forall s. ST s (Vector ComponentId)) -> Vector ComponentId
forall a b. (a -> b) -> a -> b
$ do
  let numV :: ComponentId
numV = SparseGraph w -> ComponentId
forall w. SparseGraph w -> ComponentId
numVerticesSG SparseGraph w
gr
  MVector s ComponentId
low <- ComponentId
-> ComponentId -> ST s (MVector (PrimState (ST s)) ComponentId)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
ComponentId -> a -> m (MVector (PrimState m) a)
UM.replicate ComponentId
numV ComponentId
nothing
  MVector s ComponentId
preord <- ComponentId
-> ComponentId -> ST s (MVector (PrimState (ST s)) ComponentId)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
ComponentId -> a -> m (MVector (PrimState m) a)
UM.replicate ComponentId
numV ComponentId
nothing
  Buffer s ComponentId
stack <- ComponentId -> ST s (Buffer (PrimState (ST s)) ComponentId)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
ComponentId -> m (Buffer (PrimState m) a)
newBufferAsStack ComponentId
numV
  MVector s ComponentId
component <- ComponentId
-> ComponentId -> ST s (MVector (PrimState (ST s)) ComponentId)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
ComponentId -> a -> m (MVector (PrimState m) a)
UM.replicate ComponentId
numV ComponentId
nothing
  MVector s ComponentId
vars <- ComponentId
-> ComponentId -> ST s (MVector (PrimState (ST s)) ComponentId)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
ComponentId -> a -> m (MVector (PrimState m) a)
UM.replicate ComponentId
2 ComponentId
0

  ComponentId -> (ComponentId -> ST s ()) -> ST s ()
forall (m :: * -> *).
Monad m =>
ComponentId -> (ComponentId -> m ()) -> m ()
rep ComponentId
numV ((ComponentId -> ST s ()) -> ST s ())
-> (ComponentId -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ComponentId
root -> do
    ComponentId
rootOrd <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
root
    Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ComponentId
rootOrd ComponentId -> ComponentId -> Bool
forall a. Eq a => a -> a -> Bool
== ComponentId
nothing) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
      (((ComponentId -> ST s ()) -> ComponentId -> ST s ())
 -> ComponentId -> ST s ())
-> ComponentId
-> ((ComponentId -> ST s ()) -> ComponentId -> ST s ())
-> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((ComponentId -> ST s ()) -> ComponentId -> ST s ())
-> ComponentId -> ST s ()
forall a. (a -> a) -> a
fix ComponentId
root (((ComponentId -> ST s ()) -> ComponentId -> ST s ()) -> ST s ())
-> ((ComponentId -> ST s ()) -> ComponentId -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ComponentId -> ST s ()
dfs ComponentId
v -> do
        ComponentId
preordId <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
vars ComponentId
_preordId
        MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
vars ComponentId
_preordId (ComponentId
preordId ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
1)

        MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
v ComponentId
preordId
        MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
low ComponentId
v ComponentId
preordId

        ComponentId -> Buffer (PrimState (ST s)) ComponentId -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack ComponentId
v Buffer s ComponentId
Buffer (PrimState (ST s)) ComponentId
stack

        Vector ComponentId -> (ComponentId -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w -> ComponentId -> Vector ComponentId
forall w. SparseGraph w -> ComponentId -> Vector ComponentId
adj SparseGraph w
gr ComponentId
v) ((ComponentId -> ST s ()) -> ST s ())
-> (ComponentId -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ComponentId
u -> do
          ComponentId
ordU <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
u
          if ComponentId
ordU ComponentId -> ComponentId -> Bool
forall a. Eq a => a -> a -> Bool
== ComponentId
nothing
            then do
              ComponentId -> ST s ()
dfs ComponentId
u
              ComponentId
lowU <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
low ComponentId
u
              MVector (PrimState (ST s)) ComponentId
-> (ComponentId -> ComponentId) -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> ComponentId -> m ()
UM.unsafeModify MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
low (ComponentId -> ComponentId -> ComponentId
forall a. Ord a => a -> a -> a
min ComponentId
lowU) ComponentId
v
            else MVector (PrimState (ST s)) ComponentId
-> (ComponentId -> ComponentId) -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> ComponentId -> m ()
UM.unsafeModify MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
low (ComponentId -> ComponentId -> ComponentId
forall a. Ord a => a -> a -> a
min ComponentId
ordU) ComponentId
v

        ComponentId
lowV <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
low ComponentId
v
        ComponentId
ordV <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
v
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ComponentId
lowV ComponentId -> ComponentId -> Bool
forall a. Eq a => a -> a -> Bool
== ComponentId
ordV) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
          ComponentId
compId <- MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
vars ComponentId
_compId
          (ST s () -> ST s ()) -> ST s ()
forall a. (a -> a) -> a
fix ((ST s () -> ST s ()) -> ST s ())
-> (ST s () -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ST s ()
loop -> do
            Buffer (PrimState (ST s)) ComponentId -> ST s (Maybe ComponentId)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popBack Buffer s ComponentId
Buffer (PrimState (ST s)) ComponentId
stack ST s (Maybe ComponentId)
-> (Maybe ComponentId -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Just ComponentId
x -> do
                MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
preord ComponentId
x ComponentId
numV
                MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
component ComponentId
x ComponentId
compId
                Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ComponentId
x ComponentId -> ComponentId -> Bool
forall a. Eq a => a -> a -> Bool
/= ComponentId
v) ST s ()
loop
              Maybe ComponentId
Nothing -> ST s ()
forall a. HasCallStack => a
undefined
          MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ComponentId -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> a -> m ()
UM.unsafeWrite MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
vars ComponentId
_compId (ComponentId
compId ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
1)
  ComponentId
maxCompId <- ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
subtract ComponentId
1 (ComponentId -> ComponentId)
-> ST s ComponentId -> ST s ComponentId
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState (ST s)) ComponentId
-> ComponentId -> ST s ComponentId
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> ComponentId -> m a
UM.unsafeRead MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
vars ComponentId
_compId
  (ComponentId -> ComponentId)
-> Vector ComponentId -> Vector ComponentId
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map (ComponentId
maxCompId -) (Vector ComponentId -> Vector ComponentId)
-> ST s (Vector ComponentId) -> ST s (Vector ComponentId)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) ComponentId -> ST s (Vector ComponentId)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s ComponentId
MVector (PrimState (ST s)) ComponentId
component
  where
    nothing :: ComponentId
nothing = -ComponentId
1
    _preordId :: ComponentId
_preordId = ComponentId
0
    _compId :: ComponentId
_compId = ComponentId
1