module Data.UnionFind.Diff where
import Control.Monad.Primitive
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
data UnionFindDiff s a = UFD
{ forall s a. UnionFindDiff s a -> MVector s Int
parentOrNegativeSizeUFD :: UM.MVector s Int
, forall s a. UnionFindDiff s a -> MVector s a
potentialUFD :: UM.MVector s a
}
newUnionFindDiff ::
(U.Unbox a, Num a, PrimMonad m) =>
Int ->
m (UnionFindDiff (PrimState m) a)
newUnionFindDiff :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
Int -> m (UnionFindDiff (PrimState m) a)
newUnionFindDiff Int
n =
MVector (PrimState m) Int
-> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a
forall s a. MVector s Int -> MVector s a -> UnionFindDiff s a
UFD
(MVector (PrimState m) Int
-> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) Int)
-> m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) a) -> m (UnionFindDiff (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
0
findUFD ::
(Num a, U.Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a ->
Int ->
m (Int, a)
findUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x0 = Int -> ((Int, a) -> m (Int, a)) -> m (Int, a)
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> ((Int, a) -> m b) -> m b
go Int
x0 (Int, a) -> m (Int, a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
where
go :: Int -> ((Int, a) -> m b) -> m b
go !Int
x (Int, a) -> m b
k = do
Int
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x
if Int
px Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
then (Int, a) -> m b
k (Int
x, a
0)
else Int -> ((Int, a) -> m b) -> m b
go Int
px (((Int, a) -> m b) -> m b) -> ((Int, a) -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \(Int
root, !a
hpx) -> do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x Int
root
a
hx <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
potential Int
x
let !hx' :: a
hx' = a
hpx a -> a -> a
forall a. Num a => a -> a -> a
+ a
hx
MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
MVector (PrimState m) a
potential Int
x a
hx'
(Int, a) -> m b
k (Int
root, a
hx')
{-# INLINE findUFD #-}
sizeUFD :: (PrimMonad m) => UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD :: forall (m :: * -> *) a.
PrimMonad m =>
UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
_) = ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix (((Int -> m Int) -> Int -> m Int) -> Int -> m Int)
-> ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ \Int -> m Int
loop Int
x -> do
Int
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
x
if Int
px Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
then Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$! Int -> Int
forall a. Num a => a -> a
negate Int
px
else Int -> m Int
loop Int
px
{-# INLINE sizeUFD #-}
setDiffUFD ::
(Eq a, Num a, U.Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a ->
Int ->
Int ->
a ->
m (Maybe Bool)
setDiffUFD :: forall a (m :: * -> *).
(Eq a, Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> a -> m (Maybe Bool)
setDiffUFD ufd :: UnionFindDiff (PrimState m) a
ufd@(UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x Int
y a
d = do
(Int
px, a
hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
(Int
py, a
hy) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
y
if Int
px Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
py
then
if a
hx a -> a -> a
forall a. Num a => a -> a -> a
- a
hy a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
d
then Maybe Bool -> m (Maybe Bool)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Bool -> m (Maybe Bool)) -> Maybe Bool -> m (Maybe Bool)
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
else Maybe Bool -> m (Maybe Bool)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Bool
forall a. Maybe a
Nothing
else do
Int
rx <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
px
Int
ry <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
py
if Int
rx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
ry
then do
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
uf (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ry) Int
px
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
uf Int
py Int
px
MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
potential Int
py (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ a
hx a -> a -> a
forall a. Num a => a -> a -> a
- a
hy a -> a -> a
forall a. Num a => a -> a -> a
- a
d
else do
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
uf (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rx) Int
py
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
uf Int
px Int
py
MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
potential Int
px (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ a
hy a -> a -> a
forall a. Num a => a -> a -> a
- a
hx a -> a -> a
forall a. Num a => a -> a -> a
+ a
d
Maybe Bool -> m (Maybe Bool)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Bool -> m (Maybe Bool)) -> Maybe Bool -> m (Maybe Bool)
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
{-# INLINE setDiffUFD #-}
diffUFD ::
(Num a, U.Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a ->
Int ->
Int ->
m (Maybe a)
diffUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> m (Maybe a)
diffUFD UnionFindDiff (PrimState m) a
ufd Int
x Int
y = do
(Int
px, a
hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
(Int
py, a
hy) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
y
Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ if Int
px Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
py
then a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$! a
hx a -> a -> a
forall a. Num a => a -> a -> a
- a
hy
else Maybe a
forall a. Maybe a
Nothing
{-# INLINE diffUFD #-}