{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.Tree.LCT where

import Control.Monad
import Control.Monad.Primitive
import Data.Coerce
import Data.Function
import Data.Int
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

-- | for commutative monoids
data LCT s a = LCT
    { forall s a. LCT s a -> MVector s Int32
parentLCT :: UM.MVector s Int32
    , forall s a. LCT s a -> MVector s Int32
leftChildLCT :: UM.MVector s Int32
    , forall s a. LCT s a -> MVector s Int32
rightChildLCT :: UM.MVector s Int32
    , forall s a. LCT s a -> MVector s a
commMonoidLCT :: UM.MVector s a
    , forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: UM.MVector s a
    , forall s a. LCT s a -> MVector s Bool
lazyRevFlagLCT :: UM.MVector s Bool
    }

newLCT :: (U.Unbox a, Monoid a, PrimMonad m) => Int -> m (LCT (PrimState m) a)
newLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Int -> m (LCT (PrimState m) a)
newLCT Int
n =
    MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) a
-> MVector (PrimState m) a
-> MVector (PrimState m) Bool
-> LCT (PrimState m) a
forall s a.
MVector s Int32
-> MVector s Int32
-> MVector s Int32
-> MVector s a
-> MVector s a
-> MVector s Bool
-> LCT s a
LCT
        (MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) a
 -> MVector (PrimState m) a
 -> MVector (PrimState m) Bool
 -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
forall a. Monoid a => a
mempty
        m (MVector (PrimState m) a
   -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
forall a. Monoid a => a
mempty
        m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Bool) -> m (LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Bool
False

buildLCT :: (U.Unbox a, Monoid a, PrimMonad m) => U.Vector a -> m (LCT (PrimState m) a)
buildLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Vector a -> m (LCT (PrimState m) a)
buildLCT Vector a
vs =
    MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) Int32
