{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.MaxFlow where

import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Buffer

nothingMF :: Int
nothingMF :: Int
nothingMF = -Int
1

type Vertex = Int

--

{- |
Dinic /O(V^2E)/

>>> :{
maxFlow @Int 5 0 4 $ \builder -> do
    addEdgeMFB builder (0, 1, 10)
    addEdgeMFB builder (0, 2, 2)
    addEdgeMFB builder (1, 2, 6)
    addEdgeMFB builder (1, 3, 6)
    addEdgeMFB builder (3, 2, 2)
    addEdgeMFB builder (2, 4, 5)
    addEdgeMFB builder (3, 4, 8)
:}
11
>>> maxFlow @Int 2 0 1 $ const (return ())
0
-}
maxFlow ::
  (U.Unbox cap, Num cap, Ord cap, Bounded cap) =>
  -- | number of vertices
  Int ->
  -- | source
  Vertex ->
  -- | sink
  Vertex ->
  (forall s. MaxFlowBuilder s cap -> ST s ()) ->
  cap
maxFlow :: forall cap.
(Unbox cap, Num cap, Ord cap, Bounded cap) =>
Int
-> Int -> Int -> (forall s. MaxFlowBuilder s cap -> ST s ()) -> cap
maxFlow Int
numVertices Int
src Int
sink forall s. MaxFlowBuilder s cap -> ST s ()
run = (forall s. ST s cap) -> cap
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s cap) -> cap) -> (forall s. ST s cap) -> cap
forall a b. (a -> b) -> a -> b
$ do
  builder <- Int -> ST s (MaxFlowBuilder (PrimState (ST s)) cap)
forall cap (m :: * -> *).
(Unbox cap, PrimMonad m) =>
Int -> m (MaxFlowBuilder (PrimState m) cap)
newMaxFlowBuilder Int
numVertices
  run builder
  buildMaxFlow builder >>= runMaxFlow src sink

data MaxFlow s cap = MaxFlow
  { forall s cap. MaxFlow s cap -> Int
numVerticesMF :: !Int
  , forall s cap. MaxFlow s cap -> Int
numEdgesMF :: !Int
  , forall s cap. MaxFlow s cap -> Vector Int
offsetMF :: U.Vector Int
  , forall s cap. MaxFlow s cap -> Vector Int
dstMF :: U.Vector Vertex
  , forall s cap. MaxFlow s cap -> MVector s cap
residualMF :: UM.MVector s cap
  , forall s cap. MaxFlow s cap -> MVector s Int
levelMF :: UM.MVector s Int
  , forall s cap. MaxFlow s cap -> Vector Int
revEdgeMF :: U.Vector Int
  , forall s cap. MaxFlow s cap -> MVector s Int
iterMF :: UM.MVector s Int
  , forall s cap. MaxFlow s cap -> Queue s Int
queueMF :: Queue s Vertex
  }

runMaxFlow ::
  (U.Unbox cap, Num cap, Ord cap, Bounded cap, PrimMonad m) =>
  Vertex ->
  Vertex ->
  MaxFlow (PrimState m) cap ->
  m cap
runMaxFlow :: forall cap (m :: * -> *).
(Unbox cap, Num cap, Ord cap, Bounded cap, PrimMonad m) =>
Int -> Int -> MaxFlow (PrimState m) cap -> m cap
runMaxFlow Int
src Int
sink mf :: MaxFlow (PrimState m) cap
mf@MaxFlow{Int
MVector (PrimState m) cap
MVector (PrimState m) Int
Vector Int
Queue (PrimState m) Int
numVerticesMF :: forall s cap. MaxFlow s cap -> Int
numEdgesMF :: forall s cap. MaxFlow s cap -> Int
offsetMF :: forall s cap. MaxFlow s cap -> Vector Int
dstMF :: forall s cap. MaxFlow s cap -> Vector Int
residualMF :: forall s cap. MaxFlow s cap -> MVector s cap
levelMF :: forall s cap. MaxFlow s cap -> MVector s Int
revEdgeMF :: forall s cap. MaxFlow s cap -> Vector Int
iterMF :: forall s cap. MaxFlow s cap -> MVector s Int
queueMF :: forall s cap. MaxFlow s cap -> Queue s Int
numVerticesMF :: Int
numEdgesMF :: Int
offsetMF :: Vector Int
dstMF :: Vector Int
residualMF :: MVector (PrimState m) cap
levelMF :: MVector (PrimState m) Int
revEdgeMF :: Vector Int
iterMF :: MVector (PrimState m) Int
queueMF :: Queue (PrimState m) Int
..} = do
  (((cap -> m cap) -> cap -> m cap) -> cap -> m cap)
