module Data.Graph.Tree.CentroidDecomposition where

import Control.Monad
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Graph.Sparse

data CentroidDecomposition = CentroidDecomposition
    { CentroidDecomposition -> Vector Int
parentCentroidCD :: U.Vector Int
    , CentroidDecomposition -> Vector Int
subtreeSizeCD :: U.Vector Int
    }

nothingCD :: Int
nothingCD :: Int
nothingCD = -Int
1

-- | /O(1)/
memberCD ::
    CentroidDecomposition ->
    -- centroid
    Vertex ->
    Vertex ->
    Bool
memberCD :: CentroidDecomposition -> Int -> Int -> Bool
memberCD CentroidDecomposition{Vector Int
subtreeSizeCD :: CentroidDecomposition -> Vector Int
subtreeSizeCD :: Vector Int
subtreeSizeCD} Int
centroid Int
v =
    Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
subtreeSizeCD Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
subtreeSizeCD Int
centroid
{-# INLINE memberCD #-}

centroidDecomposition :: SparseGraph w -> CentroidDecomposition
centroidDecomposition :: forall w. SparseGraph w -> CentroidDecomposition
centroidDecomposition SparseGraph w
gr = (forall s. ST s CentroidDecomposition) -> CentroidDecomposition
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s CentroidDecomposition) -> CentroidDecomposition)
-> (forall s. ST s CentroidDecomposition) -> CentroidDecomposition
forall a b. (a -> b) -> a -> b
$ do
    parent <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (SparseGraph w -> Int
forall w. SparseGraph w -> Int
numVerticesSG SparseGraph w
gr) Int
nothingCD
    subtreeSize <- UM.replicate (numVerticesSG gr) (0 :: Int)
    let root = Int
0
    fix
        ( \Int -> Int -> ST s ()
dfs Int
pv Int
v -> do
            Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
nv -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
pv) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                Int -> Int -> ST s ()
dfs Int
v Int
nv
            sz <-
                (Int -> Int -> Int) -> Int -> Vector Int -> Int
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
U.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
1
                    (Vector Int -> Int) -> ST s (Vector Int) -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> ST s Int) -> Vector Int -> ST s (Vector Int)
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
(a -> m b) -> Vector a -> m (Vector b)
U.mapM
                        (MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize)
                        ((Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
pv) (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v)
            UM.unsafeWrite subtreeSize v sz
        )
        nothingCD
        root

    fix
        ( \Int -> ST s ()
dfs Int
v -> do
            n <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
v
            sizes <- U.forM (gr `adj` v) $ \Int
nv -> do
                (,) Int
nv (Int -> (Int, Int)) -> ST s Int -> ST s (Int, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
nv
            case U.find (\(Int
_, Int
sz) -> Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Bool -> Bool -> Bool
&& Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) sizes of
                Just (Int
nv, Int
_) -> do
                    sznv <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
nv
                    szv <- UM.unsafeRead subtreeSize v
                    UM.unsafeWrite subtreeSize v $ szv - sznv
                    UM.unsafeWrite subtreeSize nv szv
                    UM.unsafeRead parent v >>= UM.unsafeWrite parent nv
                    UM.unsafeWrite parent v nothingCD
                    dfs nv
                Maybe (Int, Int)
Nothing -> do
                    -- v is centroid
                    Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
nv -> do
                        sz <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
nv
                        when (sz < n) $ do
                            UM.unsafeWrite parent nv v
                            dfs nv
        )
        root
    CentroidDecomposition
        <$> U.unsafeFreeze parent
        <*> U.unsafeFreeze subtreeSize