-> MVector (PrimState m) a
-> MVector (PrimState m) a
-> MVector (PrimState m) Bool
-> LCT (PrimState m) a
forall s a.
MVector s Int32
-> MVector s Int32
-> MVector s Int32
-> MVector s a
-> MVector s a
-> MVector s Bool
-> LCT s a
LCT
        (MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) Int32
 -> MVector (PrimState m) a
 -> MVector (PrimState m) a
 -> MVector (PrimState m) Bool
 -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) Int32
      -> MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) Int32
   -> MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Int32)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) a
      -> MVector (PrimState m) Bool
      -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int32 -> m (MVector (PrimState m) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (SplayNodeId -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral SplayNodeId
nothingLCT)
        m (MVector (PrimState m) a
   -> MVector (PrimState m) a
   -> MVector (PrimState m) Bool
   -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) a
      -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Vector a -> m (MVector (PrimState m) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector a
vs
        m (MVector (PrimState m) a
   -> MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Vector a -> m (MVector (PrimState m) a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector a
vs
        m (MVector (PrimState m) Bool -> LCT (PrimState m) a)
-> m (MVector (PrimState m) Bool) -> m (LCT (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Bool
False
    where
        n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
vs

{- | make v root

>>> lct <- newLCT @() 3
>>> linkLCT lct 1 0 >> linkLCT lct 2 1
>>> findRootLCT lct 2
0
>>> evertLCT lct 1
>>> findRootLCT lct 2
1
-}
evertLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> m ()
evertLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} Int
v = do
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
lazyRevFlagLCT Int
v Bool
True
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE evertLCT #-}

-- | require: the edge @(u, v)@ exists
cutLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m ()
cutLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m ()
cutLCT LCT (PrimState m) a
lct Int
u Int
v = do
    LCT (PrimState m) a -> Int -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT LCT (PrimState m) a
lct Int
u
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    -- u is left child of v
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v) SplayNodeId
nothingLCT
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
u) SplayNodeId
nothingLCT
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE cutLCT #-}

{- | link u to v

require: @u@ and @v@ are *not connected*
-}
linkLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m ()
linkLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m ()
linkLCT LCT (PrimState m) a
lct Int
u Int
v = do
    LCT (PrimState m) a -> Int -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT LCT (PrimState m) a
lct Int
u
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
u) (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE linkLCT #-}

{- | require: @l@ and @r@ connected

>>> import Data.Monoid
>>> lct <- buildLCT @(Sum Int) $ U.fromList $ map Sum [0..3]
>>> linkLCT lct 1 0 >> linkLCT lct 2 1 >> linkLCT lct 3 1
>>> mconcatPathLCT lct 0 1
Sum {getSum = 1}
>>> mconcatPathLCT lct 2 3  -- 2 - 1 - 3
Sum {getSum = 6}
>>> mconcatPathLCT lct 0 3  -- 0 - 1 - 3
Sum {getSum = 4}
-}
mconcatPathLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m a
mconcatPathLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m a
mconcatPathLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} Int
l Int
r = do
    LCT (PrimState m) a -> Int -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m ()
evertLCT LCT (PrimState m) a
lct Int
l
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
r)
    MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
foldSubtreesLCT Int
r
{-# INLINE mconcatPathLCT #-}

-- | root is the left most node of the root path
findRootLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> m Int
findRootLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> m Int
findRootLCT LCT (PrimState m) a
lct Int
v0 = do
    SplayNodeId
u0 <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v0)
    SplayNodeId
lu0 <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
u0
    if SplayNodeId
lu0 SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
nothingLCT
        then Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
u0
        else
            ((SplayNodeId -> m Int) -> SplayNodeId -> m Int)
-> SplayNodeId -> m Int
forall a. (a -> a) -> a
fix
                ( \SplayNodeId -> m Int
loop !SplayNodeId
v -> do
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct SplayNodeId
v
                    SplayNodeId
lv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v
                    if SplayNodeId
lv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT
                        then SplayNodeId -> m Int
loop SplayNodeId
lv
                        else Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
v
                )
                SplayNodeId
lu0
{-# INLINE findRootLCT #-}

{- | commutative monoids

>>> import Data.Monoid
>>> lct <- buildLCT @(Sum Int) $ U.fromList $ map Sum [0..2]
>>> linkLCT lct 1 0 >> linkLCT lct 2 1
>>> mconcatPathLCT lct 0 2
Sum {getSum = 3}
>>> setCMonLCT lct 1 (Sum 100)
>>> mconcatPathLCT lct 0 2
Sum {getSum = 102}
-}
setCMonLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> a -> m ()
setCMonLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> a -> m ()
setCMonLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} Int
v a
x = do
    (SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
(SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
traverseDownLCT (LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct) LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
    MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
commMonoidLCT Int
v a
x
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE setCMonLCT #-}

{- | require: u and v are connected

>>> lct <- newLCT @() 4
>>> linkLCT lct 1 0 >> linkLCT lct 2 1 >> linkLCT lct 3 2
>>> evertLCT lct 0
>>> lcaLCT lct 0 3
0
>>> evertLCT lct 2
>>> lcaLCT lct 0 3
2
-}
lcaLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> Int -> Int -> m Int
lcaLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> Int -> Int -> m Int
lcaLCT LCT (PrimState m) a
t Int
u Int
v = do
    m SplayNodeId -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m SplayNodeId -> m ()) -> m SplayNodeId -> m ()
forall a b. (a -> b) -> a -> b
$ LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
t (Int -> SplayNodeId
asSplayNodeId Int
u)
    forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId (SplayNodeId -> Int) -> m SplayNodeId -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
t (Int -> SplayNodeId
asSplayNodeId Int
v)
{-# INLINE lcaLCT #-}

newtype SplayNodeId = SplayNodeId {SplayNodeId -> Int
getSplayNodeId :: Int}
    deriving newtype (SplayNodeId -> SplayNodeId -> Bool
(SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool) -> Eq SplayNodeId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SplayNodeId -> SplayNodeId -> Bool
== :: SplayNodeId -> SplayNodeId -> Bool
$c/= :: SplayNodeId -> SplayNodeId -> Bool
/= :: SplayNodeId -> SplayNodeId -> Bool
Eq, Eq SplayNodeId
Eq SplayNodeId =>
(SplayNodeId -> SplayNodeId -> Ordering)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> Bool)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> Ord SplayNodeId
SplayNodeId -> SplayNodeId -> Bool
SplayNodeId -> SplayNodeId -> Ordering
SplayNodeId -> SplayNodeId -> SplayNodeId
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SplayNodeId -> SplayNodeId -> Ordering
compare :: SplayNodeId -> SplayNodeId -> Ordering
$c< :: SplayNodeId -> SplayNodeId -> Bool
< :: SplayNodeId -> SplayNodeId -> Bool
$c<= :: SplayNodeId -> SplayNodeId -> Bool
<= :: SplayNodeId -> SplayNodeId -> Bool
$c> :: SplayNodeId -> SplayNodeId -> Bool
> :: SplayNodeId -> SplayNodeId -> Bool
$c>= :: SplayNodeId -> SplayNodeId -> Bool
>= :: SplayNodeId -> SplayNodeId -> Bool
$cmax :: SplayNodeId -> SplayNodeId -> SplayNodeId
max :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cmin :: SplayNodeId -> SplayNodeId -> SplayNodeId
min :: SplayNodeId -> SplayNodeId -> SplayNodeId
Ord, Int -> SplayNodeId -> ShowS
[SplayNodeId] -> ShowS
SplayNodeId -> String
(Int -> SplayNodeId -> ShowS)
-> (SplayNodeId -> String)
-> ([SplayNodeId] -> ShowS)
-> Show SplayNodeId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SplayNodeId -> ShowS
showsPrec :: Int -> SplayNodeId -> ShowS
$cshow :: SplayNodeId -> String
show :: SplayNodeId -> String
$cshowList :: [SplayNodeId] -> ShowS
showList :: [SplayNodeId] -> ShowS
Show, Integer -> SplayNodeId
SplayNodeId -> SplayNodeId
SplayNodeId -> SplayNodeId -> SplayNodeId
(SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (Integer -> SplayNodeId)
-> Num SplayNodeId
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: SplayNodeId -> SplayNodeId -> SplayNodeId
+ :: SplayNodeId -> SplayNodeId -> SplayNodeId
$c- :: SplayNodeId -> SplayNodeId -> SplayNodeId
- :: SplayNodeId -> SplayNodeId -> SplayNodeId
$c* :: SplayNodeId -> SplayNodeId -> SplayNodeId
* :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cnegate :: SplayNodeId -> SplayNodeId
negate :: SplayNodeId -> SplayNodeId
$cabs :: SplayNodeId -> SplayNodeId
abs :: SplayNodeId -> SplayNodeId
$csignum :: SplayNodeId -> SplayNodeId
signum :: SplayNodeId -> SplayNodeId
$cfromInteger :: Integer -> SplayNodeId
fromInteger :: Integer -> SplayNodeId
Num, Num SplayNodeId
Ord SplayNodeId
(Num SplayNodeId, Ord SplayNodeId) =>
(SplayNodeId -> Rational) -> Real SplayNodeId
SplayNodeId -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
$ctoRational :: SplayNodeId -> Rational
toRational :: SplayNodeId -> Rational
Real, Int -> SplayNodeId
SplayNodeId -> Int
SplayNodeId -> [SplayNodeId]
SplayNodeId -> SplayNodeId
SplayNodeId -> SplayNodeId -> [SplayNodeId]
SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId]
(SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId)
-> (Int -> SplayNodeId)
-> (SplayNodeId -> Int)
-> (SplayNodeId -> [SplayNodeId])
-> (SplayNodeId -> SplayNodeId -> [SplayNodeId])
-> (SplayNodeId -> SplayNodeId -> [SplayNodeId])
-> (SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId])
-> Enum SplayNodeId
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: SplayNodeId -> SplayNodeId
succ :: SplayNodeId -> SplayNodeId
$cpred :: SplayNodeId -> SplayNodeId
pred :: SplayNodeId -> SplayNodeId
$ctoEnum :: Int -> SplayNodeId
toEnum :: Int -> SplayNodeId
$cfromEnum :: SplayNodeId -> Int
fromEnum :: SplayNodeId -> Int
$cenumFrom :: SplayNodeId -> [SplayNodeId]
enumFrom :: SplayNodeId -> [SplayNodeId]
$cenumFromThen :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
enumFromThen :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
$cenumFromTo :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
enumFromTo :: SplayNodeId -> SplayNodeId -> [SplayNodeId]
$cenumFromThenTo :: SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId]
enumFromThenTo :: SplayNodeId -> SplayNodeId -> SplayNodeId -> [SplayNodeId]
Enum, Enum SplayNodeId
Real SplayNodeId
(Real SplayNodeId, Enum SplayNodeId) =>
(SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> SplayNodeId)
-> (SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId))
-> (SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId))
-> (SplayNodeId -> Integer)
-> Integral SplayNodeId
SplayNodeId -> Integer
SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
SplayNodeId -> SplayNodeId -> SplayNodeId
forall a.
(Real a, Enum a) =>
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> (a, a))
-> (a -> a -> (a, a))
-> (a -> Integer)
-> Integral a
$cquot :: SplayNodeId -> SplayNodeId -> SplayNodeId
quot :: SplayNodeId -> SplayNodeId -> SplayNodeId
$crem :: SplayNodeId -> SplayNodeId -> SplayNodeId
rem :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cdiv :: SplayNodeId -> SplayNodeId -> SplayNodeId
div :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cmod :: SplayNodeId -> SplayNodeId -> SplayNodeId
mod :: SplayNodeId -> SplayNodeId -> SplayNodeId
$cquotRem :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
quotRem :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
$cdivMod :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
divMod :: SplayNodeId -> SplayNodeId -> (SplayNodeId, SplayNodeId)
$ctoInteger :: SplayNodeId -> Integer
toInteger :: SplayNodeId -> Integer
Integral)