-> cap -> ((cap -> m cap) -> cap -> m cap) -> m cap
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((cap -> m cap) -> cap -> m cap) -> cap -> m cap
forall a. (a -> a) -> a
fix cap
0 (((cap -> m cap) -> cap -> m cap) -> m cap)
-> ((cap -> m cap) -> cap -> m cap) -> m cap
forall a b. (a -> b) -> a -> b
$ \cap -> m cap
loopBFS !cap
flow -> do
    MVector (PrimState m) Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> a -> m ()
UM.set MVector (PrimState m) Int
levelMF Int
nothingMF
    Queue (PrimState m) Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
Buffer (PrimState m) a -> m ()
clearBuffer Queue (PrimState m) Int
queueMF
    Int -> Int -> MaxFlow (PrimState m) cap -> m ()
forall cap (m :: * -> *).
(Num cap, Ord cap, Unbox cap, PrimMonad m) =>
Int -> Int -> MaxFlow (PrimState m) cap -> m ()
bfsMF Int
src Int
sink MaxFlow (PrimState m) cap
mf
    lsink <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
levelMF Int
sink
    if lsink == nothingMF
      then return flow
      else do
        U.unsafeCopy iterMF offsetMF
        flip fix flow $ \cap -> m cap
loopDFS !cap
f -> do
          df <- Int -> Int -> cap -> MaxFlow (PrimState m) cap -> m cap
forall cap (m :: * -> *).
(Unbox cap, Num cap, Ord cap, Bounded cap, PrimMonad m) =>
Int -> Int -> cap -> MaxFlow (PrimState m) cap -> m cap
dfsMF Int
src Int
sink cap
forall a. Bounded a => a
maxBound MaxFlow (PrimState m) cap
mf
          if df > 0
            then loopDFS (f + df)
            else loopBFS f

bfsMF ::
  (Num cap, Ord cap, U.Unbox cap, PrimMonad m) =>
  Vertex ->
  Vertex ->
  MaxFlow (PrimState m) cap ->
  m ()
bfsMF :: forall cap (m :: * -> *).
(Num cap, Ord cap, Unbox cap, PrimMonad m) =>
Int -> Int -> MaxFlow (PrimState m) cap -> m ()
bfsMF Int
src Int
sink MaxFlow{Int
MVector (PrimState m) cap
MVector (PrimState m) Int
Vector Int
Queue (PrimState m) Int
numVerticesMF :: forall s cap. MaxFlow s cap -> Int
numEdgesMF :: forall s cap. MaxFlow s cap -> Int
offsetMF :: forall s cap. MaxFlow s cap -> Vector Int
dstMF :: forall s cap. MaxFlow s cap -> Vector Int
residualMF :: forall s cap. MaxFlow s cap -> MVector s cap
levelMF :: forall s cap. MaxFlow s cap -> MVector s Int
revEdgeMF :: forall s cap. MaxFlow s cap -> Vector Int
iterMF :: forall s cap. MaxFlow s cap -> MVector s Int
queueMF :: forall s cap. MaxFlow s cap -> Queue s Int
numVerticesMF :: Int
numEdgesMF :: Int
offsetMF :: Vector Int
dstMF :: Vector Int
residualMF :: MVector (PrimState m) cap
levelMF :: MVector (PrimState m) Int
revEdgeMF :: Vector Int
iterMF :: MVector (PrimState m) Int
queueMF :: Queue (PrimState m) Int
..} = do
  MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
levelMF Int
src Int
0
  Int -> Queue (PrimState m) Int -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Int
src Queue (PrimState m) Int
queueMF
  (m () -> m ()) -> m ()
