module Data.UnionFind.Merge where
import Control.Monad
import Control.Monad.Primitive
import Data.Function
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
data UnionFind mv s a = UF
{ forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> MVector s Int
parentOrNegativeSizeUF :: UM.MVector s Int
, forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> mv s a
mconcatUF :: mv s a
}
newUnionFind ::
(G.Vector v a, Monoid a, PrimMonad m) =>
Int ->
m (UnionFind (G.Mutable v) (PrimState m) a)
newUnionFind :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, Monoid a, PrimMonad m) =>
Int -> m (UnionFind (Mutable v) (PrimState m) a)
newUnionFind Int
n = MVector (PrimState m) Int
-> Mutable v (PrimState m) a
-> UnionFind (Mutable v) (PrimState m) a
forall {k} (mv :: * -> k -> *) s (a :: k).
MVector s Int -> mv s a -> UnionFind mv s a
UF (MVector (PrimState m) Int
-> Mutable v (PrimState m) a
-> UnionFind (Mutable v) (PrimState m) a)
-> m (MVector (PrimState m) Int)
-> m (Mutable v (PrimState m) a
-> UnionFind (Mutable v) (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1) m (Mutable v (PrimState m) a
-> UnionFind (Mutable v) (PrimState m) a)
-> m (Mutable v (PrimState m) a)
-> m (UnionFind (Mutable v) (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> a -> m (v (PrimState m) a)
GM.replicate Int
n a
forall a. Monoid a => a
mempty
{-# INLINE newUnionFind #-}
buildUnionFind ::
(G.Vector v a, PrimMonad m) =>
v a ->
m (UnionFind (G.Mutable v) (PrimState m) a)
buildUnionFind :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
v a -> m (UnionFind (Mutable v) (PrimState m) a)
buildUnionFind v a
v = do
mv <- v a -> m (Mutable v (PrimState m) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v a
v
UF <$> UM.replicate (GM.length mv) (-1) <*> pure mv
{-# INLINE buildUnionFind #-}
findUF ::
(PrimMonad m) =>
UnionFind mv (PrimState m) a ->
Int ->
m Int
findUF :: forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
findUF UF{parentOrNegativeSizeUF :: forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> MVector s Int
parentOrNegativeSizeUF = MVector (PrimState m) Int
uf} Int
x0 = Int -> (Int -> m Int) -> m Int
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> (Int -> m b) -> m b
go Int
x0 Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
where
go :: Int -> (Int -> m b) -> m b
go !Int
x Int -> m b
k = do
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x
if px < 0
then k x
else go px $ \Int
ppx -> do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x Int
ppx
Int -> m b
k Int
ppx
{-# INLINE findUF #-}
sizeUF ::
(PrimMonad m) =>
UnionFind mv (PrimState m) a ->
Int ->
m Int
sizeUF :: forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
sizeUF UF{parentOrNegativeSizeUF :: forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> MVector s Int
parentOrNegativeSizeUF = MVector (PrimState m) Int
uf} = ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix (((Int -> m Int) -> Int -> m Int) -> Int -> m Int)
-> ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ \Int -> m Int
loop Int
x -> do
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
x
if px < 0
then return $! negate px
else loop px
{-# INLINE sizeUF #-}
readUF ::
(G.Vector v a, PrimMonad m) =>
UnionFind (G.Mutable v) (PrimState m) a ->
Int ->
m a
readUF :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
UnionFind (Mutable v) (PrimState m) a -> Int -> m a
readUF UnionFind (Mutable v) (PrimState m) a
uf Int
x = UnionFind (Mutable v) (PrimState m) a -> Int -> m Int
forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
findUF UnionFind (Mutable v) (PrimState m) a
uf Int
x m Int -> (Int -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Mutable v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead (UnionFind (Mutable v) (PrimState m) a -> Mutable v (PrimState m) a
forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> mv s a
mconcatUF UnionFind (Mutable v) (PrimState m) a
uf)
{-# INLINE readUF #-}
writeUF ::
(G.Vector v a, PrimMonad m) =>
UnionFind (G.Mutable v) (PrimState m) a ->
Int ->
a ->
m ()
writeUF :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
UnionFind (Mutable v) (PrimState m) a -> Int -> a -> m ()
writeUF UnionFind (Mutable v) (PrimState m) a
uf Int
i a
x = UnionFind (Mutable v) (PrimState m) a -> Int -> m Int
forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
findUF UnionFind (Mutable v) (PrimState m) a
uf Int
i m Int -> (Int -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> a -> m ()) -> a -> Int -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Mutable v (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite (UnionFind (Mutable v) (PrimState m) a -> Mutable v (PrimState m) a
forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> mv s a
mconcatUF UnionFind (Mutable v) (PrimState m) a
uf)) a
x
{-# INLINE writeUF #-}
modifyUF ::
(G.Vector v a, PrimMonad m) =>
UnionFind (G.Mutable v) (PrimState m) a ->
(a -> a) ->
Int ->
m ()
modifyUF :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, PrimMonad m) =>
UnionFind (Mutable v) (PrimState m) a -> (a -> a) -> Int -> m ()
modifyUF UnionFind (Mutable v) (PrimState m) a
uf a -> a
f Int
i = UnionFind (Mutable v) (PrimState m) a -> Int -> m Int
forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
findUF UnionFind (Mutable v) (PrimState m) a
uf Int
i m Int -> (Int -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Mutable v (PrimState m) a -> (a -> a) -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
GM.unsafeModify (UnionFind (Mutable v) (PrimState m) a -> Mutable v (PrimState m) a
forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> mv s a
mconcatUF UnionFind (Mutable v) (PrimState m) a
uf) a -> a
f
{-# INLINE modifyUF #-}
uniteUF ::
(G.Vector v a, Monoid a, PrimMonad m) =>
UnionFind (G.Mutable v) (PrimState m) a ->
Int ->
Int ->
m Bool
uniteUF :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, Monoid a, PrimMonad m) =>
UnionFind (Mutable v) (PrimState m) a -> Int -> Int -> m Bool
uniteUF UnionFind (Mutable v) (PrimState m) a
uf Int
x Int
y = do
px <- UnionFind (Mutable v) (PrimState m) a -> Int -> m Int
forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
findUF UnionFind (Mutable v) (PrimState m) a
uf Int
x
py <- findUF uf y
if px == py
then return False
else do
rx <- UM.unsafeRead (parentOrNegativeSizeUF uf) px
ry <- UM.unsafeRead (parentOrNegativeSizeUF uf) py
if rx < ry
then do
UM.unsafeModify (parentOrNegativeSizeUF uf) (+ ry) px
UM.unsafeWrite (parentOrNegativeSizeUF uf) py px
mx <- GM.unsafeRead (mconcatUF uf) px
my <- GM.unsafeRead (mconcatUF uf) py
GM.unsafeWrite (mconcatUF uf) px $! mx <> my
GM.unsafeWrite (mconcatUF uf) py $! mempty
else do
UM.unsafeModify (parentOrNegativeSizeUF uf) (+ rx) py
UM.unsafeWrite (parentOrNegativeSizeUF uf) px py
mx <- GM.unsafeRead (mconcatUF uf) px
my <- GM.unsafeRead (mconcatUF uf) py
GM.unsafeWrite (mconcatUF uf) py $! mx <> my
GM.unsafeWrite (mconcatUF uf) px $! mempty
return True
{-# INLINE uniteUF #-}
uniteUF_ ::
(G.Vector v a, Monoid a, PrimMonad m) =>
UnionFind (G.Mutable v) (PrimState m) a ->
Int ->
Int ->
m ()
uniteUF_ :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, Monoid a, PrimMonad m) =>
UnionFind (Mutable v) (PrimState m) a -> Int -> Int -> m ()
uniteUF_ UnionFind (Mutable v) (PrimState m) a
uf Int
x Int
y = m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> m Bool -> m ()
forall a b. (a -> b) -> a -> b
$ UnionFind (Mutable v) (PrimState m) a -> Int -> Int -> m Bool
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, Monoid a, PrimMonad m) =>
UnionFind (Mutable v) (PrimState m) a -> Int -> Int -> m Bool
uniteUF UnionFind (Mutable v) (PrimState m) a
uf Int
x Int
y
{-# INLINE uniteUF_ #-}
equivUF ::
(PrimMonad m) =>
UnionFind mv (PrimState m) a ->
Int ->
Int ->
m Bool
equivUF :: forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> Int -> m Bool
equivUF UnionFind mv (PrimState m) a
uf Int
x Int
y = Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Int -> Int -> Bool) -> m Int -> m (Int -> Bool)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` UnionFind mv (PrimState m) a -> Int -> m Int
forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
findUF UnionFind mv (PrimState m) a
uf Int
x m (Int -> Bool) -> m Int -> m Bool
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
`ap` UnionFind mv (PrimState m) a -> Int -> m Int
forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> Int -> m Int
findUF UnionFind mv (PrimState m) a
uf Int
y
{-# INLINE equivUF #-}
countGroupUF ::
(PrimMonad m) =>
UnionFind mv (PrimState m) a ->
m Int
countGroupUF :: forall {k} (m :: * -> *) (mv :: * -> k -> *) (a :: k).
PrimMonad m =>
UnionFind mv (PrimState m) a -> m Int
countGroupUF UnionFind mv (PrimState m) a
uf = Vector Int -> Int
forall a. Unbox a => Vector a -> Int
U.length (Vector Int -> Int)
-> (Vector Int -> Vector Int) -> Vector Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Vector Int -> Int) -> m (Vector Int) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.freeze (UnionFind mv (PrimState m) a -> MVector (PrimState m) Int
forall {k} (mv :: * -> k -> *) s (a :: k).
UnionFind mv s a -> MVector s Int
parentOrNegativeSizeUF UnionFind mv (PrimState m) a
uf)
{-# INLINE countGroupUF #-}