module Data.UnionFind where
import Control.Monad
import Control.Monad.Primitive
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
newtype UnionFind s = UF {forall s. UnionFind s -> MVector s Int
getUnionFind :: UM.MVector s Int}
newUnionFind :: (PrimMonad m) => Int -> m (UnionFind (PrimState m))
newUnionFind :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (UnionFind (PrimState m))
newUnionFind Int
n = MVector (PrimState m) Int -> UnionFind (PrimState m)
forall s. MVector s Int -> UnionFind s
UF (MVector (PrimState m) Int -> UnionFind (PrimState m))
-> m (MVector (PrimState m) Int) -> m (UnionFind (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
{-# INLINE newUnionFind #-}
freezeUnionFind :: (PrimMonad m) => UnionFind (PrimState m) -> m (U.Vector Int)
freezeUnionFind :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> m (Vector Int)
freezeUnionFind = MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.freeze (MVector (PrimState m) Int -> m (Vector Int))
-> (UnionFind (PrimState m) -> MVector (PrimState m) Int)
-> UnionFind (PrimState m)
-> m (Vector Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind
findUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> m Int
findUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
x0 = Int -> (Int -> m Int) -> m Int
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> (Int -> m b) -> m b
go Int
x0 Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
where
go :: Int -> (Int -> m b) -> m b
go !Int
x Int -> m b
k = do
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead (UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind UnionFind (PrimState m)
UnionFind (PrimState m)
uf) Int
x
if px < 0
then k x
else go px $ \Int
ppx -> do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite (UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind UnionFind (PrimState m)
uf) Int
x Int
ppx
Int -> m b
k Int
ppx
{-# INLINE findUF #-}
sizeUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> m Int
sizeUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
sizeUF UnionFind (PrimState m)
uf = ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix (((Int -> m Int) -> Int -> m Int) -> Int -> m Int)
-> ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ \Int -> m Int
loop Int
x -> do
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead (UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind UnionFind (PrimState m)
uf) Int
x
if px < 0
then return $! negate px
else loop px
{-# INLINE sizeUF #-}
uniteUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> Int -> m Bool
uniteUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m Bool
uniteUF UnionFind (PrimState m)
uf Int
x Int
y = do
px <- UnionFind (PrimState m) -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
x
py <- findUF uf y
if px == py
then return False
else do
rx <- UM.unsafeRead (getUnionFind uf) px
ry <- UM.unsafeRead (getUnionFind uf) py
if rx < ry
then do
UM.unsafeModify (getUnionFind uf) (+ ry) px
UM.unsafeWrite (getUnionFind uf) py px
else do
UM.unsafeModify (getUnionFind uf) (+ rx) py
UM.unsafeWrite (getUnionFind uf) px py
return True
{-# INLINE uniteUF #-}
uniteUF_ :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> Int -> m ()
uniteUF_ :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m ()
uniteUF_ UnionFind (PrimState m)
uf Int
x Int
y = m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> m Bool -> m ()
forall a b. (a -> b) -> a -> b
$ UnionFind (PrimState m) -> Int -> Int -> m Bool
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m Bool
uniteUF UnionFind (PrimState m)
uf Int
x Int
y
{-# INLINE uniteUF_ #-}
equivUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> Int -> m Bool
equivUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m Bool
equivUF UnionFind (PrimState m)
uf Int
x Int
y = Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Int -> Int -> Bool) -> m Int -> m (Int -> Bool)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` UnionFind (PrimState m) -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
x m (Int -> Bool) -> m Int -> m Bool
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
`ap` UnionFind (PrimState m) -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
y
{-# INLINE equivUF #-}
countGroupUF :: (PrimMonad m) => UnionFind (PrimState m) -> m Int
countGroupUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> m Int
countGroupUF UnionFind (PrimState m)
uf = Vector Int -> Int
forall a. Unbox a => Vector a -> Int
U.length (Vector Int -> Int)
-> (Vector Int -> Vector Int) -> Vector Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Vector Int -> Int) -> m (Vector Int) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UnionFind (PrimState m) -> m (Vector Int)
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> m (Vector Int)
freezeUnionFind UnionFind (PrimState m)
uf
{-# INLINE countGroupUF #-}