{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.Tree.LCT where

import Control.Monad
import Control.Monad.Primitive
import Data.Coerce
import Data.Function
import Data.Int
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

-- | for commutative monoids
data LCT s a = LCT
    { forall s a. LCT s a -> MVector s Int32
parentLCT :: UM.MVector s Int32
    , forall s a. LCT s a -> MVector s Int32
leftChildLCT :: UM.MVector s Int32
    , forall s a. LCT s a -> MVector s Int32
rightChildLCT :: UM.MVector s Int32
    , forall s a. LCT s a -> MVector s a
commMonoidLCT :: UM.MVector s a
    , forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: UM.MVector s a
    , forall s a. LCT s a -> MVector s Bool
lazyRevFlagLCT :: UM.MVector s Bool
    }

newLCT :: (U.Unbox a, Monoid a, PrimMonad m) => Int -> m (LCT (PrimState m) a)
newLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Int -> m (LCT (PrimState m) a)
newLCT Int
n =
    MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) a
-> MVector (PrimState m) a
-> MVector (PrimState m) Bool
-> LCT (PrimState m) a
forall s a.
MVector s Int32
-> MVector s Int32
-> MVector s Int32
-> MVector s a
-> MVector s a
-> MVector s Bool
-> LCT s a
LCT
        (MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) a
 -> MVector (PrimState m) a
 -> MVector (PrimState m) Bool
 -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
forall a. Monoid a => a
mempty
        m (MVector (PrimState m) a
   -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
forall a. Monoid a => a
mempty
        m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Bool) -> m (LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Bool
False

buildLCT :: (U.Unbox a, Monoid a, PrimMonad m) => U.Vector a -> m (LCT (PrimState m) a)
buildLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Vector a -> m (LCT (PrimState m) a)
buildLCT Vector a
vs =
    MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) a
-> MVector (PrimState m) a
-> MVector (PrimState m) Bool
-> LCT (PrimState m) a
forall s a.
MVector s Int32
-> MVector s Int32
-> MVector s Int32
-> MVector s a
-> MVector s a
-> MVector s Bool
-> LCT s a
LCT
        (MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) a
 -> MVector (PrimState m) a
 -> MVector (PrimState m) Bool
 -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Vector a -> m (MVector (PrimState m) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector a
vs
        m (MVector (PrimState m) a
   -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Vector a -> m (MVector (PrimState m) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector a
vs
        m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Bool) -> m (LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Bool
False
    where
        n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
vs

{- | make v root

>>> lct <- newLCT @() 3
>>> linkLCT lct 1 0 >> linkLCT lct 2 1
>>> findRootLCT lct 2
0
>>> evertLCT lct 1
>>> findRootLCT lct 2
1
-}
evertLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> m ()
evertLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} Int
v = do
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
lazyRevFlagLCT Int
v Bool
True
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE evertLCT #-}

-- | require: the edge @(u, v)@ exists
cutLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m ()
cutLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m ()
cutLCT LCT (PrimState m) a
lct Int
u Int
v = do
    LCT (PrimState m) a -> Int -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT LCT (PrimState m) a
lct Int
u
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    -- u is left child of v
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v) SplayNodeId
nothingLCT
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
u) SplayNodeId
nothingLCT
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE cutLCT #-}

{- | link u to v

require: @u@ and @v@ are *not connected*
-}
linkLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m ()
linkLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m ()
linkLCT LCT (PrimState m) a
lct Int
u Int
v = do
    LCT (PrimState m) a -> Int -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT LCT (PrimState m) a
lct Int
u
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
u) (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE linkLCT #-}

{- | require: @l@ and @r@ connected

>>> import Data.Monoid
>>> lct <- buildLCT @(Sum Int) $ U.fromList $ map Sum [0..3]
>>> linkLCT lct 1 0 >> linkLCT lct 2 1 >> linkLCT lct 3 1
>>> mconcatPathLCT lct 0 1
Sum {getSum = 1}
>>> mconcatPathLCT lct 2 3  -- 2 - 1 - 3
Sum {getSum = 6}
>>> mconcatPathLCT lct 0 3  -- 0 - 1 - 3
Sum {getSum = 4}
-}
mconcatPathLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m a
mconcatPathLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m a
mconcatPathLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} Int
l Int
r = do
    LCT (PrimState m) a -> Int -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT LCT (PrimState m) a
lct Int
l
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
r)
    MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
