{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.Unicyclic where

import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Buffer
import Data.Graph.Sparse
import My.Prelude

data UnicyclicComponents w = UnicyclicComponents
  { forall w. UnicyclicComponents w -> Int
numComponentsUC :: !Int
  -- ^ equals to the number of cycles
  , forall w. UnicyclicComponents w -> Vector Int
componentIDsUC :: !(U.Vector Int)
  , forall w. UnicyclicComponents w -> Vector Bool
onCycleUC :: !(U.Vector Bool)
  , forall w. UnicyclicComponents w -> SparseGraph w
inverseUC :: !(SparseGraph w)
  }

{- |
>>> uni = buildUnicyclicComponents 5 $ U.fromList [1,2,0,0,4]
>>> numComponentsUC uni
2
>>> componentIDsUC uni
[0,0,0,0,1]
>>> onCycleUC uni
[True,True,True,False,True]
-}
buildUnicyclicComponents ::
  -- | the number of vertices
  Int ->
  -- | next
  U.Vector Int ->
  UnicyclicComponents ()
buildUnicyclicComponents :: Int -> Vector Int -> UnicyclicComponents ()
buildUnicyclicComponents Int
n Vector Int
next =
  Int -> Vector (Int, ()) -> UnicyclicComponents ()
forall w.
Unbox w =>
Int -> Vector (Int, w) -> UnicyclicComponents w
buildUnicyclicComponentsW Int
n ((Int -> (Int, ())) -> Vector Int -> Vector (Int, ())
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map (,()) Vector Int
next)

buildUnicyclicComponentsW ::
  (U.Unbox w) =>
  -- | the number of vertices
  Int ->
  -- | (next, weight)
  U.Vector (Int, w) ->
  UnicyclicComponents w
buildUnicyclicComponentsW :: forall w.
Unbox w =>
Int -> Vector (Int, w) -> UnicyclicComponents w
buildUnicyclicComponentsW Int
n Vector (Int, w)
next = UnicyclicComponents{Int
Vector Bool
Vector Int
SparseGraph w
numComponentsUC :: Int
componentIDsUC :: Vector Int
onCycleUC :: Vector Bool
inverseUC :: SparseGraph w
inverseUC :: SparseGraph w
numComponentsUC :: Int
componentIDsUC :: Vector Int
onCycleUC :: Vector Bool
..}
  where
    !gr :: SparseGraph w
gr@SparseGraph w
inverseUC =
      Int -> Int -> Vector (EdgeWith w) -> SparseGraph w
forall w.
Unbox w =>
Int -> Int -> Vector (EdgeWith w) -> SparseGraph w
buildDirectedGraphW Int
n Int
n (Vector (EdgeWith w) -> SparseGraph w)
-> Vector (EdgeWith w) -> SparseGraph w
forall a b. (a -> b) -> a -> b
$
        (Int -> (Int, w) -> EdgeWith w)
-> Vector (Int, w) -> Vector (EdgeWith w)
forall a b.
(Unbox a, Unbox b) =>
(Int -> a -> b) -> Vector a -> Vector b
U.imap (\Int
v (Int
nv, w
w) -> (Int
nv, Int
v, w
w)) Vector (Int, w)
next
    (Int
numComponentsUC, Vector Int
componentIDsUC) = (forall s. ST s (Int, Vector Int)) -> (Int, Vector Int)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Int, Vector Int)) -> (Int, Vector Int))
-> (forall s. ST s (Int, Vector Int)) -> (Int, Vector Int)
forall a b. (a -> b) -> a -> b
$ do
      compIDs <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1 :: Int)
      num <-
        U.foldM'
          ( \Int
compID Int
root -> do
              MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
compIDs Int
root ST s Int -> (Int -> ST s Int) -> ST s Int
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                (-1) -> do
                  MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
compIDs Int
root Int
compID
                  ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix
                    ( \Int -> ST s ()
dfs Int
v -> do
                        Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
nv -> do
                          MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
compIDs Int
nv ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                            (-1) -> do
                              MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
compIDs Int
nv Int
compID
                              Int -> ST s ()
dfs Int
nv
                            Int
_ -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                    )
                    Int
root
                  Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
compID Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                Int
_ -> Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
compID
          )
          0
          (U.elemIndices True onCycleUC)
      (num,) <$> U.unsafeFreeze compIDs
    !onCycleUC :: Vector Bool
onCycleUC = (forall s. ST s (Vector Bool)) -> Vector Bool
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector Bool)) -> Vector Bool)
-> (forall s. ST s (Vector Bool)) -> Vector Bool
forall a b. (a -> b) -> a -> b
$ do
      inDeg <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (Int
0 :: Int)
      U.forM_ next $ \(Int
nv, w
_) -> do
        MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s Int
MVector (PrimState (ST s)) Int
inDeg (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
nv
      leaves <- newBufferAsQueue n
      rep n $ \Int
v -> do
        MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
inDeg Int
v ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Int
0 -> Int -> Buffer (PrimState (ST s)) Int -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Int
v Buffer s Int
Buffer (PrimState (ST s)) Int
leaves
          Int
_ -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      fix $ \ST s ()
loop -> do
        Buffer (PrimState (ST s)) Int -> ST s (Maybe Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popFront Buffer s Int
Buffer (PrimState (ST s)) Int
leaves ST s (Maybe Int) -> (Maybe Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Maybe Int
Nothing -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          Just Int
v -> do
            let (Int
nv, w
_) = Vector (Int, w) -> Int -> (Int, w)
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector (Int, w)
next Int
v
            MVector (PrimState (ST s)) Int -> (Int -> Int) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s Int
MVector (PrimState (ST s)) Int
inDeg (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) Int
nv
            MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
inDeg Int
nv ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Int
0 -> Int -> Buffer (PrimState (ST s)) Int -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Int
nv Buffer s Int
Buffer (PrimState (ST s)) Int
leaves
              Int
_ -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            ST s ()
loop
      U.map (== 1) <$> U.unsafeFreeze inDeg