asSplayNodeId :: Int -> SplayNodeId
asSplayNodeId :: Int -> SplayNodeId
asSplayNodeId = Int -> SplayNodeId
forall a b. Coercible a b => a -> b
coerce
{-# INLINE asSplayNodeId #-}

getLeftChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
getLeftChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT{MVector (PrimState m) Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: MVector (PrimState m) Int32
leftChildLCT} SplayNodeId
v =
    Int32 -> SplayNodeId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> SplayNodeId) -> m Int32 -> m SplayNodeId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int32 -> Int -> m Int32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int32
leftChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v)
{-# INLINE getLeftChildLCT #-}

getRightChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
getRightChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT{MVector (PrimState m) Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: MVector (PrimState m) Int32
rightChildLCT} SplayNodeId
v =
    Int32 -> SplayNodeId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> SplayNodeId) -> m Int32 -> m SplayNodeId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int32 -> Int -> m Int32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int32
rightChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v)
{-# INLINE getRightChildLCT #-}

getParentLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
getParentLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT{MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
parentLCT :: MVector (PrimState m) Int32
parentLCT} SplayNodeId
v = Int32 -> SplayNodeId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> SplayNodeId) -> m Int32 -> m SplayNodeId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int32 -> Int -> m Int32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int32
parentLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v)
{-# INLINE getParentLCT #-}

setLeftChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    SplayNodeId ->
    m ()
setLeftChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT{MVector (PrimState m) Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: MVector (PrimState m) Int32
leftChildLCT} SplayNodeId
v SplayNodeId
lv =
    MVector (PrimState m) Int32 -> Int -> Int32 -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int32
leftChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int (Int -> Int32) -> Int -> Int32
forall a b. (a -> b) -> a -> b
$ SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
lv)
{-# INLINE setLeftChildLCT #-}

setRightChildLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    SplayNodeId ->
    m ()
setRightChildLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT{MVector (PrimState m) Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: MVector (PrimState m) Int32
rightChildLCT} SplayNodeId
v SplayNodeId
rv =
    MVector (PrimState m) Int32 -> Int -> Int32 -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int32
rightChildLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int (Int -> Int32) -> Int -> Int32
forall a b. (a -> b) -> a -> b
$ SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
rv)
{-# INLINE setRightChildLCT #-}

setParentLCT ::
    (PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    SplayNodeId ->
    m ()
setParentLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT{MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
parentLCT :: MVector (PrimState m) Int32
parentLCT} SplayNodeId
v SplayNodeId
pv =
    MVector (PrimState m) Int32 -> Int -> Int32 -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int32
parentLCT (SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
v) (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int (Int -> Int32) -> Int -> Int32
forall a b. (a -> b) -> a -> b
$ SplayNodeId -> Int
forall a b. Coercible a b => a -> b
coerce SplayNodeId
pv)
{-# INLINE setParentLCT #-}

nothingLCT :: SplayNodeId
nothingLCT :: SplayNodeId
nothingLCT = Int -> SplayNodeId
SplayNodeId (-Int
1)

isSplayTreeRootLCT :: (PrimMonad m) => LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT :: forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT LCT (PrimState m) a
lct SplayNodeId
v = do
    SplayNodeId
pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
    if SplayNodeId
pv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
nothingLCT
        then Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
        else do
            SplayNodeId
lpv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
pv
            SplayNodeId
rpv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT (PrimState m) a
lct SplayNodeId
pv
            Bool -> m Bool
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$! SplayNodeId
lpv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
v Bool -> Bool -> Bool
&& SplayNodeId
rpv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
v
{-# INLINE isSplayTreeRootLCT #-}

pullLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} SplayNodeId
v = do
    SplayNodeId
lv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v
    a
mlv <-
        if SplayNodeId
lv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT
            then MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
foldSubtreesLCT (Int -> m a) -> Int -> m a
forall a b. (a -> b) -> a -> b
$ forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
lv
            else a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Monoid a => a
mempty
    SplayNodeId
rv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT (PrimState m) a
lct SplayNodeId
v
    a
mrv <-
        if SplayNodeId
rv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT
            then MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
foldSubtreesLCT (Int -> m a) -> Int -> m a
forall a b. (a -> b) -> a -> b
$ forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
rv
            else a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Monoid a => a
mempty
    a
mv <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
commMonoidLCT (Int -> m a) -> Int -> m a
forall a b. (a -> b) -> a -> b
$ forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
v
    MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
foldSubtreesLCT (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
v) (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ a
mlv a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
mv a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
mrv
{-# INLINE pullLCT #-}

pushLCT :: (U.Unbox a, Monoid a, PrimMonad m) => LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT lct :: LCT (PrimState m) a
lct@LCT{MVector (PrimState m) a
MVector (PrimState m) Bool
MVector (PrimState m) Int32
parentLCT :: forall s a. LCT s a -> MVector s Int32
leftChildLCT :: forall s a. LCT s a -> MVector s Int32
rightChildLCT :: forall s a. LCT s a -> MVector s Int32
commMonoidLCT :: forall s a. LCT s a -> MVector s a
foldSubtreesLCT :: forall s a. LCT s a -> MVector s a
lazyRevFlagLCT :: forall s a. LCT s a -> MVector s Bool
parentLCT :: MVector (PrimState m) Int32
leftChildLCT :: MVector (PrimState m) Int32
rightChildLCT :: MVector (PrimState m) Int32
commMonoidLCT :: MVector (PrimState m) a
foldSubtreesLCT :: MVector (PrimState m) a
lazyRevFlagLCT :: MVector (PrimState m) Bool
..} SplayNodeId
v = do
    MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
lazyRevFlagLCT (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
v) m Bool -> (Bool -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
False -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Bool
True -> do
            MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
lazyRevFlagLCT (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
v) Bool
False
            SplayNodeId
lv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v
            SplayNodeId
rv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT (PrimState m) a
lct SplayNodeId
v
            LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
rv
            LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
lv
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
lv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                MVector (PrimState m) Bool -> (Bool -> Bool) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Bool
lazyRevFlagLCT Bool -> Bool
not (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
lv)
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
rv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                MVector (PrimState m) Bool -> (Bool -> Bool) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Bool
lazyRevFlagLCT Bool -> Bool
not (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @SplayNodeId SplayNodeId
rv)
{-# INLINE pushLCT #-}

-- | from the splay tree root to v
traverseDownLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    (SplayNodeId -> m ()) ->
    LCT (PrimState m) a ->
    SplayNodeId ->
    m ()
traverseDownLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
(SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
traverseDownLCT SplayNodeId -> m ()
f LCT (PrimState m) a
lct = ((SplayNodeId -> m ()) -> SplayNodeId -> m ())
-> SplayNodeId -> m ()
forall a. (a -> a) -> a
fix (((SplayNodeId -> m ()) -> SplayNodeId -> m ())
 -> SplayNodeId -> m ())
-> ((SplayNodeId -> m ()) -> SplayNodeId -> m ())
-> SplayNodeId
-> m ()
forall a b. (a -> b) -> a -> b
$ \SplayNodeId -> m ()
goUp SplayNodeId
v -> do
    LCT (PrimState m) a -> SplayNodeId -> m Bool
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT LCT (PrimState m) a
lct SplayNodeId
v m Bool -> (Bool -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> SplayNodeId -> m ()
f SplayNodeId
v
        Bool
False -> do
            SplayNodeId
pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
            SplayNodeId -> m ()
goUp SplayNodeId
pv
            SplayNodeId -> m ()
f SplayNodeId
v
{-# INLINE traverseDownLCT #-}

{-
     pv           v
    /  \         / \
   v   rpv ==> lv  pv
  / \              / \
lv  rv           rv  rpv
-}
rotateRightLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    -- | has pasrent node
    SplayNodeId ->
    m ()
rotateRightLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v = do
    SplayNodeId
pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
    SplayNodeId
ppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
pv
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
ppv
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
ppv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        SplayNodeId
lppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
lppv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
pv) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv SplayNodeId
v
        SplayNodeId
rppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
rppv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
pv) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv SplayNodeId
v
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct SplayNodeId
pv SplayNodeId
v
    SplayNodeId
rv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT (PrimState m) a
lct SplayNodeId
v
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
pv SplayNodeId
rv
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
pv
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
rv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct SplayNodeId
rv SplayNodeId
pv
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct SplayNodeId
pv
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct SplayNodeId
v

{-
{-# INLINE rotateRightLCT #-}
<no location info>: error:
    Simplifier ticks exhausted
-}

{-
   pv         v
  /  \       / \
lpv  v  ==> pv  rv
    / \    /  \
  lv  rv  lpv  lv
-}
rotateLeftLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    -- | has parent node
    SplayNodeId ->
    m ()
rotateLeftLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v = do
    SplayNodeId
pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
    SplayNodeId
ppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
pv
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
ppv
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
ppv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        SplayNodeId
lppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
lppv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
pv) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv SplayNodeId
v
        SplayNodeId
rppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getRightChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
rppv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
pv) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv SplayNodeId
v
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct SplayNodeId
pv SplayNodeId
v
    SplayNodeId
lv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT (PrimState m) a
lct SplayNodeId
pv SplayNodeId
lv
    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
pv
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SplayNodeId
lv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setParentLCT LCT (PrimState m) a
lct SplayNodeId
lv SplayNodeId
pv
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct SplayNodeId
pv
    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct SplayNodeId
v

{-
-- {-# INLINE rotateLeftLCT #-}
<no location info>: error:
    Simplifier ticks exhausted
-}

splayLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m ()
splayLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct SplayNodeId
v =
    (m () -> m ()) -> m ()
forall a. (a -> a) -> a
fix ((m () -> m ()) -> m ()) -> (m () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \m ()
loop -> do
        Bool
isRoot <- LCT (PrimState m) a -> SplayNodeId -> m Bool
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT LCT (PrimState m) a
lct SplayNodeId
v
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
isRoot (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
            SplayNodeId
pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
            SplayNodeId
ppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
pv
            Bool
isRoot' <- LCT (PrimState m) a -> SplayNodeId -> m Bool
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m Bool
isSplayTreeRootLCT LCT (PrimState m) a
lct SplayNodeId
pv
            if Bool
isRoot'
                then do
                    SplayNodeId
lpv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
pv
                    if SplayNodeId
lpv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
v
                        then LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v
                        else LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v
                else do
                    SplayNodeId
lpv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
pv
                    SplayNodeId
lppv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getLeftChildLCT LCT (PrimState m) a
lct SplayNodeId
ppv
                    case (SplayNodeId
lppv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
pv, SplayNodeId
lpv SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
== SplayNodeId
v) of
                        (Bool
True, Bool
True) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
pv
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v
                        (Bool
True, Bool
False) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v
                        (Bool
False, Bool
True) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateRightLCT LCT (PrimState m) a
lct SplayNodeId
v
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v
                        (Bool
False, Bool
False) -> do
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
pv
                            LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
rotateLeftLCT LCT (PrimState m) a
lct SplayNodeId
v
            m ()
loop
{-# INLINE splayLCT #-}

{- |
make v on the root path

properties
* @v@ is the root of the splay tree.
* @v@ is the right most node of the root path.
* @expose u >> expose v == lca u v@.

>>> lct <- newLCT @() 2
>>> linkLCT lct 0 1
>>> evertLCT lct 0
>>> exposeLCT lct 1
1
>>> isSplayTreeRootLCT lct 1
True
>>> findRootLCT lct 1
0
-}
exposeLCT ::
    (U.Unbox a, Monoid a, PrimMonad m) =>
    LCT (PrimState m) a ->
    SplayNodeId ->
    m SplayNodeId
exposeLCT :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
exposeLCT LCT (PrimState m) a
lct SplayNodeId
v0 = do
    ((SplayNodeId -> SplayNodeId -> m SplayNodeId)
 -> SplayNodeId -> SplayNodeId -> m SplayNodeId)
-> SplayNodeId -> SplayNodeId -> m SplayNodeId
forall a. (a -> a) -> a
fix
        ( \SplayNodeId -> SplayNodeId -> m SplayNodeId
goUp !SplayNodeId
v !SplayNodeId
rv ->
            if SplayNodeId
v SplayNodeId -> SplayNodeId -> Bool
forall a. Eq a => a -> a -> Bool
/= SplayNodeId
nothingLCT
                then do
                    (SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
(SplayNodeId -> m ()) -> LCT (PrimState m) a -> SplayNodeId -> m ()
traverseDownLCT (LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pushLCT LCT (PrimState m) a
lct) LCT (PrimState m) a
lct SplayNodeId
v
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct SplayNodeId
v
                    LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> SplayNodeId -> m ()
setRightChildLCT LCT (PrimState m) a
lct SplayNodeId
v SplayNodeId
rv
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
pullLCT LCT (PrimState m) a
lct SplayNodeId
v
                    SplayNodeId
pv <- LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
forall (m :: * -> *) a.
PrimMonad m =>
LCT (PrimState m) a -> SplayNodeId -> m SplayNodeId
getParentLCT LCT (PrimState m) a
lct SplayNodeId
v
                    SplayNodeId -> SplayNodeId -> m SplayNodeId
goUp SplayNodeId
pv SplayNodeId
v
                else do
                    -- rv is the root of splay tree
                    LCT (PrimState m) a -> SplayNodeId -> m ()
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
LCT (PrimState m) a -> SplayNodeId -> m ()
splayLCT LCT (PrimState m) a
lct SplayNodeId
v0
                    SplayNodeId -> m SplayNodeId
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return SplayNodeId
rv
        )
        SplayNodeId
v0
        SplayNodeId
nothingLCT
{-# INLINE exposeLCT #-}