{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Graph.Sparse.Dijkstra where

import Control.Monad
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.Graph.Sparse
import Data.Heap.Binary

dijkstraSG ::
  (U.Unbox w, Num w, Ord w, Bounded w) =>
  Vertex ->
  SparseGraph w ->
  U.Vector w
dijkstraSG :: forall w.
(Unbox w, Num w, Ord w, Bounded w) =>
Vertex -> SparseGraph w -> Vector w
dijkstraSG Vertex
source gr :: SparseGraph w
gr@SparseGraph{Vertex
Vector w
Vector Vertex
numVerticesSG :: Vertex
numEdgesSG :: Vertex
offsetSG :: Vector Vertex
adjacentSG :: Vector Vertex
edgeCtxSG :: Vector w
numVerticesSG :: forall w. SparseGraph w -> Vertex
numEdgesSG :: forall w. SparseGraph w -> Vertex
offsetSG :: forall w. SparseGraph w -> Vector Vertex
adjacentSG :: forall w. SparseGraph w -> Vector Vertex
edgeCtxSG :: forall w. SparseGraph w -> Vector w
..} = (forall s. ST s (MVector s w)) -> Vector w
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s w)) -> Vector w)
-> (forall s. ST s (MVector s w)) -> Vector w
forall a b. (a -> b) -> a -> b
$ do
  MVector s w
dist <- Vertex -> w -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Vertex -> a -> m (MVector (PrimState m) a)
UM.replicate Vertex
numVerticesSG w
forall a. Bounded a => a
maxBound
  MinBinaryHeap s (w, Vertex)
heap <- Vertex -> ST s (MinBinaryHeap (PrimState (ST s)) (w, Vertex))
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vertex -> m (MinBinaryHeap (PrimState m) a)
newMinBinaryHeap (Vertex
numEdgesSG Vertex -> Vertex -> Vertex
forall a. Num a => a -> a -> a
+ Vertex
1)
  MVector (PrimState (ST s)) w -> Vertex -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.write MVector s w
MVector (PrimState (ST s)) w
dist Vertex
source w
0
  (w, Vertex)
-> MinBinaryHeap (PrimState (ST s)) (w, Vertex) -> ST s ()
forall (f :: * -> *) a (m :: * -> *).
(OrdVia f a, Unbox a, PrimMonad m) =>
a -> BinaryHeap f (PrimState m) a -> m ()
insertBH (w
0, Vertex
source) MinBinaryHeap s (w, Vertex)
MinBinaryHeap (PrimState (ST s)) (w, Vertex)
heap
  (ST s () -> ST s ()) -> ST s ()
forall a. (a -> a) -> a
fix ((ST s () -> ST s ()) -> ST s ())
-> (ST s () -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ST s ()
loop -> do
    MinBinaryHeap (PrimState (ST s)) (w, Vertex)
-> ST s (Maybe (w, Vertex))
forall (f :: * -> *) a (m :: * -> *).
(OrdVia f a, Unbox a, PrimMonad m) =>
BinaryHeap f (PrimState m) a -> m (Maybe a)
deleteFindTopBH MinBinaryHeap s (w, Vertex)
MinBinaryHeap (PrimState (ST s)) (w, Vertex)
heap ST s (Maybe (w, Vertex))
-> (Maybe (w, Vertex) -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (w
d, Vertex
v) -> do
        w
dv <- MVector (PrimState (ST s)) w -> Vertex -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Vertex
v
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (w
dv w -> w -> Bool
forall a. Eq a => a -> a -> Bool
== w
d) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
          Vector (Vertex, w) -> ((Vertex, w) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ (SparseGraph w
gr SparseGraph w -> Vertex -> Vector (Vertex, w)
forall w. Unbox w => SparseGraph w -> Vertex -> Vector (Vertex, w)
`adjW` Vertex
v) (((Vertex, w) -> ST s ()) -> ST s ())
-> ((Vertex, w) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Vertex
nv, w
w) -> do
            w
dnv <- MVector (PrimState (ST s)) w -> Vertex -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Vertex
nv
            Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (w
dv w -> w -> w
forall a. Num a => a -> a -> a
+ w
w w -> w -> Bool
forall a. Ord a => a -> a -> Bool
< w
dnv) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
              MVector (PrimState (ST s)) w -> Vertex -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
dist Vertex
nv (w -> ST s ()) -> w -> ST s ()
forall a b. (a -> b) -> a -> b
$ w
dv w -> w -> w
forall a. Num a => a -> a -> a
+ w
w
              (w, Vertex)
-> MinBinaryHeap (PrimState (ST s)) (w, Vertex) -> ST s ()
forall (f :: * -> *) a (m :: * -> *).
(OrdVia f a, Unbox a, PrimMonad m) =>
a -> BinaryHeap f (PrimState m) a -> m ()
insertBH (w
dv w -> w -> w
forall a. Num a => a -> a -> a
+ w
w, Vertex
nv) MinBinaryHeap s (w, Vertex)
MinBinaryHeap (PrimState (ST s)) (w, Vertex)
heap
        ST s ()
loop
      Maybe (w, Vertex)
Nothing -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  MVector s w -> ST s (MVector s w)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s w
dist