{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.BipartiteMatching where

import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

--
import Data.Buffer
import My.Prelude (rep)

type Vertex = Int

bipartiteMatching ::
  -- | number of vertices
  Int ->
  (forall s. BipartiteMatchingBuilder s -> ST s ()) ->
  Int
bipartiteMatching :: Int -> (forall s. BipartiteMatchingBuilder s -> ST s ()) -> Int
bipartiteMatching Int
n forall s. BipartiteMatchingBuilder s -> ST s ()
run = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
  builder <- Int -> ST s (BipartiteMatchingBuilder (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n
  run builder
  buildBipartiteMatching builder >>= runBipartiteMatching

data BipartiteMatching s = BipartiteMatching
  { forall s. BipartiteMatching s -> Int
numVerticesBM :: !Int
  , forall s. BipartiteMatching s -> MVector s Int
matchBM :: !(UM.MVector s Int)
  , forall s. BipartiteMatching s -> MVector s Bool
usedBM :: !(UM.MVector s Bool)
  , forall s. BipartiteMatching s -> Vector Int
offsetBM :: !(U.Vector Int)
  , forall s. BipartiteMatching s -> Vector Int
adjacentBM :: !(U.Vector Int)
  }

nothingBM :: Int
nothingBM :: Int
nothingBM = -Int
1
{-# INLINE nothingBM #-}

dfsBM :: (PrimMonad m) => BipartiteMatching (PrimState m) -> Vertex -> (Bool -> m ()) -> m ()
dfsBM :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> Int -> (Bool -> m ()) -> m ()
dfsBM BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = Int -> (Bool -> m ()) -> m ()
forall {m :: * -> *} {b}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> (Bool -> m b) -> m b
dfs
  where
    dfs :: Int -> (Bool -> m b) -> m b
dfs !Int
v Bool -> m b
k =
      MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v m Bool -> (Bool -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> Bool -> m b
k Bool
False
        Bool
False -> do
          MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
MVector (PrimState m) Bool
usedBM Int
v Bool
True
          let begin :: Int
begin = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM Int
v
          let end :: Int
end = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
offsetBM (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          (((Int -> m b) -> Int -> m b) -> Int -> m b)
-> Int -> ((Int -> m b) -> Int -> m b) -> m b
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m b) -> Int -> m b) -> Int -> m b
forall a. (a -> a) -> a
fix Int
begin (((Int -> m b) -> Int -> m b) -> m b)
-> ((Int -> m b) -> Int -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \Int -> m b
loop !Int
i -> do
            if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
end
              then do
                let nv :: Int
nv = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
adjacentBM Int
i
                mnv <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv
                if mnv == nothingBM
                  then do
                    UM.unsafeWrite matchBM v nv
                    UM.unsafeWrite matchBM nv v
                    k True
                  else do
                    dfs mnv $ \case
                      Bool
True -> do
                        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
v Int
nv
                        MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Int
MVector (PrimState m) Int
matchBM Int
nv Int
v
                        Bool -> m b
k Bool
True
                      Bool
False -> Int -> m b
loop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              else Bool -> m b
k Bool
False
{-# INLINE dfsBM #-}

runBipartiteMatching ::
  (PrimMonad m) =>
  BipartiteMatching (PrimState m) ->
  m Int
runBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatching (PrimState m) -> m Int
runBipartiteMatching bm :: BipartiteMatching (PrimState m)
bm@BipartiteMatching{Int
MVector (PrimState m) Bool
MVector (PrimState m) Int
Vector Int
numVerticesBM :: forall s. BipartiteMatching s -> Int
matchBM :: forall s. BipartiteMatching s -> MVector s Int
usedBM :: forall s. BipartiteMatching s -> MVector s Bool
offsetBM :: forall s. BipartiteMatching s -> Vector Int
adjacentBM :: forall s. BipartiteMatching s -> Vector Int
numVerticesBM :: Int
matchBM :: MVector (PrimState m) Int
usedBM :: MVector (PrimState m) Bool
offsetBM :: Vector Int
adjacentBM :: Vector Int
..} = do
  res <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 Int
0
  updated <- UM.replicate 1 True
  fix $ \m Int
loop -> do
    MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
False
    Int -> (Int -> m ()) -> m ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
numVerticesBM ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      mi <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
matchBM Int
i
      when (mi == nothingBM) $ do
        dfsBM bm i $ \case
          Bool
True -> do
            MVector (PrimState m) Bool -> Int -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector (PrimState m) Bool
updated Int
0 Bool
True
            MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
res (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
0
          Bool
False -> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    MVector (PrimState m) Bool -> Int -> m Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Bool
updated Int
0 m Bool -> (Bool -> m Int) -> m Int
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Bool
True -> do
        MVector (PrimState m) Bool -> Bool -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> a -> m ()
UM.set MVector (PrimState m) Bool
usedBM Bool
False
        m Int
loop
      Bool
False -> MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
res Int
0
{-# INLINE runBipartiteMatching #-}

data BipartiteMatchingBuilder s = BipartiteMatchingBuilder
  { forall s. BipartiteMatchingBuilder s -> Int
numVerticesBMB :: !Int
  , forall s. BipartiteMatchingBuilder s -> MVector s Int
inDegreeBMB :: UM.MVector s Int
  , forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
edgesBMB :: Buffer s (Vertex, Vertex)
  }

newBipartiteMatchingBuilder ::
  (PrimMonad m) =>
  Int ->
  m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder :: forall (m :: * -> *).
PrimMonad m =>
Int -> m (BipartiteMatchingBuilder (PrimState m))
newBipartiteMatchingBuilder Int
n =
  Int
-> MVector (PrimState m) Int
-> Buffer (PrimState m) (Int, Int)
-> BipartiteMatchingBuilder (PrimState m)
forall s.
Int
-> MVector s Int
-> Buffer s (Int, Int)
-> BipartiteMatchingBuilder s
BipartiteMatchingBuilder Int
n
    (MVector (PrimState m) Int
 -> Buffer (PrimState m) (Int, Int)
 -> BipartiteMatchingBuilder (PrimState m))
-> m (MVector (PrimState m) Int)
-> m (Buffer (PrimState m) (Int, Int)
      -> BipartiteMatchingBuilder (PrimState m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Int
0
    m (Buffer (PrimState m) (Int, Int)
   -> BipartiteMatchingBuilder (PrimState m))
-> m (Buffer (PrimState m) (Int, Int))
-> m (BipartiteMatchingBuilder (PrimState m))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m (Buffer (PrimState m) (Int, Int))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Int -> m (Buffer (PrimState m) a)
newBuffer (Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)

addEdgeBMB ::
  (PrimMonad m) =>
  BipartiteMatchingBuilder (PrimState m) ->
  Vertex ->
  Vertex ->
  m ()
addEdgeBMB :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m) -> Int -> Int -> m ()
addEdgeBMB BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} !Int
src !Int
dst = do
  MVector (PrimState m) Int -> (Int -> Int) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) Int
inDegreeBMB (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
src
  (Int, Int) -> Buffer (PrimState m) (Int, Int) -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack (Int
src, Int
dst) Buffer (PrimState m) (Int, Int)
edgesBMB
{-# INLINE addEdgeBMB #-}

buildBipartiteMatching ::
  (PrimMonad m) =>
  BipartiteMatchingBuilder (PrimState m) ->
  m (BipartiteMatching (PrimState m))
buildBipartiteMatching :: forall (m :: * -> *).
PrimMonad m =>
BipartiteMatchingBuilder (PrimState m)
-> m (BipartiteMatching (PrimState m))
buildBipartiteMatching BipartiteMatchingBuilder{Int
MVector (PrimState m) Int
Buffer (PrimState m) (Int, Int)
numVerticesBMB :: forall s. BipartiteMatchingBuilder s -> Int
inDegreeBMB :: forall s. BipartiteMatchingBuilder s -> MVector s Int
edgesBMB :: forall s. BipartiteMatchingBuilder s -> Buffer s (Int, Int)
numVerticesBMB :: Int
inDegreeBMB :: MVector (PrimState m) Int
edgesBMB :: Buffer (PrimState m) (Int, Int)
..} = do
  let numVerticesBM :: Int
numVerticesBM = Int
numVerticesBMB
  matchBM <- Int -> Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
numVerticesBM Int
nothingBM
  usedBM <- UM.replicate numVerticesBM False
  offsetBM <- U.scanl' (+) 0 <$!> U.unsafeFreeze inDegreeBMB
  madjacentBM <- UM.unsafeNew (U.last offsetBM)
  moffset <- U.thaw offsetBM
  edges <- unsafeFreezeBuffer edgesBMB
  U.forM_ edges $ \(Int
src, Int
dst) -> do
    offset <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) Int
moffset Int
src
    UM.unsafeWrite moffset src (offset + 1)
    UM.unsafeWrite madjacentBM offset dst
  adjacentBM <- U.unsafeFreeze madjacentBM
  return BipartiteMatching{..}
{-# INLINE buildBipartiteMatching #-}