forall a. (a -> a) -> a
fix ((m () -> m ()) -> m ()) -> (m () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \m ()
loop -> do
    Queue (PrimState m) Int -> m (Maybe Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popFront Queue (PrimState m) Int
queueMF m (Maybe Int) -> (Maybe Int -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just Int
v -> do
        lsink <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
levelMF Int
sink
        when (lsink == nothingMF) $ do
          let start = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetMF Int
v
              end = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetMF (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          lv <- UM.unsafeRead levelMF v
          U.forM_ (U.generate (end - start) (+ start)) $ \Int
e -> do
            let nv :: Int
nv = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
dstMF Int
e
            res <- MVector (PrimState m) cap -> Int -> m cap
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) cap
residualMF Int
e
            lnv <- UM.unsafeRead levelMF nv
            when (res > 0 && lnv == nothingMF) $ do
              UM.unsafeWrite levelMF nv (lv + 1)
              pushBack nv queueMF
          loop
      Maybe Int
Nothing -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE bfsMF #-}

dfsMF ::
  (U.Unbox cap, Num cap, Ord cap, Bounded cap, PrimMonad m) =>
  Vertex ->
  Vertex ->
  cap ->
  MaxFlow (PrimState m) cap ->
  m cap
dfsMF :: forall cap (m :: * -> *).
(Unbox cap, Num cap, Ord cap, Bounded cap, PrimMonad m) =>
Int -> Int -> cap -> MaxFlow (PrimState m) cap -> m cap
dfsMF Int
v0 Int
sink cap
flow0 MaxFlow{Int
MVector (PrimState m) cap
MVector (PrimState m) Int
Vector Int
Queue (PrimState m) Int
numVerticesMF :: forall s cap. MaxFlow s cap -> Int
numEdgesMF :: forall s cap. MaxFlow s cap -> Int
offsetMF :: forall s cap. MaxFlow s cap -> Vector Int
dstMF :: forall s cap. MaxFlow s cap -> Vector Int
residualMF :: forall s cap. MaxFlow s cap -> MVector s cap
levelMF :: forall s cap. MaxFlow s cap -> MVector s Int
revEdgeMF :: forall s cap. MaxFlow s cap -> Vector Int
iterMF :: forall s cap. MaxFlow s cap -> MVector s Int
queueMF :: forall s cap. MaxFlow s cap -> Queue s Int
numVerticesMF :: Int
numEdgesMF :: Int
offsetMF :: Vector Int
dstMF :: Vector Int
residualMF :: MVector (PrimState m) cap
levelMF :: MVector (PrimState m) Int
revEdgeMF :: Vector Int
iterMF :: MVector (PrimState m) Int
queueMF :: Queue (PrimState m) Int
..} = Int -> cap -> (cap -> m cap) -> m cap
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> cap -> (cap -> m b) -> m b
dfs Int
v0 cap
flow0 cap -> m cap
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
  where
    dfs :: Int -> cap -> (cap -> m b) -> m b
dfs !Int
v !cap
flow cap -> m b
k
      | Int
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sink = cap -> m b
k cap
flow
      | Bool
otherwise = (m b -> m b) -> m b
forall a. (a -> a) -> a
fix ((m b -> m b) -> m b) -> (m b -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \m b
loop -> do
          e <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
iterMF Int
v
          if e < U.unsafeIndex offsetMF (v + 1)
            then do
              UM.unsafeWrite iterMF v (e + 1)
              let nv = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
dstMF Int
e
              cap <- UM.unsafeRead residualMF e
              lv <- UM.unsafeRead levelMF v
              lnv <- UM.unsafeRead levelMF nv
              if cap > 0 && lv < lnv
                then do
                  dfs nv (min flow cap) $ \cap
f -> do
                    if cap
f cap -> cap -> Bool
forall a. Ord a => a -> a -> Bool
> cap
0
                      then do
                        MVector (PrimState m) cap -> (cap -> cap) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) cap
MVector (PrimState m) cap
residualMF (cap -> cap -> cap
forall a. Num a => a -> a -> a
subtract cap
f) Int
e
                        MVector (PrimState m) cap -> (cap -> cap) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify
                          MVector (PrimState m) cap
MVector (PrimState m) cap
residualMF
                          (cap -> cap -> cap
forall a. Num a => a -> a -> a
+ cap
f)
                          (Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
revEdgeMF Int
e)
                        cap -> m b
k cap
f
                      else m b
loop
                else loop
            else k 0
{-# INLINE dfsMF #-}

data MaxFlowBuilder s cap = MaxFlowBuilder
  { forall s cap. MaxFlowBuilder s cap -> Int
numVerticesMFB :: !Int
  , forall s cap. MaxFlowBuilder s cap -> MVector s Int
inDegreeMFB :: UM.MVector s Int
  , forall s cap. MaxFlowBuilder s cap -> Buffer s (Int, Int, cap)
edgesMFB :: Buffer s (Vertex, Vertex, cap)
  -- ^ default buffer size: /1024 * 1024/
  }

newMaxFlowBuilder ::
  (U.Unbox cap, PrimMonad m) =>
  Int ->
  m (MaxFlowBuilder (PrimState m) cap)
newMaxFlowBuilder :: forall cap (m :: * -> *).
(Unbox cap, PrimMonad m) =>
Int -> m (MaxFlowBuilder (PrimState m) cap)
newMaxFlowBuilder Int
n =
  Int
-> MVector (PrimState m) Int
-> Buffer (PrimState m) (Int, Int, cap)
-> MaxFlowBuilder (PrimState m) cap
forall s cap.
Int
-> MVector s Int
-> Buffer s (Int, Int, cap)
-> MaxFlowBuilder s cap
MaxFlowBuilder Int
n
    (MVector (PrimState m) Int
 -> Buffer (PrimState m) (Int, Int, cap)
 -> MaxFlowBuilder (PrimState m) cap)
-> m (MVector (PrimState m) Int)
-> m (Buffer (PrimState m) (Int, Int, cap)
      -> MaxFlowBuilder (PrimState m) cap)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Int
0
    m (Buffer (PrimState m) (Int, Int, cap)
   -> MaxFlowBuilder (PrimState m) cap)
-> m (Buffer (PrimState m) (Int, Int, cap))
-> m (MaxFlowBuilder (PrimState m) cap)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m (Buffer (PrimState m) (Int, Int, cap))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer (Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)

buildMaxFlow ::
  (Num cap, U.Unbox cap, PrimMonad m) =>
  MaxFlowBuilder (PrimState m) cap ->
  m (MaxFlow (PrimState m) cap)
buildMaxFlow :: forall cap (m :: * -> *).
(Num cap, Unbox cap, PrimMonad m) =>
MaxFlowBuilder (PrimState m) cap -> m (MaxFlow (PrimState m) cap)
buildMaxFlow MaxFlowBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int, cap)
numVerticesMFB :: forall s cap. MaxFlowBuilder s cap -> Int
inDegreeMFB :: forall s cap. MaxFlowBuilder s cap -> MVector s Int
edgesMFB :: forall s cap. MaxFlowBuilder s cap -> Buffer s (Int, Int, cap)
numVerticesMFB :: Int
inDegreeMFB :: MVector (PrimState m) Int
edgesMFB :: Buffer (PrimState m) (Int, Int, cap)
..} = do
  offsetMF <- (Int -> Int -> Int) -> Int -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 (Vector Int -> Vector Int) -> m (Vector Int) -> m (Vector Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Int
inDegreeMFB
  let numVerticesMF = Int
numVerticesMFB
  let numEdgesMF = Vector Int -> Int
forall a. Unbox a => Vector a -> a
U.last Vector Int
offsetMF

  moffset <- U.thaw offsetMF
  mdstMF <- UM.replicate numEdgesMF nothingMF
  mrevEdgeMF <- UM.replicate numEdgesMF nothingMF
  residualMF <- UM.replicate numEdgesMF 0

  edges <- unsafeFreezeBuffer edgesMFB
  U.forM_ edges $ \(Int
src, Int
dst, cap
cap) -> do
    srcOffset <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
moffset Int
src
    dstOffset <- UM.unsafeRead moffset dst
    UM.unsafeModify moffset (+ 1) src
    UM.unsafeModify moffset (+ 1) dst
    UM.unsafeWrite mdstMF srcOffset dst
    UM.unsafeWrite mdstMF dstOffset src
    UM.unsafeWrite mrevEdgeMF srcOffset dstOffset
    UM.unsafeWrite mrevEdgeMF dstOffset srcOffset
    UM.unsafeWrite residualMF srcOffset cap

  dstMF <- U.unsafeFreeze mdstMF
  levelMF <- UM.replicate numVerticesMF nothingMF
  revEdgeMF <- U.unsafeFreeze mrevEdgeMF
  iterMF <- UM.replicate numVerticesMF 0
  U.unsafeCopy iterMF offsetMF
  queueMF <- newBufferAsQueue numVerticesMF
  return MaxFlow{..}

addEdgeMFB ::
  (U.Unbox cap, PrimMonad m) =>
  MaxFlowBuilder (PrimState m) cap ->
  (Vertex, Vertex, cap) ->
  m ()
addEdgeMFB :: forall cap (m :: * -> *).
(Unbox cap, PrimMonad m) =>
MaxFlowBuilder (PrimState m) cap -> (Int, Int, cap) -> m ()
addEdgeMFB MaxFlowBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int, cap)
numVerticesMFB :: forall s cap. MaxFlowBuilder s cap -> Int
inDegreeMFB :: forall s cap. MaxFlowBuilder s cap -> MVector s Int
edgesMFB :: forall s cap. MaxFlowBuilder s cap -> Buffer s (Int, Int, cap)
numVerticesMFB :: Int
inDegreeMFB :: MVector (PrimState m) Int
edgesMFB :: Buffer (PrimState m) (Int, Int, cap)
..} (!Int
src, !Int
dst, !cap
cap) = do
  MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
inDegreeMFB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
src
  MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
inDegreeMFB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
dst
  (Int, Int, cap) -> Buffer (PrimState m) (Int, Int, cap) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
src, Int
dst, cap
cap) Buffer (PrimState m) (Int, Int, cap)
edgesMFB
{-# INLINE addEdgeMFB #-}