{-# LANGUAGE LambdaCase #-}

module Data.Graph.Dense.Prim where

import Control.Monad (when)
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import My.Prelude (rep, (..<))

-- | /O(V^2)/
primDense ::
  (U.Unbox w, Num w, Ord w) =>
  -- | n
  Int ->
  -- | root
  Int ->
  -- | adjacent matrix (n x n)
  U.Vector w ->
  -- | parent (parent[root] = -1)
  U.Vector Int
primDense :: forall w.
(Unbox w, Num w, Ord w) =>
Int -> Int -> Vector w -> Vector Int
primDense Int
n Int
root Vector w
gr
  | Int
root Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n Bool -> Bool -> Bool
|| Vector w -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector w
gr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n =
      [Char] -> Vector Int
forall a. HasCallStack => [Char] -> a
error [Char]
"primDense: Invalid Arguments"
  | Bool
otherwise = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
      let !inf :: w
inf = w
2 w -> w -> w
forall a. Num a => a -> a -> a
* Vector w -> w
forall a. (Unbox a, Ord a) => Vector a -> a
U.maximum Vector w
gr
      MVector s Int
parent <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n (-Int
1)
      MVector s w
dist <- Int -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.new Int
n
      MVector s Bool
used <- Int -> Bool -> ST s (MVector (PrimState (ST s)) Bool)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Bool
False
      MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s w
MVector (PrimState (ST s)) w
dist Int
root w
0
      MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
root Bool
True
      Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
root) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
          MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
parent Int
i Int
root
        MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s w
MVector (PrimState (ST s)) w
dist Int
i (w -> ST s ()) -> w -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
gr (Int
root Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
      Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
_ -> do
        Int
v <-
          ((w, Int) -> Int) -> ST s (w, Int) -> ST s Int
forall a b. (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (w, Int) -> Int
forall a b. (a, b) -> b
snd
            (ST s (w, Int) -> ST s Int) -> ST s (w, Int) -> ST s Int
forall a b. (a -> b) -> a -> b
$ ((w, Int) -> Int -> ST s (w, Int))
-> (w, Int) -> Stream (ST s) Int -> ST s (w, Int)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
              ( \(w, Int)
acc Int
i -> do
                  MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
i ST s Bool -> (Bool -> ST s (w, Int)) -> ST s (w, Int)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                    Bool
False -> do
                      w
d <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Int
i
                      (w, Int) -> ST s (w, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ((w, Int) -> ST s (w, Int)) -> (w, Int) -> ST s (w, Int)
forall a b. (a -> b) -> a -> b
$! (w, Int) -> (w, Int) -> (w, Int)
forall a. Ord a => a -> a -> a
min (w, Int)
acc (w
d, Int
i)
                    Bool
True -> (w, Int) -> ST s (w, Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (w, Int)
acc
              )
              (w
inf, -Int
1)
            (Stream (ST s) Int -> ST s (w, Int))
-> Stream (ST s) Int -> ST s (w, Int)
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
n
        MVector (PrimState (ST s)) Bool -> Int -> Bool -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
v Bool
True
        Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
u -> do
          MVector (PrimState (ST s)) Bool -> Int -> ST s Bool
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Bool
MVector (PrimState (ST s)) Bool
used Int
u ST s Bool -> (Bool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Bool
False -> do
              w
du <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dist Int
u
              let dvu :: w
dvu = Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
gr (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
u)
              Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (w
dvu w -> w -> Bool
forall a. Ord a => a -> a -> Bool
< w
du) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
                MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
dist Int
u w
dvu
                MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
parent Int
u Int
v
            Bool
True -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Int
parent
{-# INLINE primDense #-}