{-# LANGUAGE RecordWildCards #-}

module Data.Graph.Tree.LCA where

import Control.Monad
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

--
import Data.Buffer
import Data.Graph.Sparse
import Data.SparseTable

data LCA = LCA
  { LCA -> Vector Int
firstIndexLCA :: U.Vector Int
  -- ^ first index in Euler Tour
  , LCA -> RMQ (Int, Int)
rmqLCA :: RMQ (Int, Vertex)
  -- ^ Euler Tour RMQ (depth, vertex)
  }

buildLCA :: (U.Unbox w) => SparseGraph w -> Vertex -> LCA
buildLCA :: forall w. Unbox w => SparseGraph w -> Int -> LCA
buildLCA SparseGraph w
gr Int
root = (forall s. ST s LCA) -> LCA
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s LCA) -> LCA) -> (forall s. ST s LCA) -> LCA
forall a b. (a -> b) -> a -> b
$ do
  Buffer s (Int, Int)
met <- Int -> ST s (Buffer (PrimState (ST s)) (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* SparseGraph w -> Int
forall w. SparseGraph w -> Int
numVerticesSG SparseGraph w
gr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  MVector s Int
mfv <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (SparseGraph w -> Int
forall w. SparseGraph w -> Int
numVerticesSG SparseGraph w
gr)

  ((Int -> Int -> Int -> ST s ()) -> Int -> Int -> Int -> ST s ())
-> Int -> Int -> Int -> ST s ()
forall a. (a -> a) -> a
fix
    ( \Int -> Int -> Int -> ST s ()
dfs !Int
p !Int
d !Int
v -> do
        Buffer (PrimState (ST s)) (Int, Int) -> ST s Int
forall (m :: * -> *) a.
PrimMonad m =>
Buffer (PrimState m) a -> m Int
lengthBuffer Buffer s (Int, Int)
Buffer (PrimState (ST s)) (Int, Int)
met ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
mfv Int
v
        (Int, Int) -> Buffer (PrimState (ST s)) (Int, Int) -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
d, Int
v) Buffer s (Int, Int)
Buffer (PrimState (ST s)) (Int, Int)
met

        Vector Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Int -> Vector Int
forall w. SparseGraph w -> Int -> Vector Int
`adj` Int
v) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
nv -> do
          Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
p) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            Int -> Int -> Int -> ST s ()
dfs Int
v (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
nv
            (Int, Int) -> Buffer (PrimState (ST s)) (Int, Int) -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
d, Int
v) Buffer s (Int, Int)
Buffer (PrimState (ST s)) (Int, Int)
met
    )
    (-Int
1)
    Int
0
    Int
root

  Vector (Int, Int)
eulertour <- Buffer (PrimState (ST s)) (Int, Int) -> ST s (Vector (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Vector a)
unsafeFreezeBuffer Buffer s (Int, Int)
Buffer (PrimState (ST s)) (Int, Int)
met
  Vector Int
firstVisit <- MVector (PrimState (ST s)) Int -> ST s (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s Int
MVector (PrimState (ST s)) Int
mfv
  LCA -> ST s LCA
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LCA -> ST s LCA) -> LCA -> ST s LCA
forall a b. (a -> b) -> a -> b
$ Vector Int -> RMQ (Int, Int) -> LCA
LCA Vector Int
firstVisit (Vector (Int, Int) -> RMQ (Int, Int)
forall a. (Unbox a, Ord a) => Vector a -> RMQ a
buildRMQ Vector (Int, Int)
eulertour)

-- | /O(1)/
queryLCA :: LCA -> Vertex -> Vertex -> Vertex
queryLCA :: LCA -> Int -> Int -> Int
queryLCA LCA{Vector Int
RMQ (Int, Int)
firstIndexLCA :: LCA -> Vector Int
rmqLCA :: LCA -> RMQ (Int, Int)
firstIndexLCA :: Vector Int
rmqLCA :: RMQ (Int, Int)
..} Int
v Int
u =
  (Int, Int) -> Int
forall a b. (a, b) -> b
snd ((Int, Int) -> Int) -> (Int, Int) -> Int
forall a b. (a -> b) -> a -> b
$ RMQ (Int, Int) -> Int -> Int -> (Int, Int)
forall a. (Unbox a, Ord a) => RMQ a -> Int -> Int -> a
queryMin RMQ (Int, Int)
rmqLCA (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
i Int
j) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
i Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  where
    !i :: Int
i = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
firstIndexLCA Int
v
    !j :: Int
j = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
firstIndexLCA Int
u
{-# INLINE queryLCA #-}

-- | /O(1)/
queryDepth :: LCA -> Vertex -> Int
queryDepth :: LCA -> Int -> Int
queryDepth LCA{Vector Int
RMQ (Int, Int)
firstIndexLCA :: LCA -> Vector Int
rmqLCA :: LCA -> RMQ (Int, Int)
firstIndexLCA :: Vector Int
rmqLCA :: RMQ (Int, Int)
..} Int
v =
  (Int, Int) -> Int
forall a b. (a, b) -> a
fst ((Int, Int) -> Int) -> (Int -> (Int, Int)) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RMQ (Int, Int) -> Int -> (Int, Int)
forall a. Unbox a => RMQ a -> Int -> a
readRMQ RMQ (Int, Int)
rmqLCA (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
firstIndexLCA Int
v
{-# INLINE queryDepth #-}

-- | /O(1)/
queryDist :: LCA -> Vertex -> Vertex -> Int
queryDist :: LCA -> Int -> Int -> Int
queryDist LCA
lca Int
v Int
u =
  LCA -> Int -> Int
queryDepth LCA
lca Int
v
    Int -> Int -> Int
forall a. Num a => a -> a -> a
+ LCA -> Int -> Int
queryDepth LCA
lca Int
u
    Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* LCA -> Int -> Int
queryDepth LCA
lca (LCA -> Int -> Int -> Int
queryLCA LCA
lca Int
v Int
u)
{-# INLINE queryDist #-}