{-# LANGUAGE RecordWildCards #-}

module Data.Graph.Tree.HLD where

import Control.Monad
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Graph.Sparse
import My.Prelude ((..<))

type HLDIndex = Int

-- | Heavy-Light-Decomposition
data HLD = HLD
  { HLD -> Vector HLDIndex
indexHLD :: U.Vector HLDIndex
  , HLD -> Vector HLDIndex
parentHLD :: U.Vector Vertex
  , HLD -> Vector HLDIndex
pathHeadHLD :: U.Vector Vertex
  }
  deriving (HLDIndex -> HLD -> ShowS
[HLD] -> ShowS
HLD -> String
(HLDIndex -> HLD -> ShowS)
-> (HLD -> String) -> ([HLD] -> ShowS) -> Show HLD
forall a.
(HLDIndex -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: HLDIndex -> HLD -> ShowS
showsPrec :: HLDIndex -> HLD -> ShowS
$cshow :: HLD -> String
show :: HLD -> String
$cshowList :: [HLD] -> ShowS
showList :: [HLD] -> ShowS
Show)

-- | /O(log V)/
lcaHLD :: HLD -> Vertex -> Vertex -> Vertex
lcaHLD :: HLD -> HLDIndex -> HLDIndex -> HLDIndex
lcaHLD HLD{Vector HLDIndex
indexHLD :: HLD -> Vector HLDIndex
parentHLD :: HLD -> Vector HLDIndex
pathHeadHLD :: HLD -> Vector HLDIndex
indexHLD :: Vector HLDIndex
parentHLD :: Vector HLDIndex
pathHeadHLD :: Vector HLDIndex
..} = HLDIndex -> HLDIndex -> HLDIndex
go
  where
    go :: HLDIndex -> HLDIndex -> HLDIndex
go !HLDIndex
x !HLDIndex
y
      | HLDIndex
ix HLDIndex -> HLDIndex -> Bool
forall a. Ord a => a -> a -> Bool
> HLDIndex
iy = HLDIndex -> HLDIndex -> HLDIndex
go HLDIndex
y HLDIndex
x
      | Bool
otherwise =
          let !hx :: HLDIndex
hx = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
pathHeadHLD HLDIndex
x
              !hy :: HLDIndex
hy = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
pathHeadHLD HLDIndex
y
           in if HLDIndex
hx HLDIndex -> HLDIndex -> Bool
forall a. Eq a => a -> a -> Bool
/= HLDIndex
hy
                then HLDIndex -> HLDIndex -> HLDIndex
go HLDIndex
x (HLDIndex -> HLDIndex) -> HLDIndex -> HLDIndex
forall a b. (a -> b) -> a -> b
$ Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
parentHLD HLDIndex
hy
                else HLDIndex
x
      where
        !ix :: HLDIndex
ix = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
indexHLD HLDIndex
x
        !iy :: HLDIndex
iy = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
indexHLD HLDIndex
y

-- | /O(log V)/
pathHLD :: HLD -> Vertex -> Vertex -> [(HLDIndex, HLDIndex)]
pathHLD :: HLD -> HLDIndex -> HLDIndex -> [(HLDIndex, HLDIndex)]
pathHLD HLD{Vector HLDIndex
indexHLD :: HLD -> Vector HLDIndex
parentHLD :: HLD -> Vector HLDIndex
pathHeadHLD :: HLD -> Vector HLDIndex
indexHLD :: Vector HLDIndex
parentHLD :: Vector HLDIndex
pathHeadHLD :: Vector HLDIndex
..} = HLDIndex -> HLDIndex -> [(HLDIndex, HLDIndex)]
go
  where
    go :: HLDIndex -> HLDIndex -> [(HLDIndex, HLDIndex)]
go !HLDIndex
x !HLDIndex
y
      | HLDIndex
ix HLDIndex -> HLDIndex -> Bool
forall a. Ord a => a -> a -> Bool
> HLDIndex
iy = HLDIndex -> HLDIndex -> [(HLDIndex, HLDIndex)]
go HLDIndex
y HLDIndex
x
      | HLDIndex
hx HLDIndex -> HLDIndex -> Bool
forall a. Eq a => a -> a -> Bool
/= HLDIndex
hy =
          let !ihy :: HLDIndex
ihy = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
indexHLD HLDIndex
hy
              !iy' :: HLDIndex
iy' = HLDIndex
iy HLDIndex -> HLDIndex -> HLDIndex
forall a. Num a => a -> a -> a
+ HLDIndex
1
           in (HLDIndex
ihy, HLDIndex
iy') (HLDIndex, HLDIndex)
-> [(HLDIndex, HLDIndex)] -> [(HLDIndex, HLDIndex)]
forall a. a -> [a] -> [a]
: HLDIndex -> HLDIndex -> [(HLDIndex, HLDIndex)]
go HLDIndex
x (Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
parentHLD HLDIndex
hy)
      | HLDIndex
ix HLDIndex -> HLDIndex -> Bool
forall a. Eq a => a -> a -> Bool
== HLDIndex
iy = []
      | Bool
otherwise =
          let !ix' :: HLDIndex
ix' = HLDIndex
ix HLDIndex -> HLDIndex -> HLDIndex
forall a. Num a => a -> a -> a
+ HLDIndex
1
              !iy' :: HLDIndex
iy' = HLDIndex
iy HLDIndex -> HLDIndex -> HLDIndex
forall a. Num a => a -> a -> a
+ HLDIndex
1
           in [(HLDIndex
ix', HLDIndex
iy')]
      where
        !ix :: HLDIndex
ix = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
indexHLD HLDIndex
x
        !iy :: HLDIndex
iy = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
indexHLD HLDIndex
y
        hx :: HLDIndex
hx = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
pathHeadHLD HLDIndex
x
        hy :: HLDIndex
hy = Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.unsafeIndex Vector HLDIndex
pathHeadHLD HLDIndex
y

-- | /O(V)/
buildHLD :: Vertex -> SparseGraph w -> HLD
buildHLD :: forall w. HLDIndex -> SparseGraph w -> HLD
buildHLD HLDIndex
root gr :: SparseGraph w
gr@SparseGraph{HLDIndex
Vector w
Vector HLDIndex
numVerticesSG :: HLDIndex
numEdgesSG :: HLDIndex
offsetSG :: Vector HLDIndex
adjacentSG :: Vector HLDIndex
edgeCtxSG :: Vector w
edgeCtxSG :: forall w. SparseGraph w -> Vector w
adjacentSG :: forall w. SparseGraph w -> Vector HLDIndex
offsetSG :: forall w. SparseGraph w -> Vector HLDIndex
numEdgesSG :: forall w. SparseGraph w -> HLDIndex
numVerticesSG :: forall w. SparseGraph w -> HLDIndex
..}
  | HLDIndex
numEdgesSG HLDIndex -> HLDIndex -> Bool
forall a. Eq a => a -> a -> Bool
/= HLDIndex
2 HLDIndex -> HLDIndex -> HLDIndex
forall a. Num a => a -> a -> a
* (HLDIndex
numVerticesSG HLDIndex -> HLDIndex -> HLDIndex
forall a. Num a => a -> a -> a
- HLDIndex
1) = String -> HLD
forall a. HasCallStack => String -> a
error String
"not undirected tree"
  | Bool
otherwise = (forall s. ST s HLD) -> HLD
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s HLD) -> HLD) -> (forall s. ST s HLD) -> HLD
forall a b. (a -> b) -> a -> b
$ do
      mindexHLD <- HLDIndex -> ST s (MVector (PrimState (ST s)) HLDIndex)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
HLDIndex -> m (MVector (PrimState m) a)
UM.unsafeNew HLDIndex
numVerticesSG
      mparentHLD <- UM.replicate numVerticesSG nothing
      mpathHeadHLD <- UM.replicate numVerticesSG nothing

      madjacent <- U.thaw adjacentSG
      void $
        fix
          ( \HLDIndex -> HLDIndex -> ST s HLDIndex
dfs HLDIndex
pv HLDIndex
v -> do
              MVector (PrimState (ST s)) HLDIndex
-> HLDIndex -> HLDIndex -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> HLDIndex -> a -> m ()
UM.write MVector s HLDIndex
MVector (PrimState (ST s)) HLDIndex
mparentHLD HLDIndex
v HLDIndex
pv
              (size, (_, heavyId)) <-
                ((HLDIndex, (HLDIndex, HLDIndex))
 -> (HLDIndex, HLDIndex) -> ST s (HLDIndex, (HLDIndex, HLDIndex)))
-> (HLDIndex, (HLDIndex, HLDIndex))
-> Vector (HLDIndex, HLDIndex)
-> ST s (HLDIndex, (HLDIndex, HLDIndex))
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
U.foldM'
                  ( \(!HLDIndex
sz, !(HLDIndex, HLDIndex)
mm) (HLDIndex
ei, HLDIndex
nv) -> do
                      sz' <- HLDIndex -> HLDIndex -> ST s HLDIndex
dfs HLDIndex
v HLDIndex
nv
                      return (sz + sz', max mm (sz', ei))
                  )
                  (HLDIndex
1 :: Int, (HLDIndex
0, HLDIndex
nothing))
                  (Vector (HLDIndex, HLDIndex)
 -> ST s (HLDIndex, (HLDIndex, HLDIndex)))
-> (Vector (HLDIndex, HLDIndex) -> Vector (HLDIndex, HLDIndex))
-> Vector (HLDIndex, HLDIndex)
-> ST s (HLDIndex, (HLDIndex, HLDIndex))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HLDIndex, HLDIndex) -> Bool)
-> Vector (HLDIndex, HLDIndex) -> Vector (HLDIndex, HLDIndex)
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter ((HLDIndex -> HLDIndex -> Bool
forall a. Eq a => a -> a -> Bool
/= HLDIndex
pv) (HLDIndex -> Bool)
-> ((HLDIndex, HLDIndex) -> HLDIndex)
-> (HLDIndex, HLDIndex)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HLDIndex, HLDIndex) -> HLDIndex
forall a b. (a, b) -> b
snd)
                  (Vector (HLDIndex, HLDIndex)
 -> ST s (HLDIndex, (HLDIndex, HLDIndex)))
