module Data.UnionFind.Diff where

import Control.Monad.Primitive
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

data UnionFindDiff s a = UFD
  { forall s a. UnionFindDiff s a -> MVector s Int
parentOrNegativeSizeUFD :: UM.MVector s Int
  , forall s a. UnionFindDiff s a -> MVector s a
potentialUFD :: UM.MVector s a
  }

newUnionFindDiff ::
  (U.Unbox a, Num a, PrimMonad m) =>
  Int ->
  m (UnionFindDiff (PrimState m) a)
newUnionFindDiff :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
Int -> m (UnionFindDiff (PrimState m) a)
newUnionFindDiff Int
n =
  MVector (PrimState m) Int
-> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a
forall s a. MVector s Int -> MVector s a -> UnionFindDiff s a
UFD
    (MVector (PrimState m) Int
 -> MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) Int)
-> m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
    m (MVector (PrimState m) a -> UnionFindDiff (PrimState m) a)
-> m (MVector (PrimState m) a) -> m (UnionFindDiff (PrimState m) a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n a
0

findUFD ::
  (Num a, U.Unbox a, PrimMonad m) =>
  UnionFindDiff (PrimState m) a ->
  Int ->
  -- | (representative, potential)
  m (Int, a)
findUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x0 = Int -> ((Int, a) -> m (Int, a)) -> m (Int, a)
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> ((Int, a) -> m b) -> m b
go Int
x0 (Int, a) -> m (Int, a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
  where
    go :: Int -> ((Int, a) -> m b) -> m b
go !Int
x (Int, a) -> m b
k = do
      Int
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x
      if Int
px Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
        then (Int, a) -> m b
k (Int
x, a
0)
        else Int -> ((Int, a) -> m b) -> m b
go Int
px (((Int, a) -> m b) -> m b) -> ((Int, a) -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \(Int
root, !a
hpx) -> do
          MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
uf Int
x Int
root
          a
hx <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
potential Int
x
          let !hx' :: a
hx' = a
hpx a -> a -> a
forall a. Num a => a -> a -> a
+ a
hx
          MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
MVector (PrimState m) a
potential Int
x a
hx'
          (Int, a) -> m b
k (Int
root, a
hx')
{-# INLINE findUFD #-}

sizeUFD :: (PrimMonad m) => UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD :: forall (m :: * -> *) a.
PrimMonad m =>
UnionFindDiff (PrimState m) a -> Int -> m Int
sizeUFD (UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
_) = ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix (((Int -> m Int) -> Int -> m Int) -> Int -> m Int)
-> ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ \Int -> m Int
loop Int
x -> do
  Int
px <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
x
  if Int
px Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
    then Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$! Int -> Int
forall a. Num a => a -> a
negate Int
px
    else Int -> m Int
loop Int
px
{-# INLINE sizeUFD #-}

{- | @hx - hy = d@

>>> uf <- newUnionFindDiff @Int 2
>>> setDiffUFD uf 1 0 1
Just True
>>> setDiffUFD uf 1 0 999
Nothing
>>> setDiffUFD uf 1 0 1
Just False
>>> setDiffUFD uf 0 1 (-1)
Just False
-}
setDiffUFD ::
  (Eq a, Num a, U.Unbox a, PrimMonad m) =>
  UnionFindDiff (PrimState m) a ->
  Int ->
  Int ->
  a ->
  m (Maybe Bool)
setDiffUFD :: forall a (m :: * -> *).
(Eq a, Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> a -> m (Maybe Bool)
setDiffUFD ufd :: UnionFindDiff (PrimState m) a
ufd@(UFD MVector (PrimState m) Int
uf MVector (PrimState m) a
potential) Int
x Int
y a
d = do
  (Int
px, a
hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
  (Int
py, a
hy) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
y
  if Int
px Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
py
    then
      if a
hx a -> a -> a
forall a. Num a => a -> a -> a
- a
hy a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
d
        then Maybe Bool -> m (Maybe Bool)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Bool -> m (Maybe Bool)) -> Maybe Bool -> m (Maybe Bool)
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
        else Maybe Bool -> m (Maybe Bool)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Bool
forall a. Maybe a
Nothing
    else do
      Int
rx <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
px
      Int
ry <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
uf Int
py
      if Int
rx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
ry
        then do
          MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
uf (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ry) Int
px
          MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
uf Int
py Int
px
          MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
potential Int
py (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ a
hx a -> a -> a
forall a. Num a => a -> a -> a
- a
hy a -> a -> a
forall a. Num a => a -> a -> a
- a
d
        else do
          MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
uf (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rx) Int
py
          MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
uf Int
px Int
py
          MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) a
potential Int
px (a -> m ()) -> a -> m ()
forall a b. (a -> b) -> a -> b
$ a
hy a -> a -> a
forall a. Num a => a -> a -> a
- a
hx a -> a -> a
forall a. Num a => a -> a -> a
+ a
d
      Maybe Bool -> m (Maybe Bool)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Bool -> m (Maybe Bool)) -> Maybe Bool -> m (Maybe Bool)
forall a b. (a -> b) -> a -> b
$ Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
{-# INLINE setDiffUFD #-}

{- | @hx - hy@

>>> uf <- newUnionFindDiff @Int 3
>>> setDiffUFD uf 1 0 1
Just True
>>> diffUFD uf 1 0
Just 1
>>> diffUFD uf 0 1
Just (-1)
>>> diffUFD uf 0 2
Nothing
>>> setDiffUFD uf 2 1 2
Just True
>>> diffUFD uf 2 1
Just 2
>>> diffUFD uf 2 0
Just 3
-}
diffUFD ::
  (Num a, U.Unbox a, PrimMonad m) =>
  UnionFindDiff (PrimState m) a ->
  Int ->
  Int ->
  m (Maybe a)
diffUFD :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> Int -> m (Maybe a)
diffUFD UnionFindDiff (PrimState m) a
ufd Int
x Int
y = do
  (Int
px, a
hx) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
x
  (Int
py, a
hy) <- UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
UnionFindDiff (PrimState m) a -> Int -> m (Int, a)
findUFD UnionFindDiff (PrimState m) a
ufd Int
y
  Maybe a -> m (Maybe a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ if Int
px Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
py
      then a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$! a
hx a -> a -> a
forall a. Num a => a -> a -> a
- a
hy
      else Maybe a
forall a. Maybe a
Nothing
{-# INLINE diffUFD #-}