{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
module Data.Graph.BipartiteMatching where
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Buffer
import My.Prelude (rep)
type Vertex = Int
bipartiteMatching ::
Int ->
(forall s. BipartiteMatchingBuilder s -> ST s ()) ->
Int
bipartiteMatching :: Int -> (forall s. BipartiteMatchingBuilder s -> ST s ()) -> Int
bipartiteMatching Int
n forall s. BipartiteMatchingBuilder s -> ST s ()
run = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
BipartiteMatchingBuilder s
builder <- Int -> ST s (BipartiteMatchingBuilder (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n
BipartiteMatchingBuilder s -> ST s ()
forall s. BipartiteMatchingBuilder s -> ST s ()
run BipartiteMatchingBuilder s
builder
BipartiteMatchingBuilder (PrimState (ST s))
-> ST s (BipartiteMatching (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m)
-> m (BipartiteMatching (PrimState m))
buildBipartiteMatching BipartiteMatchingBuilder s
BipartiteMatchingBuilder (PrimState (ST s))
builder ST s (BipartiteMatching s)
-> (BipartiteMatching s -> ST s Int) -> ST s Int
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= BipartiteMatching s -> ST s Int
BipartiteMatching (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> m Int
runBipartiteMatching
data BipartiteMatching s = BipartiteMatching
{ forall s. BipartiteMatching s -> Int
numVerticesBM :: !Int
, forall s. BipartiteMatching s -> MVector s Int
matchBM :: !(UM.MVector s Int)
, forall s. BipartiteMatching s -> MVector s Bool
usedBM :: !(UM.MVector s Bool)
, forall s. BipartiteMatching s -> Vector Int
offsetBM :: !(U.Vector Int)
, forall s. BipartiteMatching s -> Vector Int
adjacentBM :: !(U.Vector Int)
}
nothingBM :: Int
nothingBM :: Int
nothingBM = -Int
1
{-# INLINE nothingBM #-}
dfsBM :: (PrimMonad m) => BipartiteMatching (PrimState m) -> Vertex -> (Bool -> m ()) -> m ()
dfsBM :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
dfsBM BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = Int -> (Bool -> m ()) -> m ()
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> (Bool -> m b) -> m b
dfs
where
dfs :: Int -> (Bool -> m b) -> m b
dfs !Int
v Bool -> m b
k =
MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v m Bool -> (Bool -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> Bool -> m b
k Bool
False
Bool
False -> do
MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v Bool
True
let begin :: Int
begin = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM Int
v
let end :: Int
end = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
(((Int -> m b) -> Int -> m b) -> Int -> m b)
-> Int -> ((Int -> m b) -> Int -> m b) -> m b
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m b) -> Int -> m b) -> Int -> m b
forall a. (a -> a) -> a
fix Int
begin (((Int -> m b) -> Int -> m b) -> m b)
-> ((Int -> m b) -> Int -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \Int -> m b
loop !Int
i -> do
if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
then do
let nv :: Int
nv = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
adjacentBM Int
i
Int
mnv <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv
if Int
mnv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nothingBM
then do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
v Int
nv
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv Int
v
Bool -> m b
k Bool
True
else do
Int -> (Bool -> m b) -> m b
dfs Int
mnv ((Bool -> m b) -> m b) -> (Bool -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \case
Bool
True -> do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
v Int
nv
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv Int
v
Bool -> m b
k Bool
True
Bool
False -> Int -> m b
loop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else Bool -> m b
k Bool
False
{-# INLINE dfsBM #-}
runBipartiteMatching ::
(PrimMonad m) =>
BipartiteMatching (PrimState m) ->
m Int
runBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> m Int
runBipartiteMatching bm :: BipartiteMatching (PrimState m)
bm@BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = do
MVector (PrimState m) Int
res <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 Int
0
MVector (PrimState m) Bool
updated <- Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 Bool
True
(m Int -> m Int) -> m Int
forall a. (a -> a) -> a
fix ((m Int -> m Int) -> m Int) -> (m Int -> m Int) -> m Int
forall a b. (a -> b) -> a -> b
$ \m Int
loop -> do
MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
False
Int -> (Int -> m ()) -> m ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
numVerticesBM ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Int
mi <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
matchBM Int
i
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
mi Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nothingBM) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
dfsBM BipartiteMatching (PrimState m)
bm Int
i ((Bool -> m ()) -> m ()) -> (Bool -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \case
Bool
True -> do
MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
True
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
res (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
0
Bool
False -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
updated Int
0 m Bool -> (Bool -> m Int) -> m Int
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> do
MVector (PrimState m) Bool -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> a -> m ()
UM.set MVector (PrimState m) Bool
usedBM Bool
False
m Int
loop
Bool
False -> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
res Int
0
{-# INLINE runBipartiteMatching #-}
data BipartiteMatchingBuilder s = BipartiteMatchingBuilder
{ forall s. BipartiteMatchingBuilder s -> Int
numVerticesBMB :: !Int
, forall s. BipartiteMatchingBuilder s -> MVector s Int
inDegreeBMB :: UM.MVector s Int
, forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
edgesBMB :: Buffer s (Vertex, Vertex)
}
newBipartiteMatchingBuilder ::
(PrimMonad m) =>
Int ->
m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n =
Int
-> MVector (PrimState m) Int
-> Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m)
forall s.
Int
-> MVector s Int
-> Buffer s (Int, Int)
-> BipartiteMatchingBuilder s
BipartiteMatchingBuilder Int
n
(MVector (PrimState m) Int
-> Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m))
-> m (MVector (PrimState m) Int)
-> m (Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Int
0
m (Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m))
-> m (Buffer (PrimState m) (Int, Int))
-> m (BipartiteMatchingBuilder (PrimState m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m (Buffer (PrimState m) (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer (Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)
addEdgeBMB ::
(PrimMonad m) =>
BipartiteMatchingBuilder (PrimState m) ->
Vertex ->
Vertex ->
m ()
addEdgeBMB :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m) -> Int -> Int -> m ()
addEdgeBMB BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} !Int
src !Int
dst = do
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
inDegreeBMB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
src
(Int, Int) -> Buffer (PrimState m) (Int, Int) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
src, Int
dst) Buffer (PrimState m) (Int, Int)
edgesBMB
{-# INLINE addEdgeBMB #-}
buildBipartiteMatching ::
(PrimMonad m) =>
BipartiteMatchingBuilder (PrimState m) ->
m (BipartiteMatching (PrimState m))
buildBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m)
-> m (BipartiteMatching (PrimState m))
buildBipartiteMatching BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} = do
let numVerticesBM :: Int
numVerticesBM = Int
numVerticesBMB
MVector (PrimState m) Int
matchBM <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numVerticesBM Int
nothingBM
MVector (PrimState m) Bool
usedBM <- Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numVerticesBM Bool
False
Vector Int
offsetBM <- (Int -> Int -> Int) -> Int -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 (Vector Int -> Vector Int) -> m (Vector Int) -> m (Vector Int)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Int
inDegreeBMB
MVector (PrimState m) Int
madjacentBM <- Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Vector Int -> Int
forall a. Unbox a => Vector a -> a
U.last Vector Int
offsetBM)
MVector (PrimState m) Int
moffset <- Vector Int -> m (MVector (PrimState m) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector Int
offsetBM
Vector (Int, Int)
edges <- Buffer (PrimState m) (Int, Int) -> m (Vector (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Vector a)
unsafeFreezeBuffer Buffer (PrimState m) (Int, Int)
edgesBMB
Vector (Int, Int) -> ((Int, Int) -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Int, Int)
edges (((Int, Int) -> m ()) -> m ()) -> ((Int, Int) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Int
src, Int
dst) -> do
Int
offset <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
moffset Int
src
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
moffset Int
src (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
madjacentBM Int
offset Int
dst
Vector Int
adjacentBM <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Int
madjacentBM
BipartiteMatching (PrimState m)
-> m (BipartiteMatching (PrimState m))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..}
{-# INLINE buildBipartiteMatching #-}