{-# LANGUAGE LambdaCase #-}

module Data.Graph.Dense.Dijkstra where

import Control.Monad (when)
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import My.Prelude (rep)

-- | O(V^2)
dijkstraDense ::
  (U.Unbox w, Num w, Ord w, Bounded w) =>
  -- | n
  Int ->
  -- | src
  Int ->
  -- | adjacent matrix (n x n)
  U.Vector w ->
  U.Vector w
dijkstraDense :: forall w.
(Unbox w, Num w, Ord w, Bounded w) =>
Int -> Int -> Vector w -> Vector w
dijkstraDense Int
n Int
src Vector w
gr
  | Int
src Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n Bool -> Bool -> Bool
|| Vector w -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector w
gr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n =
    [Char] -> Vector w
forall a. HasCallStack => [Char] -> a
error [Char]
"dijkstraDense: Invalid Arguments"
  | Bool
otherwise = (forall s. ST s (MVector s w)) -> Vector w
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s w)) -> Vector w)
-> (forall s. ST s (MVector s w)) -> Vector w
forall a b. (a -> b) -> a -> b
$ do
    dist <- Int -> w -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) w
forall a. Bounded a => a
maxBound
    used <- UM.replicate n False
    UM.write dist src 0
    let nothing = Int
n
    _v <- UM.replicate 1 0
    _dv <- UM.replicate 1 0
    fix $ \ST s ()
loop -> do
      MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
_v Int
0 Int
nothing
      MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
_dv Int
0 w
forall a. Bounded a => a
maxBound
      Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
i ST s Bool -> (Bool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
False -> do
            di <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Int
i
            dv <- UM.unsafeRead _dv 0
            when (di < dv) $ do
              UM.unsafeWrite _v 0 i
              UM.unsafeWrite _dv 0 di
          Bool
True -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      v <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
_v Int
0
      when (v /= nothing) $ do
        UM.unsafeWrite used v True
        dv <- UM.unsafeRead _dv 0
        rep n $ \Int
i -> do
          let di' :: w
di' = w
dv w -> w -> w
forall a. Num a => a -> a -> a
+ Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
gr (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
          MVector (PrimState (ST s)) w -> (w -> w) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s w
MVector (PrimState (ST s)) w
dist (w -> w -> w
forall a. Ord a => a -> a -> a
min w
di') Int
i
        loop
    return dist
{-# INLINE dijkstraDense #-}