module Data.UnionFind.Diff where

import Control.Monad.Primitive
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

data UnionFindDiff s a = UFD
  { forall s a. UnionFindDiff s a -> MVector s Int
parentOrNegativeSizeUFD :: UM.MVector s Int
  , forall s a. UnionFindDiff s a -> MVector s a
potentialUFD :: UM.MVector s a
  }

newUnionFindDiff ::
  (U.Unbox a, Num a, PrimMonad m) =>
  Int ->
  m (UnionFindDiff (PrimState m) a)
newUnionFindDiff :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
Int -> m (UnionFindDiff (PrimState m) a)
newUnionFindDiff Int
n =
  MVector (PrimState m) Int
-> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a
forall s a. MVector s Int -> MVector s a -> UnionFindDiff s a
UFD
    (MVector (PrimState m) Int
 -> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) Int)
-> m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
    m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) a) -> m (UnionFindDiff (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
0

findUFD ::
  (Num a, U.Unbox a, PrimMonad m) =>
  UnionFindDiff (PrimState m) a ->
  Int ->
  -- | (representative, potential)
  m (Int, a)
findUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x0 = Int -> ((Int, a) -> m (Int, a)) -> m (Int, a)
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> ((Int, a) -> m b) -> m b
go Int
x0 (Int, a) -> m (Int, a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
  where
    go :: Int -> ((Int, a) -> m b) -> m b
go !Int
x (Int, a) -> m b
k = do
      px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x
      if px < 0
        then k (x, 0)
        else go px $ \(Int
root, !a
hpx) -> do
          MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x Int
root
          hx <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
potential Int
x
          let !hx' = a
hpx a -> a -> a
forall a. Num a => a -> a -> a
+ a
hx
          UM.unsafeWrite potential x hx'
          k (root, hx')
{-# INLINE findUFD #-}

sizeUFD :: (PrimMonad m) => UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD :: forall (m :: * -> *) a.
PrimMonad m =>
UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
_) = ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix (((Int -> m Int) -> Int -> m Int) -> Int -> m Int)
-> ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ \Int -> m Int
loop Int
x -> do
  px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
x
  if px < 0
    then return $! negate px
    else loop px
{-# INLINE sizeUFD #-}

{- | @hx - hy = d@

>>> uf <- newUnionFindDiff @Int 2
>>> setDiffUFD uf 1 0 1
Just True
>>> setDiffUFD uf 1 0 999
Nothing
>>> setDiffUFD uf 1 0 1
Just False
>>> setDiffUFD uf 0 1 (-1)
Just False
-}
setDiffUFD ::
  (Eq a, Num a, U.Unbox a, PrimMonad m) =>
  UnionFindDiff (PrimState m) a ->
  Int ->
  Int ->
  a ->
  m (Maybe Bool)
setDiffUFD :: forall a (m :: * -> *).
(Eq a, Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> a -> m (Maybe Bool)
setDiffUFD ufd :: UnionFindDiff (PrimState m) a
ufd@(UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x Int
y a
d = do
  (px, hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
  (py, hy) <- findUFD ufd y
  if px == py
    then
      if hx - hy == d
        then return $ Just False
        else return Nothing
    else do
      rx <- UM.unsafeRead uf px
      ry <- UM.unsafeRead uf py
      if rx < ry
        then do
          UM.unsafeModify uf (+ ry) px
          UM.unsafeWrite uf py px
          UM.unsafeWrite potential py $ hx - hy - d
        else do
          UM.unsafeModify uf (+ rx) py
          UM.unsafeWrite uf px py
          UM.unsafeWrite potential px $ hy - hx + d
      return $ Just True
{-# INLINE setDiffUFD #-}

{- | @hx - hy@

>>> uf <- newUnionFindDiff @Int 3
>>> setDiffUFD uf 1 0 1
Just True
>>> diffUFD uf 1 0
Just 1
>>> diffUFD uf 0 1
Just (-1)
>>> diffUFD uf 0 2
Nothing
>>> setDiffUFD uf 2 1 2
Just True
>>> diffUFD uf 2 1
Just 2
>>> diffUFD uf 2 0
Just 3
-}
diffUFD ::
  (Num a, U.Unbox a, PrimMonad m) =>
  UnionFindDiff (PrimState m) a ->
  Int ->
  Int ->
  m (Maybe a)
diffUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> m (Maybe a)
diffUFD UnionFindDiff (PrimState m) a
ufd Int
x Int
y = do
  (px, hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
  (py, hy) <- findUFD ufd y
  pure
    $ if px == py
      then Just $! hx - hy
      else Nothing
{-# INLINE diffUFD #-}