{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.BipartiteMatching where

import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

--
import Data.Buffer
import My.Prelude (rep)

type Vertex = Int

bipartiteMatching ::
  -- | number of vertices
  Int ->
  (forall s. BipartiteMatchingBuilder s -> ST s ()) ->
  Int
bipartiteMatching :: Int -> (forall s. BipartiteMatchingBuilder s -> ST s ()) -> Int
bipartiteMatching Int
n forall s. BipartiteMatchingBuilder s -> ST s ()
run = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
  BipartiteMatchingBuilder s
builder <- Int -> ST s (BipartiteMatchingBuilder (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n
  BipartiteMatchingBuilder s -> ST s ()
forall s. BipartiteMatchingBuilder s -> ST s ()
run BipartiteMatchingBuilder s
builder
  BipartiteMatchingBuilder (PrimState (ST s))
-> ST s (BipartiteMatching (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m)
-> m (BipartiteMatching (PrimState m))
buildBipartiteMatching BipartiteMatchingBuilder s
BipartiteMatchingBuilder (PrimState (ST s))
builder ST s (BipartiteMatching s)
-> (BipartiteMatching s -> ST s Int) -> ST s Int
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= BipartiteMatching s -> ST s Int
BipartiteMatching (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> m Int
runBipartiteMatching

data BipartiteMatching s = BipartiteMatching
  { forall s. BipartiteMatching s -> Int
numVerticesBM :: !Int
  , forall s. BipartiteMatching s -> MVector s Int
matchBM :: !(UM.MVector s Int)
  , forall s. BipartiteMatching s -> MVector s Bool
usedBM :: !(UM.MVector s Bool)
  , forall s. BipartiteMatching s -> Vector Int
offsetBM :: !(U.Vector Int)
  , forall s. BipartiteMatching s -> Vector Int
adjacentBM :: !(U.Vector Int)
  }

nothingBM :: Int
nothingBM :: Int
nothingBM = -Int
1
{-# INLINE nothingBM #-}

dfsBM :: (PrimMonad m) => BipartiteMatching (PrimState m) -> Vertex -> (Bool -> m ()) -> m ()
dfsBM :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
dfsBM BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = Int -> (Bool -> m ()) -> m ()
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> (Bool -> m b) -> m b
dfs
  where
    dfs :: Int -> (Bool -> m b) -> m b
dfs !Int
v Bool -> m b
k =
      MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v m Bool -> (Bool -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> Bool -> m b
k Bool
False
        Bool
False -> do
          MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v Bool
True
          let begin :: Int
begin = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM Int
v
          let end :: Int
end = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          (((Int -> m b) -> Int -> m b) -> Int -> m b)
-> Int -> ((Int -> m b) -> Int -> m b) -> m b
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m b) -> Int -> m b) -> Int -> m b
forall a. (a -> a) -> a
fix Int
begin (((Int -> m b) -> Int -> m b) -> m b)
-> ((Int -> m b) -> Int -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \Int -> m b
loop !Int
i -> do
            if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
              then do
                let nv :: Int
nv = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
adjacentBM Int
i
                Int
mnv <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv
                if Int
mnv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nothingBM
                  then do
                    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
v Int
nv
                    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv Int
v
                    Bool -> m b
k Bool
True
                  else do
                    Int -> (Bool -> m b) -> m b
dfs Int
mnv ((Bool -> m b) -> m b) -> (Bool -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \case
                      Bool
True -> do
                        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
v Int
nv
                        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv Int
v
                        Bool -> m b
k Bool
True
                      Bool
False -> Int -> m b
loop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              else Bool -> m b
k Bool
False
{-# INLINE dfsBM #-}

runBipartiteMatching ::
  (PrimMonad m) =>
  BipartiteMatching (PrimState m) ->
  m Int
runBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> m Int
runBipartiteMatching bm :: BipartiteMatching (PrimState m)
bm@BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = do
  MVector (PrimState m) Int
res <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 Int
0
  MVector (PrimState m) Bool
updated <- Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 Bool
True
  (m Int -> m Int) -> m Int
forall a. (a -> a) -> a
fix ((m Int -> m Int) -> m Int) -> (m Int -> m Int) -> m Int
forall a b. (a -> b) -> a -> b
$ \m Int
loop -> do
    MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
False
    Int -> (Int -> m ()) -> m ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
numVerticesBM ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      Int
mi <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
matchBM Int
i
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
mi Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nothingBM) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
dfsBM BipartiteMatching (PrimState m)
bm Int
i ((Bool -> m ()) -> m ()) -> (Bool -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \case
          Bool
True -> do
            MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
True
            MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
res (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
0
          Bool
False -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
updated Int
0 m Bool -> (Bool -> m Int) -> m Int
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Bool
True -> do
        MVector (PrimState m) Bool -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> a -> m ()
UM.set MVector (PrimState m) Bool
usedBM Bool
False
        m Int
loop
      Bool
False -> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
res Int
0
{-# INLINE runBipartiteMatching #-}

data BipartiteMatchingBuilder s = BipartiteMatchingBuilder
  { forall s. BipartiteMatchingBuilder s -> Int
numVerticesBMB :: !Int
  , forall s. BipartiteMatchingBuilder s -> MVector s Int
inDegreeBMB :: UM.MVector s Int
  , forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
edgesBMB :: Buffer s (Vertex, Vertex)
  }

newBipartiteMatchingBuilder ::
  (PrimMonad m) =>
  Int ->
  m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n =
  Int
-> MVector (PrimState m) Int
-> Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m)
forall s.
Int
-> MVector s Int
-> Buffer s (Int, Int)
-> BipartiteMatchingBuilder s
BipartiteMatchingBuilder Int
n
    (MVector (PrimState m) Int
 -> Buffer (PrimState m) (Int, Int)
 -> BipartiteMatchingBuilder (PrimState m))
-> m (MVector (PrimState m) Int)
-> m (Buffer (PrimState m) (Int, Int)
      -> BipartiteMatchingBuilder (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Int
0
    m (Buffer (PrimState m) (Int, Int)
   -> BipartiteMatchingBuilder (PrimState m))
-> m (Buffer (PrimState m) (Int, Int))
-> m (BipartiteMatchingBuilder (PrimState m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m (Buffer (PrimState m) (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer (Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)

addEdgeBMB ::
  (PrimMonad m) =>
  BipartiteMatchingBuilder (PrimState m) ->
  Vertex ->
  Vertex ->
  m ()
addEdgeBMB :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m) -> Int -> Int -> m ()
addEdgeBMB BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} !Int
src !Int
dst = do
  MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
inDegreeBMB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
src
  (Int, Int) -> Buffer (PrimState m) (Int, Int) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
src, Int
dst) Buffer (PrimState m) (Int, Int)
edgesBMB
{-# INLINE addEdgeBMB #-}

buildBipartiteMatching ::
  (PrimMonad m) =>
  BipartiteMatchingBuilder (PrimState m) ->
  m (BipartiteMatching (PrimState m))
buildBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m)
-> m (BipartiteMatching (PrimState m))
buildBipartiteMatching BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} = do
  let numVerticesBM :: Int
numVerticesBM = Int
numVerticesBMB
  MVector (PrimState m) Int
matchBM <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numVerticesBM Int
nothingBM
  MVector (PrimState m) Bool
usedBM <- Int -> Bool -> m (MVector (PrimState m) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numVerticesBM Bool
False
  Vector Int
offsetBM <- (Int -> Int -> Int) -> Int -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 (Vector Int -> Vector Int) -> m (Vector Int) -> m (Vector Int)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Int
inDegreeBMB
  MVector (PrimState m) Int
madjacentBM <- Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Vector Int -> Int
forall a. Unbox a => Vector a -> a
U.last Vector Int
offsetBM)
  MVector (PrimState m) Int
moffset <- Vector Int -> m (MVector (PrimState m) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw Vector Int
offsetBM
  Vector (Int, Int)
edges <- Buffer (PrimState m) (Int, Int) -> m (Vector (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Vector a)
unsafeFreezeBuffer Buffer (PrimState m) (Int, Int)
edgesBMB
  Vector (Int, Int) -> ((Int, Int) -> m ()) -> m ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Int, Int)
edges (((Int, Int) -> m ()) -> m ()) -> ((Int, Int) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Int
src, Int
dst) -> do
    Int
offset <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
moffset Int
src
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
moffset Int
src (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
madjacentBM Int
offset Int
dst
  Vector Int
adjacentBM <- MVector (PrimState m) Int -> m (Vector Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector (PrimState m) Int
madjacentBM
  BipartiteMatching (PrimState m)
-> m (BipartiteMatching (PrimState m))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..}
{-# INLINE buildBipartiteMatching #-}