{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
module Data.Graph.BipartiteMatching where
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Buffer
import My.Prelude (rep)
type Vertex = Int
bipartiteMatching ::
Int ->
(forall s. BipartiteMatchingBuilder s -> ST s ()) ->
Int
bipartiteMatching :: Int -> (forall s. BipartiteMatchingBuilder s -> ST s ()) -> Int
bipartiteMatching Int
n forall s. BipartiteMatchingBuilder s -> ST s ()
run = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
builder <- Int -> ST s (BipartiteMatchingBuilder (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n
run builder
buildBipartiteMatching builder >>= runBipartiteMatching
data BipartiteMatching s = BipartiteMatching
{ forall s. BipartiteMatching s -> Int
numVerticesBM :: !Int
, forall s. BipartiteMatching s -> MVector s Int
matchBM :: !(UM.MVector s Int)
, forall s. BipartiteMatching s -> MVector s Bool
usedBM :: !(UM.MVector s Bool)
, forall s. BipartiteMatching s -> Vector Int
offsetBM :: !(U.Vector Int)
, forall s. BipartiteMatching s -> Vector Int
adjacentBM :: !(U.Vector Int)
}
nothingBM :: Int
nothingBM :: Int
nothingBM = -Int
1
{-# INLINE nothingBM #-}
dfsBM :: (PrimMonad m) => BipartiteMatching (PrimState m) -> Vertex -> (Bool -> m ()) -> m ()
dfsBM :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
dfsBM BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = Int -> (Bool -> m ()) -> m ()
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> (Bool -> m b) -> m b
dfs
where
dfs :: Int -> (Bool -> m b) -> m b
dfs !Int
v Bool -> m b
k =
MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v m Bool -> (Bool -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> Bool -> m b
k Bool
False
Bool
False -> do
MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v Bool
True
let begin :: Int
begin = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM Int
v
let end :: Int
end = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
(((Int -> m b) -> Int -> m b) -> Int -> m b)
-> Int -> ((Int -> m b) -> Int -> m b) -> m b
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m b) -> Int -> m b) -> Int -> m b
forall a. (a -> a) -> a
fix Int
begin (((Int -> m b) -> Int -> m b) -> m b)
-> ((Int -> m b) -> Int -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \Int -> m b
loop !Int
i -> do
if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
then do
let nv :: Int
nv = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
adjacentBM Int
i
mnv <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv
if mnv == nothingBM
then do
UM.unsafeWrite matchBM v nv
UM.unsafeWrite matchBM nv v
k True
else do
dfs mnv $ \case
Bool
True -> do
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
v Int
nv
MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv Int
v
Bool -> m b
k Bool
True
Bool
False -> Int -> m b
loop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else Bool -> m b
k Bool
False
{-# INLINE dfsBM #-}
runBipartiteMatching ::
(PrimMonad m) =>
BipartiteMatching (PrimState m) ->
m Int
runBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> m Int
runBipartiteMatching bm :: BipartiteMatching (PrimState m)
bm@BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = do
res <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 Int
0
updated <- UM.replicate 1 True
fix $ \m Int
loop -> do
MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
False
Int -> (Int -> m ()) -> m ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
numVerticesBM ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
mi <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
matchBM Int
i
when (mi == nothingBM) $ do
dfsBM bm i $ \case
Bool
True -> do
MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
True
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
res (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
0
Bool
False -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
updated Int
0 m Bool -> (Bool -> m Int) -> m Int
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> do
MVector (PrimState m) Bool -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> a -> m ()
UM.set MVector (PrimState m) Bool
usedBM Bool
False
m Int
loop
Bool
False -> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
res Int
0
{-# INLINE runBipartiteMatching #-}
data BipartiteMatchingBuilder s = BipartiteMatchingBuilder
{ forall s. BipartiteMatchingBuilder s -> Int
numVerticesBMB :: !Int
, forall s. BipartiteMatchingBuilder s -> MVector s Int
inDegreeBMB :: UM.MVector s Int
, forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
edgesBMB :: Buffer s (Vertex, Vertex)
}
newBipartiteMatchingBuilder ::
(PrimMonad m) =>
Int ->
m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n =
Int
-> MVector (PrimState m) Int
-> Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m)
forall s.
Int
-> MVector s Int
-> Buffer s (Int, Int)
-> BipartiteMatchingBuilder s
BipartiteMatchingBuilder Int
n
(MVector (PrimState m) Int
-> Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m))
-> m (MVector (PrimState m) Int)
-> m (Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Int
0
m (Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m))
-> m (Buffer (PrimState m) (Int, Int))
-> m (BipartiteMatchingBuilder (PrimState m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m (Buffer (PrimState m) (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer (Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)
addEdgeBMB ::
(PrimMonad m) =>
BipartiteMatchingBuilder (PrimState m) ->
Vertex ->
Vertex ->
m ()
addEdgeBMB :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m) -> Int -> Int -> m ()
addEdgeBMB BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} !Int
src !Int
dst = do
MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
inDegreeBMB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
src
(Int, Int) -> Buffer (PrimState m) (Int, Int) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
src, Int
dst) Buffer (PrimState m) (Int, Int)
edgesBMB
{-# INLINE addEdgeBMB #-}
buildBipartiteMatching ::
(PrimMonad m) =>
BipartiteMatchingBuilder (PrimState m) ->
m (BipartiteMatching (PrimState m))
buildBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m)
-> m (BipartiteMatching (PrimState m))
buildBipartiteMatching BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} = do
let numVerticesBM :: Int
numVerticesBM = Int
numVerticesBMB
matchBM <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numVerticesBM Int
nothingBM
usedBM <- UM.replicate numVerticesBM False
offsetBM <- U.scanl' (+) 0 <$!> U.unsafeFreeze inDegreeBMB
madjacentBM <- UM.unsafeNew (U.last offsetBM)
moffset <- U.thaw offsetBM
edges <- unsafeFreezeBuffer edgesBMB
U.forM_ edges $ \(Int
src, Int
dst) -> do
offset <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
moffset Int
src
UM.unsafeWrite moffset src (offset + 1)
UM.unsafeWrite madjacentBM offset dst
adjacentBM <- U.unsafeFreeze madjacentBM
return BipartiteMatching{..}
{-# INLINE buildBipartiteMatching #-}