module Data.UnionFind.Diff where
import Control.Monad.Primitive
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
data UnionFindDiff s a = UFD
{ forall s a. UnionFindDiff s a -> MVector s Int
parentOrNegativeSizeUFD :: UM.MVector s Int
, forall s a. UnionFindDiff s a -> MVector s a
potentialUFD :: UM.MVector s a
}
newUnionFindDiff ::
(U.Unbox a, Num a, PrimMonad m) =>
Int ->
m (UnionFindDiff (PrimState m) a)
newUnionFindDiff :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
Int -> m (UnionFindDiff (PrimState m) a)
newUnionFindDiff Int
n =
MVector (PrimState m) Int
-> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a
forall s a. MVector s Int -> MVector s a -> UnionFindDiff s a
UFD
(MVector (PrimState m) Int
-> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) Int)
-> m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) a) -> m (UnionFindDiff (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
0
findUFD ::
(Num a, U.Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a ->
Int ->
m (Int, a)
findUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x0 = Int -> ((Int, a) -> m (Int, a)) -> m (Int, a)
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> ((Int, a) -> m b) -> m b
go Int
x0 (Int, a) -> m (Int, a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
where
go :: Int -> ((Int, a) -> m b) -> m b
go !Int
x (Int, a) -> m b
k = do
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x
if px < 0
then k (x, 0)
else go px $ \(Int
root, !a
hpx) -> do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x Int
root
hx <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
potential Int
x
let !hx' = a
hpx a -> a -> a
forall a. Num a => a -> a -> a
+ a
hx
UM.unsafeWrite potential x hx'
k (root, hx')
{-# INLINE findUFD #-}
sizeUFD :: (PrimMonad m) => UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD :: forall (m :: * -> *) a.
PrimMonad m =>
UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
_) = ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix (((Int -> m Int) -> Int -> m Int) -> Int -> m Int)
-> ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ \Int -> m Int
loop Int
x -> do
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
x
if px < 0
then return $! negate px
else loop px
{-# INLINE sizeUFD #-}
setDiffUFD ::
(Eq a, Num a, U.Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a ->
Int ->
Int ->
a ->
m (Maybe Bool)
setDiffUFD :: forall a (m :: * -> *).
(Eq a, Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> a -> m (Maybe Bool)
setDiffUFD ufd :: UnionFindDiff (PrimState m) a
ufd@(UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x Int
y a
d = do
(px, hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
(py, hy) <- findUFD ufd y
if px == py
then
if hx - hy == d
then return $ Just False
else return Nothing
else do
rx <- UM.unsafeRead uf px
ry <- UM.unsafeRead uf py
if rx < ry
then do
UM.unsafeModify uf (+ ry) px
UM.unsafeWrite uf py px
UM.unsafeWrite potential py $ hx - hy - d
else do
UM.unsafeModify uf (+ rx) py
UM.unsafeWrite uf px py
UM.unsafeWrite potential px $ hy - hx + d
return $ Just True
{-# INLINE setDiffUFD #-}
diffUFD ::
(Num a, U.Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a ->
Int ->
Int ->
m (Maybe a)
diffUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> m (Maybe a)
diffUFD UnionFindDiff (PrimState m) a
ufd Int
x Int
y = do
(px, hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
(py, hy) <- findUFD ufd y
pure
$ if px == py
then Just $! hx - hy
else Nothing
{-# INLINE diffUFD #-}