{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.MinCostFlow where

import Control.Exception
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Bits
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Word
import Unsafe.Coerce

import Data.Buffer
import Data.Heap.Binary
import My.Prelude (rep)

nothingMCF :: Int
nothingMCF :: Cost
nothingMCF = -Cost
1

type Vertex = Int
type Cost = Int
type Capacity = Int

{- |
Primal Dual /O(FElog V)/

>>> :{
minCostFlow 2 0 1 2 (\builder -> do
    addEdgeMCFB builder 0 1 123 2
    )
:}
(246,2)
>>> :{
minCostFlow 2 0 1 123456789 (\builder -> do
    addEdgeMCFB builder 0 1 123 2
    )
:}
(246,2)
-}
minCostFlow ::
  -- | number of vertices
  Int ->
  -- | source
  Vertex ->
  -- | sink
  Vertex ->
  -- | flow
  Capacity ->
  (forall s. MinCostFlowBuilder s -> ST s ()) ->
  (Cost, Capacity)
minCostFlow :: Cost
-> Cost
-> Cost
-> Cost
-> (forall s. MinCostFlowBuilder s -> ST s ())
-> (Cost, Cost)
minCostFlow Cost
numVertices Cost
src Cost
sink Cost
flow forall s. MinCostFlowBuilder s -> ST s ()
run = (forall s. ST s (Cost, Cost)) -> (Cost, Cost)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Cost, Cost)) -> (Cost, Cost))
-> (forall s. ST s (Cost, Cost)) -> (Cost, Cost)
forall a b. (a -> b) -> a -> b
$ do
  MinCostFlowBuilder s
