module Data.Graph.Tree.Rerooting where
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Graph.Sparse
rerootingDP ::
forall w a m.
(U.Unbox w, U.Unbox a, U.Unbox m, Monoid m) =>
SparseGraph w ->
(a -> m) ->
(Vertex -> m -> a) ->
U.Vector a
rerootingDP :: forall w a m.
(Unbox w, Unbox a, Unbox m, Monoid m) =>
SparseGraph w -> (a -> m) -> (Vertex -> m -> a) -> Vector a
rerootingDP SparseGraph w
gr a -> m
toM Vertex -> m -> a
foldChildren = Vector a
dp2
where
root :: Vertex
root :: Vertex
root = Vertex
0
dp1 :: U.Vector a
!dp1 :: Vector a
dp1 = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
dp <- Vertex -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Vertex -> m (MVector (PrimState m) a)
UM.unsafeNew (SparseGraph w -> Vertex
forall w. SparseGraph w -> Vertex
numVerticesSG SparseGraph w
gr)
_ <-
fix
( \Vertex -> Vertex -> ST s a
dfs Vertex
p Vertex
v -> do
res <-
Vertex -> m -> a
foldChildren Vertex
v (m -> a) -> (Vector a -> m) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m -> m -> m) -> m -> Vector m -> m
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
U.foldl' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
forall a. Monoid a => a
mempty (Vector m -> m) -> (Vector a -> Vector m) -> Vector a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM
(Vector a -> a) -> ST s (Vector a) -> ST s a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Vertex -> ST s a) -> Vector Vertex -> ST s (Vector a)
forall (m :: * -> *) a b.
(Monad m, Unbox a, Unbox b) =>
(a -> m b) -> Vector a -> m (Vector b)
U.mapM (Vertex -> Vertex -> ST s a
dfs Vertex
v) ((Vertex -> Bool) -> Vector Vertex -> Vector Vertex
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
/= Vertex
p) (Vector Vertex -> Vector Vertex) -> Vector Vertex -> Vector Vertex
forall a b. (a -> b) -> a -> b
$ SparseGraph w
gr SparseGraph w -> Vertex -> Vector Vertex
forall w. SparseGraph w -> Vertex -> Vector Vertex
`adj` Vertex
v)
UM.write dp v res
return res
)
(-1)
root
return dp
dp2 :: U.Vector a
dp2 :: Vector a
dp2 = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
dp <- Vertex -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Vertex -> m (MVector (PrimState m) a)
UM.unsafeNew (SparseGraph w -> Vertex
forall w. SparseGraph w -> Vertex
numVerticesSG SparseGraph w
gr)
UM.write dp root (dp1 U.! root)
fix
( \Vertex -> m -> Vertex -> ST s ()
dfs Vertex
parent !m
fromParent Vertex
v -> do
let children :: Vector Vertex
children = (Vertex -> Bool) -> Vector Vertex -> Vector Vertex
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Vertex -> Vertex -> Bool
forall a. Eq a => a -> a -> Bool
/= Vertex
parent) (Vector Vertex -> Vector Vertex) -> Vector Vertex -> Vector Vertex
forall a b. (a -> b) -> a -> b
$ SparseGraph w
gr SparseGraph w -> Vertex -> Vector Vertex
forall w. SparseGraph w -> Vertex -> Vector Vertex
`adj` Vertex
v
MVector (PrimState (ST s)) a -> Vertex -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Vertex -> a -> m ()
UM.write MVector s a
MVector (PrimState (ST s)) a
dp Vertex
v
(a -> ST s ()) -> (Vector a -> a) -> Vector a -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vertex -> m -> a
foldChildren Vertex
v
(m -> a) -> (Vector a -> m) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m -> m -> m) -> m -> Vector m -> m
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
U.foldl' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
fromParent
(Vector m -> m) -> (Vector a -> Vector m) -> Vector a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM
(Vector a -> ST s ()) -> Vector a -> ST s ()
forall a b. (a -> b) -> a -> b
$ Vector a -> Vector Vertex -> Vector a
forall a. Unbox a => Vector a -> Vector Vertex -> Vector a
U.backpermute Vector a
dp1 Vector Vertex
children
let cumL :: Vector m
cumL =
(m -> m -> m) -> m -> Vector m -> Vector m
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.prescanl' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
forall a. Monoid a => a
mempty (Vector m -> Vector m)
-> (Vector a -> Vector m) -> Vector a -> Vector m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM (Vector a -> Vector m) -> Vector a -> Vector m
forall a b. (a -> b) -> a -> b
$
Vector a -> Vector Vertex -> Vector a
forall a. Unbox a => Vector a -> Vector Vertex -> Vector a
U.backpermute Vector a
dp1 Vector Vertex
children
let cumR :: Vector m
cumR =
(m -> m -> m) -> m -> Vector m -> Vector m
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> b) -> b -> Vector a -> Vector b
U.prescanr' m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) m
forall a. Monoid a => a
mempty (Vector m -> Vector m)
-> (Vector a -> Vector m) -> Vector a -> Vector m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> m) -> Vector a -> Vector m
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> m
toM (Vector a -> Vector m) -> Vector a -> Vector m
forall a b. (a -> b) -> a -> b
$
Vector a -> Vector Vertex -> Vector a
forall a. Unbox a => Vector a -> Vector Vertex -> Vector a
U.backpermute Vector a
dp1 Vector Vertex
children
(m -> Vertex -> ST s ()) -> Vector m -> Vector Vertex -> ST s ()
forall (m :: * -> *) a b c.
(Monad m, Unbox a, Unbox b) =>
(a -> b -> m c) -> Vector a -> Vector b -> m ()
U.zipWithM_
(Vertex -> m -> Vertex -> ST s ()
dfs Vertex
v (m -> Vertex -> ST s ()) -> (m -> m) -> m -> Vertex -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m
toM (a -> m) -> (m -> a) -> m -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vertex -> m -> a
foldChildren Vertex
v (m -> a) -> (m -> m) -> m -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (m -> m -> m
forall a. Semigroup a => a -> a -> a
<> m
fromParent))
((m -> m -> m) -> Vector m -> Vector m -> Vector m
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith m -> m -> m
forall a. Semigroup a => a -> a -> a
(<>) Vector m
cumL Vector m
cumR)
Vector Vertex
children
)
(-1)
mempty
root
return dp