-> Vector (HLDIndex, HLDIndex)
-> ST s (HLDIndex, (HLDIndex, HLDIndex))
forall a b. (a -> b) -> a -> b
$ SparseGraph w
gr SparseGraph w -> HLDIndex -> Vector (HLDIndex, HLDIndex)
forall w. SparseGraph w -> HLDIndex -> Vector (HLDIndex, HLDIndex)
`iadj` HLDIndex
v
              when (heavyId /= nothing) $ do
                UM.swap madjacent heavyId (offsetSG U.! v)
              return size
          )
          nothing
          root
      void $
        fix
          ( \HLDIndex -> HLDIndex -> HLDIndex -> HLDIndex -> ST s HLDIndex
dfs HLDIndex
i HLDIndex
h HLDIndex
pv HLDIndex
v -> do
              MVector (PrimState (ST s)) HLDIndex
-> HLDIndex -> HLDIndex -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> HLDIndex -> a -> m ()
UM.write MVector s HLDIndex
MVector (PrimState (ST s)) HLDIndex
mindexHLD HLDIndex
v HLDIndex
i
              MVector (PrimState (ST s)) HLDIndex
-> HLDIndex -> HLDIndex -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> HLDIndex -> a -> m ()
UM.write MVector s HLDIndex
MVector (PrimState (ST s)) HLDIndex
mpathHeadHLD HLDIndex
v HLDIndex
h
              let o :: HLDIndex
o = Vector HLDIndex
offsetSG Vector HLDIndex -> HLDIndex -> HLDIndex
forall a. Unbox a => Vector a -> HLDIndex -> a
U.! HLDIndex
v
              nv0 <- MVector (PrimState (ST s)) HLDIndex -> HLDIndex -> ST s HLDIndex
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> HLDIndex -> m a
UM.read MVector s HLDIndex
MVector (PrimState (ST s)) HLDIndex
madjacent HLDIndex
o
              acc0 <- if nv0 /= pv then dfs (i + 1) h v nv0 else pure i
              MS.foldM'
                ( \HLDIndex
acc HLDIndex
j -> do
                    nv <- MVector (PrimState (ST s)) HLDIndex -> HLDIndex -> ST s HLDIndex
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> HLDIndex -> m a
UM.read MVector s HLDIndex
MVector (PrimState (ST s)) HLDIndex
madjacent HLDIndex
j
                    if nv /= pv
                      then dfs (acc + 1) nv v nv
                      else pure acc
                )
                acc0
                $ (o + 1) ..< offsetSG U.! (v + 1)
          )
          0
          root
          nothing
          root

      HLD
        <$> U.unsafeFreeze mindexHLD
        <*> U.unsafeFreeze mparentHLD
        <*> U.unsafeFreeze mpathHeadHLD
  where
    nothing :: HLDIndex
nothing = -HLDIndex
1