{-# LANGUAGE LambdaCase #-}

module Data.Graph.Dense.Dijkstra where

import Control.Monad (when)
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import My.Prelude (rep)

-- | O(V^2)
dijkstraDense ::
  (U.Unbox w, Num w, Ord w, Bounded w) =>
  -- | n
  Int ->
  -- | src
  Int ->
  -- | adjacent matrix (n x n)
  U.Vector w ->
  U.Vector w
dijkstraDense :: forall w.
(Unbox w, Num w, Ord w, Bounded w) =>
Int -> Int -> Vector w -> Vector w
dijkstraDense Int
n Int
src Vector w
gr
  | Int
src Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n Bool -> Bool -> Bool
|| Vector w -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector w
gr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n =
    [Char] -> Vector w
forall a. HasCallStack => [Char] -> a
error [Char]
"dijkstraDense: Invalid Arguments"
  | Bool
otherwise = (forall s. ST s (MVector s w)) -> Vector w
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s w)) -> Vector w)
-> (forall s. ST s (MVector s w)) -> Vector w
forall a b. (a -> b) -> a -> b
$ do
    MVector s w
dist <- Int -> w -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) w
forall a. Bounded a => a
maxBound
    MVector s Bool
used <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Bool
False
    MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s w
MVector (PrimState (ST s)) w
dist Int
src w
0
    let nothing :: Int
nothing = Int
n
    MVector s Int
_v <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 Int
0
    MVector s w
_dv <- Int -> w -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
1 w
0
    (ST s () -> ST s ()) -> ST s ()
forall a. (a -> a) -> a
fix ((ST s () -> ST s ()) -> ST s ())
-> (ST s () -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ST s ()
loop -> do
      MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
_v Int
0 Int
nothing
      MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
_dv Int
0 w
forall a. Bounded a => a
maxBound
      Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
i ST s Bool -> (Bool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
False -> do
            w
di <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Int
i
            w
dv <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
_dv Int
0
            Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (w
di w -> w -> Bool
forall a. Ord a => a -> a -> Bool
< w
dv) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
              MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
_v Int
0 Int
i
              MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
_dv Int
0 w
di
          Bool
True -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Int
v <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
_v Int
0
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nothing) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
v Bool
True
        w
dv <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
_dv Int
0
        Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          let di' :: w
di' = w
dv w -> w -> w
forall a. Num a => a -> a -> a
+ Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
gr (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
          MVector (PrimState (ST s)) w -> (w -> w) -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector s w
MVector (PrimState (ST s)) w
dist (w -> w -> w
forall a. Ord a => a -> a -> a
min w
di') Int
i
        ST s ()
loop
    MVector s w -> ST s (MVector s w)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s w
dist
{-# INLINE dijkstraDense #-}