module Data.Graph.Tree.Rerooting where

import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

--
import Data.Graph.Sparse

rerootingDP ::
  forall w a m.
  (U.Unbox w, U.Unbox a, U.Unbox m, Monoid m) =>
  SparseGraph w ->
  (a -> m) ->
  (Vertex -> m -> a) ->
  U.Vector a
rerootingDP :: forall w a m.
(Unbox w, Unbox a, Unbox m, Monoid m) =>
SparseGraph w -> (a -> m) -> (Vertex -> m -> a) -> Vector a
rerootingDP SparseGraph w
gr a -> m
toM Vertex -> m -> a
foldChildren = Vector a
dp2
  where
    root :: Vertex
    root :: Vertex
root = Vertex
0

    dp1 :: U.Vector a
    !dp1 :: Vector a
dp1 = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
      MVector s a
dp <- Vertex -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Vertex -> m (MVector (PrimState m) a)
UM.unsafeNew (SparseGraph w -> Vertex
forall w. SparseGraph w -> Vertex
numVerticesSG SparseGraph w
gr)
      a
_ <-
        ((Vertex -> Vertex -> ST s a) -> Vertex -> Vertex -> ST s a)
-> Vertex -> Vertex -> ST s a
forall a. (a -> a) -> a
fix
          ( \Vertex -> Vertex -> ST s a
dfs Vertex
p Vertex
v -> do
              a
res <-
                Vertex -> m -> a
foldChildren Vertex
v (m -> a) -> (Vector a -> m) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m -> m -> m) -> m -> Vector m -> m
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
U.foldl' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
forall a. Monoid a => a
mempty (Vector m -> m) -> (Vector a -> Vector m) -> Vector a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM
                  (Vector a -> a) -> ST s (Vector a) -> ST s a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Vertex -> ST s a) -> Vector Vertex -> ST s (Vector a)
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
(a -> m b) -> Vector a -> m (Vector b)
U.mapM (Vertex -> Vertex -> ST s a
dfs Vertex
v) ((Vertex -> Bool) -> Vector Vertex -> Vector Vertex
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
/= Vertex
p) (Vector Vertex -> Vector Vertex) -> Vector Vertex -> Vector Vertex
forall a b. (a -> b) -> a -> b
$ SparseGraph w
gr SparseGraph w -> Vertex -> Vector Vertex
forall w. SparseGraph w -> Vertex -> Vector Vertex
`adj` Vertex
v)
              MVector (PrimState (ST s)) a -> Vertex -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.write MVector s a
MVector (PrimState (ST s)) a
dp Vertex
v a
res
              a -> ST s a
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
res
          )
          (-Vertex
1)
          Vertex
root
      MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
dp

    dp2 :: U.Vector a
    dp2 :: Vector a
dp2 = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
      MVector s a
dp <- Vertex -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Vertex -> m (MVector (PrimState m) a)
UM.unsafeNew (SparseGraph w -> Vertex
forall w. SparseGraph w -> Vertex
numVerticesSG SparseGraph w
gr)
      MVector (PrimState (ST s)) a -> Vertex -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.write MVector s a
MVector (PrimState (ST s)) a
dp Vertex
root (Vector a
dp1 Vector a -> Vertex -> a
forall a. Unbox a => Vector a -> Vertex -> a
U.! Vertex
root)
      ((Vertex -> m -> Vertex -> ST s ())
 -> Vertex -> m -> Vertex -> ST s ())
-> Vertex -> m -> Vertex -> ST s ()
forall a. (a -> a) -> a
fix
        ( \Vertex -> m -> Vertex -> ST s ()
dfs Vertex
parent !m
fromParent Vertex
v -> do
            let children :: Vector Vertex
children = (Vertex -> Bool) -> Vector Vertex -> Vector Vertex
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
/= Vertex
parent) (Vector Vertex -> Vector Vertex) -> Vector Vertex -> Vector Vertex
forall a b. (a -> b) -> a -> b
$ SparseGraph w
gr SparseGraph w -> Vertex -> Vector Vertex
forall w. SparseGraph w -> Vertex -> Vector Vertex
`adj` Vertex
v
            MVector (PrimState (ST s)) a -> Vertex -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.write MVector s a
MVector (PrimState (ST s)) a
dp Vertex
v
              (a -> ST s ()) -> (Vector a -> a) -> Vector a -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vertex -> m -> a
foldChildren Vertex
v
              (m -> a) -> (Vector a -> m) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m -> m -> m) -> m -> Vector m -> m
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
U.foldl' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
fromParent
              (Vector m -> m) -> (Vector a -> Vector m) -> Vector a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM
              (Vector a -> ST s ()) -> Vector a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector a -> Vector Vertex -> Vector a
forall a. Unbox a => Vector a -> Vector Vertex -> Vector a
U.backpermute Vector a
dp1 Vector Vertex
children
            let cumL :: Vector m
cumL =
                  (m -> m -> m) -> m -> Vector m -> Vector m
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.prescanl' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
forall a. Monoid a => a
mempty (Vector m -> Vector m)
-> (Vector a -> Vector m) -> Vector a -> Vector m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM (Vector a -> Vector m) -> Vector a -> Vector m
forall a b. (a -> b) -> a -> b
$
                    Vector a -> Vector Vertex -> Vector a
forall a. Unbox a => Vector a -> Vector Vertex -> Vector a
U.backpermute Vector a
dp1 Vector Vertex
children
            let cumR :: Vector m
cumR =
                  (m -> m -> m) -> m -> Vector m -> Vector m
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> b) -> b -> Vector a -> Vector b
U.prescanr' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
forall a. Monoid a => a
mempty (Vector m -> Vector m)
-> (Vector a -> Vector m) -> Vector a -> Vector m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM (Vector a -> Vector m) -> Vector a -> Vector m
forall a b. (a -> b) -> a -> b
$
                    Vector a -> Vector Vertex -> Vector a
forall a. Unbox a => Vector a -> Vector Vertex -> Vector a
U.backpermute Vector a
dp1 Vector Vertex
children
            (m -> Vertex -> ST s ()) -> Vector m -> Vector Vertex -> ST s ()
forall (m :: * -> *) a b c.
(Monad m, Unbox a, Unbox b) =>
(a -> b -> m c) -> Vector a -> Vector b -> m ()
U.zipWithM_
              (Vertex -> m -> Vertex -> ST s ()
dfs Vertex
v (m -> Vertex -> ST s ()) -> (m -> m) -> m -> Vertex -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m
toM (a -> m) -> (m -> a) -> m -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vertex -> m -> a
foldChildren Vertex
v (m -> a) -> (m -> m) -> m -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m -> m -> m
forall a. Semigroup a => a -> a -> a
<> m
fromParent))
              ((m -> m -> m) -> Vector m -> Vector m -> Vector m
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) Vector m
cumL Vector m
cumR)
              Vector Vertex
children
        )
        (-Vertex
1)
        m
forall a. Monoid a => a
mempty
        Vertex
root
      MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
dp