{-# LANGUAGE LambdaCase #-}
module Data.Graph.Dense.Prim where
import Control.Monad (when)
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import My.Prelude (rep, (..<))
primDense ::
(U.Unbox w, Num w, Ord w) =>
Int ->
Int ->
U.Vector w ->
U.Vector Int
primDense :: forall w.
(Unbox w, Num w, Ord w) =>
Int -> Int -> Vector w -> Vector Int
primDense Int
n Int
root Vector w
gr
| Int
root Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n Bool -> Bool -> Bool
|| Vector w -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector w
gr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n =
[Char] -> Vector Int
forall a. HasCallStack => [Char] -> a
error [Char]
"primDense: Invalid Arguments"
| Bool
otherwise = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
let !inf :: w
inf = w
2 w -> w -> w
forall a. Num a => a -> a -> a
* Vector w -> w
forall a. (Unbox a, Ord a) => Vector a -> a
U.maximum Vector w
gr
parent <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
dist <- UM.new n
used <- UM.replicate n False
UM.write dist root 0
UM.write used root True
rep n $ \Int
i -> do
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
root) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
parent Int
i Int
root
MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s w
MVector (PrimState (ST s)) w
dist Int
i (w -> ST s ()) -> w -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
gr (Int
root Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
rep (n - 1) $ \Int
_ -> do
v <-
((w, Int) -> Int) -> ST s (w, Int) -> ST s Int
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (w, Int) -> Int
forall a b. (a, b) -> b
snd
(ST s (w, Int) -> ST s Int) -> ST s (w, Int) -> ST s Int
forall a b. (a -> b) -> a -> b
$ ((w, Int) -> Int -> ST s (w, Int))
-> (w, Int) -> Stream (ST s) Int -> ST s (w, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
( \(w, Int)
acc Int
i -> do
MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
i ST s Bool -> (Bool -> ST s (w, Int)) -> ST s (w, Int)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
False -> do
d <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Int
i
return $! min acc (d, i)
Bool
True -> (w, Int) -> ST s (w, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (w, Int)
acc
)
(w
inf, -Int
1)
(Stream (ST s) Int -> ST s (w, Int))
-> Stream (ST s) Int -> ST s (w, Int)
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
n
UM.write used v True
rep n $ \Int
u -> do
MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
u ST s Bool -> (Bool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
False -> do
du <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Int
u
let dvu = Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
gr (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
u)
when (dvu < du) $ do
UM.unsafeWrite dist u dvu
UM.unsafeWrite parent u v
Bool
True -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
return parent
{-# INLINE primDense #-}