module Data.UnionFind where

import Control.Monad
import Control.Monad.Primitive
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

newtype UnionFind s = UF {forall s. UnionFind s -> MVector s Int
getUnionFind :: UM.MVector s Int}

newUnionFind :: (PrimMonad m) => Int -> m (UnionFind (PrimState m))
newUnionFind :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (UnionFind (PrimState m))
newUnionFind Int
n = MVector (PrimState m) Int -> UnionFind (PrimState m)
forall s. MVector s Int -> UnionFind s
UF (MVector (PrimState m) Int -> UnionFind (PrimState m))
-> m (MVector (PrimState m) Int) -> m (UnionFind (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
{-# INLINE newUnionFind #-}

freezeUnionFind :: (PrimMonad m) => UnionFind (PrimState m) -> m (U.Vector Int)
freezeUnionFind :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> m (Vector Int)
freezeUnionFind = MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.freeze (MVector (PrimState m) Int -> m (Vector Int))
-> (UnionFind (PrimState m) -> MVector (PrimState m) Int)
-> UnionFind (PrimState m)
-> m (Vector Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind

findUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> m Int
findUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
x0 = Int -> (Int -> m Int) -> m Int
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> (Int -> m b) -> m b
go Int
x0 Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
  where
    go :: Int -> (Int -> m b) -> m b
go !Int
x Int -> m b
k = do
      px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead (UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind UnionFind (PrimState m)
UnionFind (PrimState m)
uf) Int
x
      if px < 0
        then k x
        else go px $ \Int
ppx -> do
          MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite (UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind UnionFind (PrimState m)
uf) Int
x Int
ppx
          Int -> m b
k Int
ppx
{-# INLINE findUF #-}

sizeUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> m Int
sizeUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
sizeUF UnionFind (PrimState m)
uf = ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix (((Int -> m Int) -> Int -> m Int) -> Int -> m Int)
-> ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ \Int -> m Int
loop Int
x -> do
  px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead (UnionFind (PrimState m) -> MVector (PrimState m) Int
forall s. UnionFind s -> MVector s Int
getUnionFind UnionFind (PrimState m)
uf) Int
x
  if px < 0
    then return $! negate px
    else loop px
{-# INLINE sizeUF #-}

uniteUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> Int -> m Bool
uniteUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m Bool
uniteUF UnionFind (PrimState m)
uf Int
x Int
y = do
  px <- UnionFind (PrimState m) -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
x
  py <- findUF uf y
  if px == py
    then return False
    else do
      rx <- UM.unsafeRead (getUnionFind uf) px
      ry <- UM.unsafeRead (getUnionFind uf) py
      if rx < ry
        then do
          UM.unsafeModify (getUnionFind uf) (+ ry) px
          UM.unsafeWrite (getUnionFind uf) py px
        else do
          UM.unsafeModify (getUnionFind uf) (+ rx) py
          UM.unsafeWrite (getUnionFind uf) px py
      return True
{-# INLINE uniteUF #-}

uniteUF_ :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> Int -> m ()
uniteUF_ :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m ()
uniteUF_ UnionFind (PrimState m)
uf Int
x Int
y = m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> m Bool -> m ()
forall a b. (a -> b) -> a -> b
$ UnionFind (PrimState m) -> Int -> Int -> m Bool
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m Bool
uniteUF UnionFind (PrimState m)
uf Int
x Int
y
{-# INLINE uniteUF_ #-}

equivUF :: (PrimMonad m) => UnionFind (PrimState m) -> Int -> Int -> m Bool
equivUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> Int -> m Bool
equivUF UnionFind (PrimState m)
uf Int
x Int
y = Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Int -> Int -> Bool) -> m Int -> m (Int -> Bool)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` UnionFind (PrimState m) -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
x m (Int -> Bool) -> m Int -> m Bool
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
`ap` UnionFind (PrimState m) -> Int -> m Int
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> Int -> m Int
findUF UnionFind (PrimState m)
uf Int
y
{-# INLINE equivUF #-}

-- | O(n)
countGroupUF :: (PrimMonad m) => UnionFind (PrimState m) -> m Int
countGroupUF :: forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> m Int
countGroupUF UnionFind (PrimState m)
uf = Vector Int -> Int
forall a. Unbox a => Vector a -> Int
U.length (Vector Int -> Int)
-> (Vector Int -> Vector Int) -> Vector Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Vector Int -> Int) -> m (Vector Int) -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UnionFind (PrimState m) -> m (Vector Int)
forall (m :: * -> *).
PrimMonad m =>
UnionFind (PrimState m) -> m (Vector Int)
freezeUnionFind UnionFind (PrimState m)
uf
{-# INLINE countGroupUF #-}