module Data.Graph.Tree.CentroidDecomposition where

import Control.Monad
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Graph.Sparse

data CentroidDecomposition = CentroidDecomposition
    { CentroidDecomposition -> Vector Int
parentCentroidCD :: U.Vector Int
    , CentroidDecomposition -> Vector Int
subtreeSizeCD :: U.Vector Int
    }

nothingCD :: Int
nothingCD :: Int
nothingCD = -Int
1

-- | /O(1)/
memberCD ::
    CentroidDecomposition ->
    -- centroid
    Vertex ->
    Vertex ->
    Bool
memberCD :: CentroidDecomposition -> Int -> Int -> Bool
memberCD CentroidDecomposition{Vector Int
subtreeSizeCD :: CentroidDecomposition -> Vector Int
subtreeSizeCD :: Vector Int
subtreeSizeCD} Int
centroid Int
v =
    Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
subtreeSizeCD Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
subtreeSizeCD Int
centroid
{-# INLINE memberCD #-}

centroidDecomposition :: SparseGraph w -> CentroidDecomposition
centroidDecomposition :: forall w. SparseGraph w -> CentroidDecomposition
centroidDecomposition SparseGraph w
gr = (forall s. ST s CentroidDecomposition) -> CentroidDecomposition
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s CentroidDecomposition) -> CentroidDecomposition)
-> (forall s. ST s CentroidDecomposition) -> CentroidDecomposition
forall a b. (a -> b) -> a -> b
$ do
    MVector s Int
parent <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (SparseGraph w -> Int
forall w. SparseGraph w -> Int
numVerticesSG SparseGraph w
gr) Int
nothingCD
    MVector s Int
subtreeSize <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (SparseGraph w -> Int
forall w. SparseGraph w -> Int
numVerticesSG SparseGraph w
gr) (Int
0 :: Int)
    let root :: Int
root = Int
0
    ((Int -> Int -> ST s ()) -> Int -> Int -> ST s ())
-> Int -> Int -> ST s ()
forall a. (a -> a) -> a
fix
        ( \Int -> Int -> ST s ()
dfs Int
pv Int
v -> do
            Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
nv -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
pv) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                Int -> Int -> ST s ()
dfs Int
v Int
nv
            Int
sz <-
                (Int -> Int -> Int) -> Int -> Vector Int -> Int
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
U.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
1
                    (Vector Int -> Int) -> ST s (Vector Int) -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> ST s Int) -> Vector Int -> ST s (Vector Int)
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
(a -> m b) -> Vector a -> m (Vector b)
U.mapM
                        (MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize)
                        ((Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
pv) (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v)
            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
v Int
sz
        )
        Int
nothingCD
        Int
root

    ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix
        ( \Int -> ST s ()
dfs Int
v -> do
            Int
n <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
v
            Vector (Int, Int)
sizes <- Vector Int -> (Int -> ST s (Int, Int)) -> ST s (Vector (Int, Int))
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
Vector a -> (a -> m b) -> m (Vector b)
U.forM (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s (Int, Int)) -> ST s (Vector (Int, Int)))
-> (Int -> ST s (Int, Int)) -> ST s (Vector (Int, Int))
forall a b. (a -> b) -> a -> b
$ \Int
nv -> do
                (,) Int
nv (Int -> (Int, Int)) -> ST s Int -> ST s (Int, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
nv
            case ((Int, Int) -> Bool) -> Vector (Int, Int) -> Maybe (Int, Int)
forall a. Unbox a => (a -> Bool) -> Vector a -> Maybe a
U.find (\(Int
_, Int
sz) -> Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Bool -> Bool -> Bool
&& Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) Vector (Int, Int)
sizes of
                Just (Int
nv, Int
_) -> do
                    Int
sznv <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
nv
                    Int
szv <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
v
                    MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
v (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
szv Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sznv
                    MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
nv Int
szv
                    MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
parent Int
v ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
parent Int
nv
                    MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
parent Int
v Int
nothingCD
                    Int -> ST s ()
dfs Int
nv
                Maybe (Int, Int)
Nothing -> do
                    -- v is centroid
                    Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
nv -> do
                        Int
sz <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize Int
nv
                        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
parent Int
nv Int
v
                            Int -> ST s ()
dfs Int
nv
        )
        Int
root
    Vector Int -> Vector Int -> CentroidDecomposition
CentroidDecomposition
        (Vector Int -> Vector Int -> CentroidDecomposition)
-> ST s (Vector Int) -> ST s (Vector Int -> CentroidDecomposition)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
parent
        ST s (Vector Int -> CentroidDecomposition)
-> ST s (Vector Int) -> ST s CentroidDecomposition
forall a b. ST s (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
subtreeSize