foldSubtreesLCT Int
r
{-# INLINE mconcatPathLCT #-}

-- | root is the left most node of the root path
findRootLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> m Int
findRootLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m Int
findRootLCT LCT (PrimState m) a
lct Int
v0 = do
    u0 <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v0)
    lu0 <- getLeftChildLCT lct u0
    if lu0 == nothingLCT
        then pure $ coerce @SplayNodeId u0
        else
            fix
                ( \SplayNodeId -> m Int
loop !SplayNodeId
v -> do
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct SplayNodeId
v
                    lv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v
                    if lv /= nothingLCT
                        then loop lv
                        else pure $ coerce @SplayNodeId v
                )
                lu0
{-# INLINE findRootLCT #-}

{- | commutative monoids

>>> import Data.Monoid
>>> lct <- buildLCT @(Sum Int) $ U.fromList $ map Sum [0..2]
>>> linkLCT lct 1 0 >> linkLCT lct 2 1
>>> mconcatPathLCT lct 0 2
Sum {getSum = 3}
>>> setCMonLCT lct 1 (Sum 100)
>>> mconcatPathLCT lct 0 2
Sum {getSum = 102}
-}
setCMonLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> a -> m ()
setCMonLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> a -> m ()
setCMonLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} Int
v a
x = do
    (SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
(SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
traverseDownLCT (LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct) LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
commMonoidLCT Int
v a
x
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE setCMonLCT #-}

{- | require: u and v are connected

>>> lct <- newLCT @() 4
>>> linkLCT lct 1 0 >> linkLCT lct 2 1 >> linkLCT lct 3 2
>>> evertLCT lct 0
>>> lcaLCT lct 0 3
0
>>> evertLCT lct 2
>>> lcaLCT lct 0 3
2
-}
lcaLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m Int
lcaLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m Int
lcaLCT LCT (PrimState m) a
t Int
u Int
v = do
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
t (Int -> SplayNodeId
asSplayNodeId Int
u)
    forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId (SplayNodeId -> Int) -> m SplayNodeId -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
t (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE lcaLCT #-}

newtype SplayNodeId = SplayNodeId {SplayNodeId -> Int
getSplayNodeId :: Int}
    deriving newtype (SplayNodeId -> SplayNodeId -> Bool
(SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool) -> Eq SplayNodeId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SplayNodeId -> SplayNodeId -> Bool
== :: SplayNodeId -> SplayNodeId -> Bool
$c/= :: SplayNodeId -> SplayNodeId -> Bool
/= :: SplayNodeId -> SplayNodeId -> Bool
Eq, Eq SplayNodeId
Eq SplayNodeId =>
(SplayNodeId -> SplayNodeId -> Ordering)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> Ord SplayNodeId
SplayNodeId -> SplayNodeId -> Bool
SplayNodeId -> SplayNodeId -> Ordering
SplayNodeId -> SplayNodeId -> SplayNodeId
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SplayNodeId -> SplayNodeId -> Ordering
compare :: SplayNodeId -> SplayNodeId -> Ordering
$c< :: SplayNodeId -> SplayNodeId -> Bool
< :: SplayNodeId -> SplayNodeId -> Bool
$c<= :: SplayNodeId -> SplayNodeId -> Bool
<= :: SplayNodeId -> SplayNodeId -> Bool
$c> :: SplayNodeId -> SplayNodeId -> Bool
> :: SplayNodeId -> SplayNodeId -> Bool
$c>= :: SplayNodeId -> SplayNodeId -> Bool
>= :: SplayNodeId -> SplayNodeId -> Bool
$cmax :: SplayNodeId -> SplayNodeId -> SplayNodeId
max :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cmin :: SplayNodeId -> SplayNodeId -> SplayNodeId
min :: SplayNodeId -> SplayNodeId -> SplayNodeId
Ord, Int -> SplayNodeId -> ShowS
[SplayNodeId] -> ShowS
SplayNodeId -> String
(Int -> SplayNodeId -> ShowS)
-> (SplayNodeId -> String)
-> ([SplayNodeId] -> ShowS)
-> Show SplayNodeId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SplayNodeId -> ShowS
showsPrec :: Int -> SplayNodeId -> ShowS
$cshow :: SplayNodeId -> String
show :: SplayNodeId -> String
$cshowList :: [SplayNodeId] -> ShowS
showList :: [SplayNodeId] -> ShowS
Show, Integer -> SplayNodeId
SplayNodeId -> SplayNodeId
SplayNodeId -> SplayNodeId -> SplayNodeId
(SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (Integer -> SplayNodeId)
-> Num SplayNodeId
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: SplayNodeId -> SplayNodeId -> SplayNodeId
+ :: SplayNodeId -> SplayNodeId -> SplayNodeId
$c- :: SplayNodeId -> SplayNodeId -> SplayNodeId
- :: SplayNodeId -> SplayNodeId -> SplayNodeId
$c* :: SplayNodeId -> SplayNodeId -> SplayNodeId
* :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cnegate :: SplayNodeId -> SplayNodeId
negate :: SplayNodeId -> SplayNodeId
$cabs :: SplayNodeId -> SplayNodeId
abs :: SplayNodeId -> SplayNodeId
$csignum :: SplayNodeId -> SplayNodeId
signum :: SplayNodeId -> SplayNodeId
$cfromInteger :: Integer -> SplayNodeId
fromInteger :: Integer -> SplayNodeId
Num, Num SplayNodeId
Ord SplayNodeId
(Num SplayNodeId, Ord SplayNodeId) =>
(SplayNodeId -> Rational) -> Real SplayNodeId
SplayNodeId -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
$ctoRational :: SplayNodeId -> Rational
toRational :: SplayNodeId -> Rational
Real, Int -> SplayNodeId
SplayNodeId -> Int
SplayNodeId -> [SplayNodeId]
SplayNodeId -> SplayNodeId
SplayNodeId -> SplayNodeId -> [SplayNodeId]
SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId]
(SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (Int -> SplayNodeId)
-> (SplayNodeId -> Int)
-> (SplayNodeId -> [SplayNodeId])
-> (SplayNodeId -> SplayNodeId -> [SplayNodeId])
-> (SplayNodeId -> SplayNodeId -> [SplayNodeId])
-> (SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId])
-> Enum SplayNodeId
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: SplayNodeId -> SplayNodeId
succ :: SplayNodeId -> SplayNodeId
$cpred :: SplayNodeId -> SplayNodeId
pred :: SplayNodeId -> SplayNodeId
$ctoEnum :: Int -> SplayNodeId
toEnum :: Int -> SplayNodeId
$cfromEnum :: SplayNodeId -> Int
fromEnum :: SplayNodeId -> Int
$cenumFrom :: SplayNodeId -> [SplayNodeId]
enumFrom :: SplayNodeId -> [SplayNodeId]
$cenumFromThen :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
enumFromThen :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
$cenumFromTo :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
enumFromTo :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
$cenumFromThenTo :: SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId]
enumFromThenTo :: SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId]
Enum, Enum SplayNodeId
Real SplayNodeId
(Real SplayNodeId, Enum SplayNodeId) =>
(SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId))
-> (SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId))
-> (SplayNodeId -> Integer)
-> Integral SplayNodeId
SplayNodeId -> Integer
SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
SplayNodeId -> SplayNodeId -> SplayNodeId
forall a.
(Real a, Enum a) =>
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> (a, a))
-> (a -> a -> (a, a))
-> (a -> Integer)
-> Integral a
$cquot :: SplayNodeId -> SplayNodeId -> SplayNodeId
quot :: SplayNodeId -> SplayNodeId -> SplayNodeId
$crem :: SplayNodeId -> SplayNodeId -> SplayNodeId
rem :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cdiv :: SplayNodeId -> SplayNodeId -> SplayNodeId
div :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cmod :: SplayNodeId -> SplayNodeId -> SplayNodeId
mod :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cquotRem :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
quotRem :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
$cdivMod :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
divMod :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
$ctoInteger :: SplayNodeId -> Integer
toInteger :: SplayNodeId -> Integer
Integral)