builder <- Cost -> ST s (MinCostFlowBuilder (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Cost -> m (MinCostFlowBuilder (PrimState m))
newMinCostFlowBuilder Cost
numVertices
  MinCostFlowBuilder s -> ST s ()
forall s. MinCostFlowBuilder s -> ST s ()
run MinCostFlowBuilder s
builder
  MinCostFlowBuilder (PrimState (ST s))
-> ST s (MinCostFlow (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
MinCostFlowBuilder (PrimState m) -> m (MinCostFlow (PrimState m))
buildMinCostFlow MinCostFlowBuilder s
MinCostFlowBuilder (PrimState (ST s))
builder ST s (MinCostFlow s)
-> (MinCostFlow s -> ST s (Cost, Cost)) -> ST s (Cost, Cost)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Cost
-> Cost
-> Cost
-> MinCostFlow (PrimState (ST s))
-> ST s (Cost, Cost)
forall (m :: * -> *).
PrimMonad m =>
Cost -> Cost -> Cost -> MinCostFlow (PrimState m) -> m (Cost, Cost)
runMinCostFlow Cost
src Cost
sink Cost
flow

data MinCostFlow s = MinCostFlow
  { forall s. MinCostFlow s -> Cost
numVerticesMCF :: !Int
  , forall s. MinCostFlow s -> Cost
numEdgesMCF :: !Int
  , forall s. MinCostFlow s -> Vector Cost
offsetMCF :: U.Vector Int
  , forall s. MinCostFlow s -> Vector Cost
dstMCF :: U.Vector Vertex
  , forall s. MinCostFlow s -> Vector Cost
costMCF :: U.Vector Cost
  , forall s. MinCostFlow s -> MVector s Cost
residualMCF :: UM.MVector s Capacity
  , forall s. MinCostFlow s -> MVector s Cost
potentialMCF :: UM.MVector s Cost
  , forall s. MinCostFlow s -> MVector s Cost
distMCF :: UM.MVector s Cost
  , forall s. MinCostFlow s -> MinBinaryHeap s Word64
heapMCF :: MinBinaryHeap s Word64 -- (Cost, Vertex)
  , forall s. MinCostFlow s -> Vector Cost
revEdgeMCF :: U.Vector Int
  , forall s. MinCostFlow s -> MVector s Cost
prevVertexMCF :: UM.MVector s Vertex
  , forall s. MinCostFlow s -> MVector s Cost
prevEdgeMCF :: UM.MVector s Int
  }

runMinCostFlow ::
  (PrimMonad m) =>
  Vertex ->
  Vertex ->
  Capacity ->
  MinCostFlow (PrimState m) ->
  m (Cost, Capacity)
runMinCostFlow :: forall (m :: * -> *).
PrimMonad m =>
Cost -> Cost -> Cost -> MinCostFlow (PrimState m) -> m (Cost, Cost)
runMinCostFlow Cost
source Cost
sink Cost
flow mcf :: MinCostFlow (PrimState m)
mcf@MinCostFlow{Cost
MVector (PrimState m) Cost
Vector Cost
MinBinaryHeap (PrimState m) Word64
numVerticesMCF :: forall s. MinCostFlow s -> Cost
numEdgesMCF :: forall s. MinCostFlow s -> Cost
offsetMCF :: forall s. MinCostFlow s -> Vector Cost
dstMCF :: forall s. MinCostFlow s -> Vector Cost
costMCF :: forall s. MinCostFlow s -> Vector Cost
residualMCF :: forall s. MinCostFlow s -> MVector s Cost
potentialMCF :: forall s. MinCostFlow s -> MVector s Cost
distMCF :: forall s. MinCostFlow s -> MVector s Cost
heapMCF :: forall s. MinCostFlow s -> MinBinaryHeap s Word64
revEdgeMCF :: forall s. MinCostFlow s -> Vector Cost
prevVertexMCF :: forall s. MinCostFlow s -> MVector s Cost
prevEdgeMCF :: forall s. MinCostFlow s -> MVector s Cost
numVerticesMCF :: Cost
numEdgesMCF :: Cost
offsetMCF :: Vector Cost
dstMCF :: Vector Cost
costMCF :: Vector Cost
residualMCF :: MVector (PrimState m) Cost
potentialMCF :: MVector (PrimState m) Cost
distMCF :: MVector (PrimState m) Cost
heapMCF :: MinBinaryHeap (PrimState m) Word64
revEdgeMCF :: Vector Cost
prevVertexMCF :: MVector (PrimState m) Cost
prevEdgeMCF :: MVector (PrimState m) Cost
..} = Cost -> Cost -> m (Cost, Cost)
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Cost -> Cost -> m (Cost, Cost)
go Cost
0 Cost
flow
  where
    go :: Cost -> Cost -> m (Cost, Cost)
go !Cost
res !Cost
f
      | Cost
f Cost -> Cost -> Bool
forall a. Eq a => a -> a -> Bool
== Cost
0 = (Cost, Cost) -> m (Cost, Cost)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Cost
res, Cost
flow)
      | Bool
otherwise = do
          Bool
canFlow <- Cost -> Cost -> MinCostFlow (PrimState m) -> m Bool
forall (m :: * -> *).
PrimMonad m =>
Cost -> Cost -> MinCostFlow (PrimState m) -> m Bool
dijkstraMCF Cost
source Cost
sink MinCostFlow (PrimState m)
MinCostFlow (PrimState m)
mcf
          if Bool
canFlow
            then do
              Cost -> (Cost -> m ()) -> m ()
forall (m :: * -> *). Monad m => Cost -> (Cost -> m ()) -> m ()
rep Cost
numVerticesMCF ((Cost -> m ()) -> m ()) -> (Cost -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Cost
v -> do
                Cost
dv <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
MVector (PrimState m) Cost
distMCF Cost
v
                MVector (PrimState m) Cost -> (Cost -> Cost) -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Cost -> m ()
UM.unsafeModify MVector (PrimState m) Cost
MVector (PrimState m) Cost
potentialMCF (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
dv) Cost
v
              Cost
flowed <- Cost -> Cost -> MinCostFlow (PrimState m) -> m Cost
forall (m :: * -> *).
PrimMonad m =>
Cost -> Cost -> MinCostFlow (PrimState m) -> m Cost
updateResidualMCF Cost
sink Cost
f MinCostFlow (PrimState m)
MinCostFlow (PrimState m)
mcf
              Cost
hsink <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
MVector (PrimState m) Cost
potentialMCF Cost
sink
              Cost -> Cost -> m (Cost, Cost)
go (Cost
hsink Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
* Cost
flowed Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
res) (Cost
f Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
- Cost
flowed)
            else (Cost, Cost) -> m (Cost, Cost)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Cost
res, Cost
flow Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
- Cost
f)

-- | cost 48bit / vertex 16bit
encodeMCF :: Cost -> Vertex -> Word64
encodeMCF :: Cost -> Cost -> Word64
encodeMCF Cost
cost Cost
v = Cost -> Word64
forall a b. a -> b
unsafeCoerce (Cost -> Word64) -> Cost -> Word64
forall a b. (a -> b) -> a -> b
$ Cost -> Cost -> Cost
forall a. Bits a => a -> Cost -> a
unsafeShiftL Cost
cost Cost
16 Cost -> Cost -> Cost
forall a. Bits a => a -> a -> a
.|. Cost
v
{-# INLINE encodeMCF #-}

decodeMCF :: Word64 -> (Cost, Vertex)
decodeMCF :: Word64 -> (Cost, Cost)
decodeMCF Word64
costv = (Word64, Word64) -> (Cost, Cost)
forall a b. a -> b
unsafeCoerce (Word64
cost, Word64
v)
  where
    !cost :: Word64
cost = Word64 -> Cost -> Word64
forall a. Bits a => a -> Cost -> a
unsafeShiftR Word64
costv Cost
16
    !v :: Word64
v = Word64
costv Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0xffff
{-# INLINE decodeMCF #-}

dijkstraMCF ::
  (PrimMonad m) =>
  Vertex ->
  Vertex ->
  MinCostFlow (PrimState m) ->
  m Bool
dijkstraMCF :: forall (m :: * -> *).
PrimMonad m =>
Cost -> Cost -> MinCostFlow (PrimState m) -> m Bool
dijkstraMCF Cost
source Cost
sink MinCostFlow{Cost
MVector (PrimState m) Cost
Vector Cost
MinBinaryHeap (PrimState m) Word64
numVerticesMCF :: forall s. MinCostFlow s -> Cost
numEdgesMCF :: forall s. MinCostFlow s -> Cost
offsetMCF :: forall s. MinCostFlow s -> Vector Cost
dstMCF :: forall s. MinCostFlow s -> Vector Cost
costMCF :: forall s. MinCostFlow s -> Vector Cost
residualMCF :: forall s. MinCostFlow s -> MVector s Cost
potentialMCF :: forall s. MinCostFlow s -> MVector s Cost
distMCF :: forall s. MinCostFlow s -> MVector s Cost
heapMCF :: forall s. MinCostFlow s -> MinBinaryHeap s Word64
revEdgeMCF :: forall s. MinCostFlow s -> Vector Cost
prevVertexMCF :: forall s. MinCostFlow s -> MVector s Cost
prevEdgeMCF :: forall s. MinCostFlow s -> MVector s Cost
numVerticesMCF :: Cost
numEdgesMCF :: Cost
offsetMCF :: Vector Cost
dstMCF :: Vector Cost
costMCF :: Vector Cost
residualMCF :: MVector (PrimState m) Cost
potentialMCF :: MVector (PrimState m) Cost
distMCF :: MVector (PrimState m) Cost
heapMCF :: MinBinaryHeap (PrimState m) Word64
revEdgeMCF :: Vector Cost
prevVertexMCF :: MVector (PrimState m) Cost
prevEdgeMCF :: MVector (PrimState m) Cost
..} = do
  MVector (PrimState m) Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> a -> m ()
UM.set MVector (PrimState m) Cost
distMCF Cost
forall a. Bounded a => a
maxBound
  MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
distMCF Cost
source Cost
0
  MinBinaryHeap (PrimState m) Word64 -> m ()
forall (m :: * -> *) (f :: * -> *) a.
PrimMonad m =>
BinaryHeap f (PrimState m) a -> m ()
clearBH MinBinaryHeap (PrimState m) Word64
heapMCF
  Word64 -> MinBinaryHeap (PrimState m) Word64 -> m ()
forall (f :: * -> *) a (m :: * -> *).
(OrdVia f a, Unbox a, PrimMonad m) =>
a -> BinaryHeap f (PrimState m) a -> m ()
insertBH (Cost -> Cost -> Word64
encodeMCF Cost
0 Cost
source) MinBinaryHeap (PrimState m) Word64
heapMCF

  (m Bool -> m Bool) -> m Bool
forall a. (a -> a) -> a
fix ((m Bool -> m Bool) -> m Bool) -> (m Bool -> m Bool) -> m Bool
forall a b. (a -> b) -> a -> b
$ \m Bool
loop -> do
    MinBinaryHeap (PrimState m) Word64 -> m (Maybe Word64)
forall (f :: * -> *) a (m :: * -> *).
(OrdVia f a, Unbox a, PrimMonad m) =>
BinaryHeap f (PrimState m) a -> m (Maybe a)
deleteFindTopBH MinBinaryHeap (PrimState m) Word64
heapMCF m (Maybe Word64) -> (Maybe Word64 -> m Bool) -> m Bool
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just Word64
cv -> do
        let (Cost
c, Cost
v) = Word64 -> (Cost, Cost)
decodeMCF Word64
cv
        Cost
dv <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
distMCF Cost
v
        Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Cost
c Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
> Cost
dv) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          let start :: Cost
start = Vector Cost -> Cost -> Cost
forall a. Unbox a => Vector a -> Cost -> a
U.unsafeIndex Vector Cost
offsetMCF Cost
v
          let end :: Cost
end = Vector Cost -> Cost -> Cost
forall a. Unbox a => Vector a -> Cost -> a
U.unsafeIndex Vector Cost
offsetMCF (Cost
v Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
1)
          Vector Cost -> (Cost -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (Cost -> (Cost -> Cost) -> Vector Cost
forall a. Unbox a => Cost -> (Cost -> a) -> Vector a
U.generate (Cost
end Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
- Cost
start) (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
start)) ((Cost -> m ()) -> m ()) -> (Cost -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Cost
e -> do
            let nv :: Cost
nv = Vector Cost -> Cost -> Cost
forall a. Unbox a => Vector a -> Cost -> a
U.unsafeIndex Vector Cost
dstMCF Cost
e
            let v2nv :: Cost
v2nv = Vector Cost -> Cost -> Cost
forall a. Unbox a => Vector a -> Cost -> a
U.unsafeIndex Vector Cost
costMCF Cost
e
            Cost
cap <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
residualMCF Cost
e
            Cost
hv <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
potentialMCF Cost
v
            Cost
hnv <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
potentialMCF Cost
nv
            Cost
old <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
distMCF Cost
nv
            let dnv :: Cost
dnv = Cost
dv Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
v2nv Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
hv Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
- Cost
hnv
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Cost
cap Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
> Cost
0 Bool -> Bool -> Bool
&& Cost
dnv Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
< Cost
old) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
              MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
distMCF Cost
nv Cost
dnv
              MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
prevVertexMCF Cost
nv Cost
v
              MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
prevEdgeMCF Cost
nv Cost
e
              Word64 -> MinBinaryHeap (PrimState m) Word64 -> m ()
forall (f :: * -> *) a (m :: * -> *).
(OrdVia f a, Unbox a, PrimMonad m) =>
a -> BinaryHeap f (PrimState m) a -> m ()
insertBH (Cost -> Cost -> Word64
encodeMCF Cost
dnv Cost
nv) MinBinaryHeap (PrimState m) Word64
heapMCF
        m Bool
loop
      Maybe Word64
Nothing -> do
        Cost
cost <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
distMCF Cost
sink
        Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$! Cost
cost Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
< Cost
forall a. Bounded a => a
maxBound
{-# INLINE dijkstraMCF #-}

updateResidualMCF ::
  (PrimMonad m) =>
  Vertex ->
  Capacity ->
  MinCostFlow (PrimState m) ->
  m Capacity
updateResidualMCF :: forall (m :: * -> *).
PrimMonad m =>
Cost -> Cost -> MinCostFlow (PrimState m) -> m Cost
updateResidualMCF Cost
sink Cost
flow MinCostFlow{Cost
MVector (PrimState m) Cost
Vector Cost
MinBinaryHeap (PrimState m) Word64
numVerticesMCF :: forall s. MinCostFlow s -> Cost
numEdgesMCF :: forall s. MinCostFlow s -> Cost
offsetMCF :: forall s. MinCostFlow s -> Vector Cost
dstMCF :: forall s. MinCostFlow s -> Vector Cost
costMCF :: forall s. MinCostFlow s -> Vector Cost
residualMCF :: forall s. MinCostFlow s -> MVector s Cost
potentialMCF :: forall s. MinCostFlow s -> MVector s Cost
distMCF :: forall s. MinCostFlow s -> MVector s Cost
heapMCF :: forall s. MinCostFlow s -> MinBinaryHeap s Word64
revEdgeMCF :: forall s. MinCostFlow s -> Vector Cost
prevVertexMCF :: forall s. MinCostFlow s -> MVector s Cost
prevEdgeMCF :: forall s. MinCostFlow s -> MVector s Cost
numVerticesMCF :: Cost
numEdgesMCF :: Cost
offsetMCF :: Vector Cost
dstMCF :: Vector Cost
costMCF :: Vector Cost
residualMCF :: MVector (PrimState m) Cost
potentialMCF :: MVector (PrimState m) Cost
distMCF :: MVector (PrimState m) Cost
heapMCF :: MinBinaryHeap (PrimState m) Word64
revEdgeMCF :: Vector Cost
prevVertexMCF :: MVector (PrimState m) Cost
prevEdgeMCF :: MVector (PrimState m) Cost
..} = Cost -> Cost -> (Cost -> m Cost) -> m Cost
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Cost -> Cost -> (Cost -> m b) -> m b
go Cost
sink Cost
flow Cost -> m Cost
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
  where
    go :: Cost -> Cost -> (Cost -> m b) -> m b
go !Cost
v !Cost
f Cost -> m b
k = do
      Cost
pv <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
MVector (PrimState m) Cost
prevVertexMCF Cost
v
      if Cost
pv Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
< Cost
0
        then Cost -> m b
k Cost
f
        else do
          Cost
pv2v <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
MVector (PrimState m) Cost
prevEdgeMCF Cost
v
          Cost
f' <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
MVector (PrimState m) Cost
residualMCF Cost
pv2v
          Cost -> Cost -> (Cost -> m b) -> m b
go Cost
pv (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
f Cost
f') ((Cost -> m b) -> m b) -> (Cost -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \Cost
nf -> do
            MVector (PrimState m) Cost -> (Cost -> Cost) -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Cost -> m ()
UM.unsafeModify MVector (PrimState m) Cost
MVector (PrimState m) Cost
residualMCF (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
subtract Cost
nf) Cost
pv2v
            MVector (PrimState m) Cost -> (Cost -> Cost) -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Cost -> m ()
UM.unsafeModify MVector (PrimState m) Cost
MVector (PrimState m) Cost
residualMCF (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
nf) (Vector Cost -> Cost -> Cost
forall a. Unbox a => Vector a -> Cost -> a
U.unsafeIndex Vector Cost
revEdgeMCF Cost
pv2v)
            Cost -> m b
k Cost
nf
{-# INLINE updateResidualMCF #-}

data MinCostFlowBuilder s = MinCostFlowBuilder
  { forall s. MinCostFlowBuilder s -> Cost
numVerticesMCFB :: !Int
  , forall s. MinCostFlowBuilder s -> MVector s Cost
inDegreeMCFB :: UM.MVector s Int
  , forall s. MinCostFlowBuilder s -> Buffer s (Cost, Cost, Cost, Cost)
edgesMCFB :: Buffer s (Vertex, Vertex, Cost, Capacity)
  -- ^ default buffer size: /1024 * 1024/
  }

newMinCostFlowBuilder ::
  (PrimMonad m) =>
  Int ->
  m (MinCostFlowBuilder (PrimState m))
newMinCostFlowBuilder :: forall (m :: * -> *).
PrimMonad m =>
Cost -> m (MinCostFlowBuilder (PrimState m))
newMinCostFlowBuilder Cost
n =
  Cost
-> MVector (PrimState m) Cost
-> Buffer (PrimState m) (Cost, Cost, Cost, Cost)
-> MinCostFlowBuilder (PrimState m)
forall s.
Cost
-> MVector s Cost
-> Buffer s (Cost, Cost, Cost, Cost)
-> MinCostFlowBuilder s
MinCostFlowBuilder Cost
n
    (MVector (PrimState m) Cost
 -> Buffer (PrimState m) (Cost, Cost, Cost, Cost)
 -> MinCostFlowBuilder (PrimState m))
-> m (MVector (PrimState m) Cost)
-> m (Buffer (PrimState m) (Cost, Cost, Cost, Cost)
      -> MinCostFlowBuilder (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
n Cost
0
    m (Buffer (PrimState m) (Cost, Cost, Cost, Cost)
   -> MinCostFlowBuilder (PrimState m))
-> m (Buffer (PrimState m) (Cost, Cost, Cost, Cost))
-> m (MinCostFlowBuilder (PrimState m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Cost -> m (Buffer (PrimState m) (Cost, Cost, Cost, Cost))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Cost -> m (Buffer (PrimState m) a)
newBuffer (Cost
1024 Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
* Cost
1024)

-- | /cost >= 0/
addEdgeMCFB ::
  (PrimMonad m) =>
  MinCostFlowBuilder (PrimState m) ->
  Vertex ->
  Vertex ->
  Cost ->
  Capacity ->
  m ()
addEdgeMCFB :: forall (m :: * -> *).
PrimMonad m =>
MinCostFlowBuilder (PrimState m)
-> Cost -> Cost -> Cost -> Cost -> m ()
addEdgeMCFB MinCostFlowBuilder{Cost
MVector (PrimState m) Cost
Buffer (PrimState m) (Cost, Cost, Cost, Cost)
numVerticesMCFB :: forall s. MinCostFlowBuilder s -> Cost
inDegreeMCFB :: forall s. MinCostFlowBuilder s -> MVector s Cost
edgesMCFB :: forall s. MinCostFlowBuilder s -> Buffer s (Cost, Cost, Cost, Cost)
numVerticesMCFB :: Cost
inDegreeMCFB :: MVector (PrimState m) Cost
edgesMCFB :: Buffer (PrimState m) (Cost, Cost, Cost, Cost)
..} Cost
src Cost
dst Cost
cost Cost
capacity =
  Bool -> m () -> m ()
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Cost
cost Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
>= Cost
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    MVector (PrimState m) Cost -> (Cost -> Cost) -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Cost -> m ()
UM.unsafeModify MVector (PrimState m) Cost
inDegreeMCFB (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
1) Cost
src
    MVector (PrimState m) Cost -> (Cost -> Cost) -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Cost -> m ()
UM.unsafeModify MVector (PrimState m) Cost
inDegreeMCFB (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
1) Cost
dst
    (Cost, Cost, Cost, Cost)
-> Buffer (PrimState m) (Cost, Cost, Cost, Cost) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Cost
src, Cost
dst, Cost
cost, Cost
capacity) Buffer (PrimState m) (Cost, Cost, Cost, Cost)
edgesMCFB

buildMinCostFlow ::
  (PrimMonad m) =>
  MinCostFlowBuilder (PrimState m) ->
  m (MinCostFlow (PrimState m))
buildMinCostFlow :: forall (m :: * -> *).
PrimMonad m =>
MinCostFlowBuilder (PrimState m) -> m (MinCostFlow (PrimState m))
buildMinCostFlow MinCostFlowBuilder{Cost
MVector (PrimState m) Cost
Buffer (PrimState m) (Cost, Cost, Cost, Cost)
numVerticesMCFB :: forall s. MinCostFlowBuilder s -> Cost
inDegreeMCFB :: forall s. MinCostFlowBuilder s -> MVector s Cost
edgesMCFB :: forall s. MinCostFlowBuilder s -> Buffer s (Cost, Cost, Cost, Cost)
numVerticesMCFB :: Cost
inDegreeMCFB :: MVector (PrimState m) Cost
edgesMCFB :: Buffer (PrimState m) (Cost, Cost, Cost, Cost)
..} = do
  Vector Cost
offsetMCF <- (Cost -> Cost -> Cost) -> Cost -> Vector Cost -> Vector Cost
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
(+) Cost
0 (Vector Cost -> Vector Cost) -> m (Vector Cost) -> m (Vector Cost)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) Cost -> m (Vector Cost)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Cost
inDegreeMCFB
  let numVerticesMCF :: Cost
numVerticesMCF = Cost
numVerticesMCFB
  let numEdgesMCF :: Cost
numEdgesMCF = Vector Cost -> Cost
forall a. Unbox a => Vector a -> a
U.last Vector Cost
offsetMCF

  MVector (PrimState m) Cost
moffset <- Vector Cost -> m (MVector (PrimState m) Cost)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector Cost
offsetMCF
  MVector (PrimState m) Cost
mdstMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numEdgesMCF Cost
nothingMCF
  MVector (PrimState m) Cost
mcostMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numEdgesMCF Cost
0
  MVector (PrimState m) Cost
mrevEdgeMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numEdgesMCF Cost
nothingMCF
  MVector (PrimState m) Cost
residualMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numEdgesMCF Cost
0

  Vector (Cost, Cost, Cost, Cost)
edges <- Buffer (PrimState m) (Cost, Cost, Cost, Cost)
-> m (Vector (Cost, Cost, Cost, Cost))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Vector a)
unsafeFreezeBuffer Buffer (PrimState m) (Cost, Cost, Cost, Cost)
edgesMCFB
  Vector (Cost, Cost, Cost, Cost)
-> ((Cost, Cost, Cost, Cost) -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Cost, Cost, Cost, Cost)
edges (((Cost, Cost, Cost, Cost) -> m ()) -> m ())
-> ((Cost, Cost, Cost, Cost) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Cost
src, Cost
dst, Cost
cost, Cost
capacity) -> do
    Cost
srcOffset <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
moffset Cost
src
    Cost
dstOffset <- MVector (PrimState m) Cost -> Cost -> m Cost
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> m a
UM.unsafeRead MVector (PrimState m) Cost
moffset Cost
dst
    MVector (PrimState m) Cost -> (Cost -> Cost) -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Cost -> m ()
UM.unsafeModify MVector (PrimState m) Cost
moffset (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
1) Cost
src
    MVector (PrimState m) Cost -> (Cost -> Cost) -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Cost -> m ()
UM.unsafeModify MVector (PrimState m) Cost
moffset (Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
1) Cost
dst

    MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
mdstMCF Cost
srcOffset Cost
dst
    MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
mdstMCF Cost
dstOffset Cost
src
    MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
mcostMCF Cost
srcOffset Cost
cost
    MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
mcostMCF Cost
dstOffset (-Cost
cost)
    MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
mrevEdgeMCF Cost
srcOffset Cost
dstOffset
    MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
mrevEdgeMCF Cost
dstOffset Cost
srcOffset
    MVector (PrimState m) Cost -> Cost -> Cost -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Cost -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Cost
residualMCF Cost
srcOffset Cost
capacity

  Vector Cost
dstMCF <- MVector (PrimState m) Cost -> m (Vector Cost)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Cost
mdstMCF
  Vector Cost
costMCF <- MVector (PrimState m) Cost -> m (Vector Cost)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Cost
mcostMCF
  MVector (PrimState m) Cost
potentialMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numVerticesMCF Cost
0
  MVector (PrimState m) Cost
distMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numVerticesMCF Cost
0
  MinBinaryHeap (PrimState m) Word64
heapMCF <- Cost -> m (MinBinaryHeap (PrimState m) Word64)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Cost -> m (MinBinaryHeap (PrimState m) a)
newMinBinaryHeap (Cost
numEdgesMCF Cost -> Cost -> Cost
forall a. Num a => a -> a -> a
+ Cost
1)
  Vector Cost
revEdgeMCF <- MVector (PrimState m) Cost -> m (Vector Cost)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Cost
mrevEdgeMCF
  MVector (PrimState m) Cost
prevVertexMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numVerticesMCF Cost
nothingMCF
  MVector (PrimState m) Cost
prevEdgeMCF <- Cost -> Cost -> m (MVector (PrimState m) Cost)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Cost -> a -> m (MVector (PrimState m) a)
UM.replicate Cost
numVerticesMCF Cost
nothingMCF
  MinCostFlow (PrimState m) -> m (MinCostFlow (PrimState m))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return MinCostFlow{Cost
MVector (PrimState m) Cost
Vector Cost
MinBinaryHeap (PrimState m) Word64
numVerticesMCF :: Cost
numEdgesMCF :: Cost
offsetMCF :: Vector Cost
dstMCF :: Vector Cost
costMCF :: Vector Cost
residualMCF :: MVector (PrimState m) Cost
potentialMCF :: MVector (PrimState m) Cost
distMCF :: MVector (PrimState m) Cost
heapMCF :: MinBinaryHeap (PrimState m) Word64
revEdgeMCF :: Vector Cost
prevVertexMCF :: MVector (PrimState m) Cost
prevEdgeMCF :: MVector (PrimState m) Cost
offsetMCF :: Vector Cost
numVerticesMCF :: Cost
numEdgesMCF :: Cost
residualMCF :: MVector (PrimState m) Cost
dstMCF :: Vector Cost
costMCF :: Vector Cost
potentialMCF :: MVector (PrimState m) Cost
distMCF :: MVector (PrimState m) Cost
heapMCF :: MinBinaryHeap (PrimState m) Word64
revEdgeMCF :: Vector Cost
prevVertexMCF :: MVector (PrimState m) Cost
prevEdgeMCF :: MVector (PrimState m) Cost
..}