module Data.Graph.BellmanFord where

import Control.Monad
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import My.Prelude (rep)

type Vertex = Int

{- | Bellman-Ford O(VE)

 dist[v] == maxBound iff v is unreachable

 dist[v] == minBound iff v in negative cycle

>>> bellmanFord 2 0 (U.singleton (0, 1, -1))
[0,-1]
>>> bellmanFord 2 0 U.empty
[0,9223372036854775807]
>>> bellmanFord 1 0 (U.singleton (0, 0, -1))
[-9223372036854775808]
>>> bellmanFord 2 0 (U.fromList [(0, 1, -1), (1, 0, -1)])
[-9223372036854775808,-9223372036854775808]
-}
bellmanFord :: Int -> Vertex -> U.Vector (Vertex, Vertex, Int) -> U.Vector Int
bellmanFord :: Int -> Int -> Vector (Int, Int, Int) -> Vector Int
bellmanFord Int
n Int
root Vector (Int, Int, Int)
edges = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
  MVector s Int
dist <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Int
forall a. Bounded a => a
maxBound
  MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
dist Int
root Int
0
  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
_ -> do
    Vector (Int, Int, Int) -> ((Int, Int, Int) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Int, Int, Int)
edges (((Int, Int, Int) -> ST s ()) -> ST s ())
-> ((Int, Int, Int) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
src, Int
dst, Int
cost) -> do
      Int
dv <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
dist Int
src
      Int
du <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
dist Int
dst
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
dv Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cost Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
du Bool -> Bool -> Bool
&& Int
dv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
forall a. Bounded a => a
maxBound) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
dist Int
dst (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
dv Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cost

  Vector (Int, Int, Int) -> ((Int, Int, Int) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Int, Int, Int)
edges (((Int, Int, Int) -> ST s ()) -> ST s ())
-> ((Int, Int, Int) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
src, Int
dst, Int
cost) -> do
    Int
dv <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
dist Int
src
    Int
du <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
dist Int
dst
    Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
dv Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cost Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
du Bool -> Bool -> Bool
&& Int
dv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
forall a. Bounded a => a
maxBound) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
      MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
dist Int
dst Int
forall a. Bounded a => a
minBound

  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
_ -> do
    Vector (Int, Int, Int) -> ((Int, Int, Int) -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector (Int, Int, Int)
edges (((Int, Int, Int) -> ST s ()) -> ST s ())
-> ((Int, Int, Int) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(Int
src, Int
dst, Int
_) -> do
      Int
dv <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
dist Int
src
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
dv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
forall a. Bounded a => a
minBound) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
dist Int
dst Int
forall a. Bounded a => a
minBound

  MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Int
dist