asSplayNodeId :: Int -> SplayNodeId
asSplayNodeId :: Int -> SplayNodeId
asSplayNodeId = Int -> SplayNodeId
forall a b. Coercible a b => a -> b
coerce
{-# INLINE asSplayNodeId #-}

getLeftChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
getLeftChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT{MVector (PrimState m) Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: MVector (PrimState m) Int32
leftChildLCT} SplayNodeId
v =
    Int32 -> SplayNodeId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> SplayNodeId) -> m Int32 -> m SplayNodeId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int32 -> Int -> m Int32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int32
leftChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v)
{-# INLINE getLeftChildLCT #-}

getRightChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
getRightChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT{MVector (PrimState m) Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: MVector (PrimState m) Int32
rightChildLCT} SplayNodeId
v =
    Int32 -> SplayNodeId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> SplayNodeId) -> m Int32 -> m SplayNodeId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int32 -> Int -> m Int32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int32
rightChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v)
{-# INLINE getRightChildLCT #-}

getParentLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
getParentLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT{MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
parentLCT :: MVector (PrimState m) Int32
parentLCT} SplayNodeId
v = Int32 -> SplayNodeId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> SplayNodeId) -> m Int32 -> m SplayNodeId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int32 -> Int -> m Int32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int32
parentLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v)
{-# INLINE getParentLCT #-}

setLeftChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    SplayNodeId ->
    m ()
setLeftChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT{MVector (PrimState m) Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: MVector (PrimState m) Int32
leftChildLCT} SplayNodeId
v SplayNodeId
lv =
    MVector (PrimState m) Int32 -> Int -> Int32 -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int32
leftChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int (Int -> Int32) -> Int -> Int32
forall a b. (a -> b) -> a -> b
$ SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
lv)
{-# INLINE setLeftChildLCT #-}

setRightChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    SplayNodeId ->
    m ()
setRightChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT{MVector (PrimState m) Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: MVector (PrimState m) Int32
rightChildLCT} SplayNodeId
v SplayNodeId
rv =
    MVector (PrimState m) Int32 -> Int -> Int32 -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int32
rightChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int (Int -> Int32) -> Int -> Int32
forall a b. (a -> b) -> a -> b
$ SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
rv)
{-# INLINE setRightChildLCT #-}

setParentLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    SplayNodeId ->
    m ()
setParentLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT{MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
parentLCT :: MVector (PrimState m) Int32
parentLCT} SplayNodeId
v SplayNodeId
pv =
    MVector (PrimState m) Int32 -> Int -> Int32 -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int32
parentLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int (Int -> Int32) -> Int -> Int32
forall a b. (a -> b) -> a -> b
$ SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
pv)
{-# INLINE setParentLCT #-}

nothingLCT :: SplayNodeId
nothingLCT :: SplayNodeId
nothingLCT = Int -> SplayNodeId
SplayNodeId (-Int
1)

isSplayTreeRootLCT :: (PrimMonad m) => LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT LCT (PrimState m) a
lct SplayNodeId
v = do
    pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
    if pv == nothingLCT
        then pure True
        else do
            lpv <- getLeftChildLCT lct pv
            rpv <- getRightChildLCT lct pv
            pure $! lpv /= v && rpv /= v
{-# INLINE isSplayTreeRootLCT #-}

pullLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} SplayNodeId
v = do
    lv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v
    mlv <-
        if lv /= nothingLCT
            then UM.unsafeRead foldSubtreesLCT $ coerce @SplayNodeId lv
            else pure mempty
    rv <- getRightChildLCT lct v
    mrv <-
        if rv /= nothingLCT
            then UM.unsafeRead foldSubtreesLCT $ coerce @SplayNodeId rv
            else pure mempty
    mv <- UM.unsafeRead commMonoidLCT $ coerce @SplayNodeId v
    UM.unsafeWrite foldSubtreesLCT (coerce @SplayNodeId v) $ mlv <> mv <> mrv
{-# INLINE pullLCT #-}

pushLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} SplayNodeId
v = do
    MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
lazyRevFlagLCT (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
v) m Bool -> (Bool -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
False -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Bool
True -> do
            MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
lazyRevFlagLCT (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
v) Bool
False
            lv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v
            rv <- getRightChildLCT lct v
            setLeftChildLCT lct v rv
            setRightChildLCT lct v lv
            when (lv /= nothingLCT) $ do
                UM.unsafeModify lazyRevFlagLCT not (coerce @SplayNodeId lv)
            when (rv /= nothingLCT) $ do
                UM.unsafeModify lazyRevFlagLCT not (coerce @SplayNodeId rv)
{-# INLINE pushLCT #-}

-- | from the splay tree root to v
traverseDownLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    (SplayNodeId -> m ()) ->
    LCT (PrimState m) a ->
    SplayNodeId ->
    m ()
traverseDownLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
(SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
traverseDownLCT SplayNodeId -> m ()
f LCT (PrimState m) a
lct = ((SplayNodeId -> m ()) -> SplayNodeId -> m ())
-> SplayNodeId -> m ()
forall a. (a -> a) -> a
fix (((SplayNodeId -> m ()) -> SplayNodeId -> m ())
 -> SplayNodeId -> m ())
-> ((SplayNodeId -> m ()) -> SplayNodeId -> m ())
-> SplayNodeId
-> m ()
forall a b. (a -> b) -> a -> b
$ \SplayNodeId -> m ()
goUp SplayNodeId
v -> do
    LCT (PrimState m) a -> SplayNodeId -> m Bool
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT LCT (PrimState m) a
lct SplayNodeId
v m Bool -> (Bool -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> SplayNodeId -> m ()
f SplayNodeId
v
        Bool
False -> do
            pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
            goUp pv
            f v
{-# INLINE traverseDownLCT #-}

{-
     pv           v
    /  \         / \
   v   rpv ==> lv  pv
  / \              / \
lv  rv           rv  rpv
-}
rotateRightLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    -- | has pasrent node
    SplayNodeId ->
    m ()
rotateRightLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v = do
    pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
    ppv <- getParentLCT lct pv
    setParentLCT lct v ppv
    when (ppv /= nothingLCT) $ do
        lppv <- getLeftChildLCT lct ppv
        when (lppv == pv) $ do
            setLeftChildLCT lct ppv v
        rppv <- getRightChildLCT lct ppv
        when (rppv == pv) $ do
            setRightChildLCT lct ppv v
    setParentLCT lct pv v
    rv <- getRightChildLCT lct v
    setLeftChildLCT lct pv rv
    setRightChildLCT lct v pv
    when (rv /= nothingLCT) $ do
        setParentLCT lct rv pv
    pullLCT lct pv
    pullLCT lct v

{-
{-# INLINE rotateRightLCT #-}
<no location info>: error:
    Simplifier ticks exhausted
-}

{-
   pv         v
  /  \       / \
lpv  v  ==> pv  rv
    / \    /  \
  lv  rv  lpv  lv
-}
rotateLeftLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    -- | has parent node
    SplayNodeId ->
    m ()
rotateLeftLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v = do
    pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
    ppv <- getParentLCT lct pv
    setParentLCT lct v ppv
    when (ppv /= nothingLCT) $ do
        lppv <- getLeftChildLCT lct ppv
        when (lppv == pv) $ do
            setLeftChildLCT lct ppv v
        rppv <- getRightChildLCT lct ppv
        when (rppv == pv) $ do
            setRightChildLCT lct ppv v
    setParentLCT lct pv v
    lv <- getLeftChildLCT lct v
    setRightChildLCT lct pv lv
    setLeftChildLCT lct v pv
    when (lv /= nothingLCT) $ do
        setParentLCT lct lv pv
    pullLCT lct pv
    pullLCT lct v

{-
-- {-# INLINE rotateLeftLCT #-}
<no location info>: error:
    Simplifier ticks exhausted
-}

splayLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m ()
splayLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct SplayNodeId
v =
    (m () -> m ()) -> m ()
forall a. (a -> a) -> a
fix ((m () -> m ()) -> m ()) -> (m () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \m ()
loop -> do
        isRoot <- LCT (PrimState m) a -> SplayNodeId -> m Bool
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT LCT (PrimState m) a
lct SplayNodeId
v
        unless isRoot $ do
            pv <- getParentLCT lct v
            ppv <- getParentLCT lct pv
            isRoot' <- isSplayTreeRootLCT lct pv
            if isRoot'
                then do
                    lpv <- getLeftChildLCT lct pv
                    if lpv == v
                        then rotateRightLCT lct v
                        else rotateLeftLCT lct v
                else do
                    lpv <- getLeftChildLCT lct pv
                    lppv <- getLeftChildLCT lct ppv
                    case (lppv == pv, lpv == v) of
                        (Bool
True, Bool
True) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
pv
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v
                        (Bool
True, Bool
False) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v
                        (Bool
False, Bool
True) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v
                        (Bool
False, Bool
False) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
pv
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v
            loop
{-# INLINE splayLCT #-}

{- |
make v on the root path

properties
* @v@ is the root of the splay tree.
* @v@ is the right most node of the root path.
* @expose u >> expose v == lca u v@.

>>> lct <- newLCT @() 2
>>> linkLCT lct 0 1
>>> evertLCT lct 0
>>> exposeLCT lct 1
1
>>> isSplayTreeRootLCT lct 1
True
>>> findRootLCT lct 1
0
-}
exposeLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
exposeLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct SplayNodeId
v0 = do
    ((SplayNodeId -> SplayNodeId -> m SplayNodeId)
 -> SplayNodeId -> SplayNodeId -> m SplayNodeId)
-> SplayNodeId -> SplayNodeId -> m SplayNodeId
forall a. (a -> a) -> a
fix
        ( \SplayNodeId -> SplayNodeId -> m SplayNodeId
goUp !SplayNodeId
v !SplayNodeId
rv ->
            if SplayNodeId
v SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT
                then do
                    (SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
(SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
traverseDownLCT (LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct) LCT (PrimState m) a
lct SplayNodeId
v
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct SplayNodeId
v
                    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
rv
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct SplayNodeId
v
                    pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
                    goUp pv v
                else do
                    -- rv is the root of splay tree
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct SplayNodeId
v0
                    SplayNodeId -> m SplayNodeId
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return SplayNodeId
rv
        )
        SplayNodeId
v0
        SplayNodeId
nothingLCT
{-# INLINE exposeLCT #-}