{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.Sparse.BFS where

import Control.Monad
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Buffer
import Data.Graph.Sparse

-- | /O(V+E)/
bfsSG :: Vertex -> SparseGraph w -> U.Vector Int
bfsSG :: forall w. Vertex -> SparseGraph w -> Vector Vertex
bfsSG Vertex
source gr :: SparseGraph w
gr@SparseGraph{Vertex
Vector w
Vector Vertex
numVerticesSG :: Vertex
numEdgesSG :: Vertex
offsetSG :: Vector Vertex
adjacentSG :: Vector Vertex
edgeCtxSG :: Vector w
numVerticesSG :: forall w. SparseGraph w -> Vertex
numEdgesSG :: forall w. SparseGraph w -> Vertex
offsetSG :: forall w. SparseGraph w -> Vector Vertex
adjacentSG :: forall w. SparseGraph w -> Vector Vertex
edgeCtxSG :: forall w. SparseGraph w -> Vector w
..} = (forall s. ST s (MVector s Vertex)) -> Vector Vertex
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Vertex)) -> Vector Vertex)
-> (forall s. ST s (MVector s Vertex)) -> Vector Vertex
forall a b. (a -> b) -> a -> b
$ do
  MVector s Vertex
dist <- Vertex -> Vertex -> ST s (MVector (PrimState (ST s)) Vertex)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Vertex -> a -> m (MVector (PrimState m) a)
UM.replicate Vertex
numVerticesSG Vertex
forall a. Bounded a => a
maxBound
  Buffer s Vertex
que <- Vertex -> ST s (Buffer (PrimState (ST s)) Vertex)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vertex -> m (Buffer (PrimState m) a)
newBufferAsQueue (Vertex
numEdgesSG Vertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+ Vertex
1)
  MVector (PrimState (ST s)) Vertex -> Vertex -> Vertex -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.write MVector s Vertex
MVector (PrimState (ST s)) Vertex
dist Vertex
source Vertex
0
  Vertex -> Buffer (PrimState (ST s)) Vertex -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Vertex
source Buffer s Vertex
Buffer (PrimState (ST s)) Vertex
que
  (ST s () -> ST s ()) -> ST s ()
forall a. (a -> a) -> a
fix ((ST s () -> ST s ()) -> ST s ())
-> (ST s () -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ST s ()
loop -> do
    Buffer (PrimState (ST s)) Vertex -> ST s (Maybe Vertex)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Buffer (PrimState m) a -> m (Maybe a)
popFront Buffer s Vertex
Buffer (PrimState (ST s)) Vertex
que ST s (Maybe Vertex) -> (Maybe Vertex -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just Vertex
v -> do
        Vertex
dv <- MVector (PrimState (ST s)) Vertex -> Vertex -> ST s Vertex
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> m a
UM.unsafeRead MVector s Vertex
MVector (PrimState (ST s)) Vertex
dist Vertex
v
        Vector Vertex -> (Vertex -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Vertex -> Vector Vertex
forall w. SparseGraph w -> Vertex -> Vector Vertex
`adj` Vertex
v) ((Vertex -> ST s ()) -> ST s ()) -> (Vertex -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Vertex
nv -> do
          Vertex
dnv <- MVector (PrimState (ST s)) Vertex -> Vertex -> ST s Vertex
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> m a
UM.unsafeRead MVector s Vertex
MVector (PrimState (ST s)) Vertex
dist Vertex
nv
          Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Vertex
dnv Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
== Vertex
forall a. Bounded a => a
maxBound) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            MVector (PrimState (ST s)) Vertex -> Vertex -> Vertex -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.unsafeWrite MVector s Vertex
MVector (PrimState (ST s)) Vertex
dist Vertex
nv (Vertex -> ST s ()) -> Vertex -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vertex
dv Vertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+ Vertex
1
            Vertex -> Buffer (PrimState (ST s)) Vertex -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
a -> Buffer (PrimState m) a -> m ()
pushBack Vertex
nv Buffer s Vertex
Buffer (PrimState (ST s)) Vertex
que
        ST s ()
loop
      Maybe Vertex
Nothing -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  MVector s Vertex -> ST s (MVector s Vertex)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Vertex
dist