module Data.Graph.Dense.TSP where
import Control.Monad
import Control.Monad.ST
import Data.Bits
import Data.Coerce
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import Data.Vector.Fusion.Util
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.BitSet
import Data.Graph.Dense
import My.Prelude (rep, (..<), (>..))
data TSPResult a = TSPResult
{ forall a. TSPResult a -> a
resultTSP :: !a
, forall a. TSPResult a -> Vector a
freezedTSP :: U.Vector a
}
runTSP ::
(U.Unbox w, Num w, Eq w, Ord w) =>
w ->
DenseGraph w ->
TSPResult w
runTSP :: forall w.
(Unbox w, Num w, Eq w, Ord w) =>
w -> DenseGraph w -> TSPResult w
runTSP w
inf DenseGraph w
gr = (forall s. ST s (TSPResult w)) -> TSPResult w
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (TSPResult w)) -> TSPResult w)
-> (forall s. ST s (TSPResult w)) -> TSPResult w
forall a b. (a -> b) -> a -> b
$ do
MVector s w
dp <- Int -> w -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) w
inf
Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
v -> do
MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP (Int -> BitSet
singletonBS Int
v) Int
v) (w -> ST s ()) -> w -> ST s ()
forall a b. (a -> b) -> a -> b
$ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
origin Int
v
Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n) ((Int -> ST s ()) -> ST s ())
-> ((BitSet -> ST s ()) -> Int -> ST s ())
-> (BitSet -> ST s ())
-> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((BitSet -> ST s ()) -> (Int -> BitSet) -> Int -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> BitSet
BitSet) ((BitSet -> ST s ()) -> ST s ()) -> (BitSet -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \BitSet
visited ->
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BitSet -> Int
forall a. Bits a => a -> Int
popCount BitSet
visited Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
((Int -> ST s ()) -> Stream (ST s) Int -> ST s ())
-> Stream (ST s) Int -> (Int -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> ST s ()) -> Stream (ST s) Int -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (BitSet -> Stream (ST s) Int
forall (m :: * -> *). Monad m => BitSet -> Stream m Int
toStreamBS BitSet
visited) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
v -> do
(w -> Int -> ST s w) -> w -> Stream (ST s) Int -> ST s w
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
( \w
acc Int
pv -> do
w
dpv <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP (Int -> BitSet -> BitSet
deleteBS Int
v BitSet
visited) Int
pv)
w -> ST s w
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (w -> ST s w) -> w -> ST s w
forall a b. (a -> b) -> a -> b
$ w -> w -> w
forall a. Ord a => a -> a -> a
min w
acc (w
dpv w -> w -> w
forall a. Num a => a -> a -> a
+ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
pv Int
v)
)
w
inf
(Int
n Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
>.. Int
0)
ST s w -> (w -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP BitSet
visited Int
v)
!w
res <-
(w -> Int -> ST s w) -> w -> Stream (ST s) Int -> ST s w
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
( \w
acc Int
v -> do
w
dv <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP BitSet
visitedAll Int
v)
w -> ST s w
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (w -> ST s w) -> w -> ST s w
forall a b. (a -> b) -> a -> b
$ w -> w -> w
forall a. Ord a => a -> a -> a
min w
acc (w
dv w -> w -> w
forall a. Num a => a -> a -> a
+ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
v Int
origin)
)
(if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 then w
inf else w
0)
(Stream (ST s) Int -> ST s w) -> Stream (ST s) Int -> ST s w
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
n
w -> Vector w -> TSPResult w
forall a. a -> Vector a -> TSPResult a
TSPResult w
res (Vector w -> TSPResult w) -> ST s (Vector w) -> ST s (TSPResult w)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) w -> ST s (Vector w)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s w
MVector (PrimState (ST s)) w
dp
where
!n :: Int
n = DenseGraph w -> Int
forall a. DenseGraph a -> Int
numVerticesDG DenseGraph w
gr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
origin :: Int
origin = Int
n
visitedAll :: BitSet
visitedAll = Int -> BitSet
BitSet (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
ixTSP :: BitSet -> Int -> Int
ixTSP :: BitSet -> Int -> Int
ixTSP BitSet
visited Int
lastPos = forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @BitSet @Int BitSet
visited Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lastPos
{-# INLINE ixTSP #-}
{-# INLINE runTSP #-}
reconstructTSP ::
(U.Unbox w, Num w, Eq w) =>
DenseGraph w ->
TSPResult w ->
U.Vector Int
reconstructTSP :: forall w.
(Unbox w, Num w, Eq w) =>
DenseGraph w -> TSPResult w -> Vector Int
reconstructTSP DenseGraph w
gr TSPResult{freezedTSP :: forall a. TSPResult a -> Vector a
freezedTSP = Vector w
dp, w
resultTSP :: forall a. TSPResult a -> a
resultTSP :: w
resultTSP} = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
MVector s Int
path <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)
MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
path Int
0 Int
origin
MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
path (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
origin
ST s (BitSet, Int, w) -> ST s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST s (BitSet, Int, w) -> ST s ())
-> ST s (BitSet, Int, w) -> ST s ()
forall a b. (a -> b) -> a -> b
$
((BitSet, Int, w) -> Int -> ST s (BitSet, Int, w))
-> (BitSet, Int, w) -> Stream (ST s) Int -> ST s (BitSet, Int, w)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
( \(!BitSet
visited, !Int
nv, !w
dnv) Int
pos -> do
let !v :: Int
v =
Int -> (Int -> Int) -> Maybe Int -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"reconstructTSP") Int -> Int
forall a. a -> a
id
(Maybe Int -> Int)
-> (Stream Id Int -> Maybe Int) -> Stream Id Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id (Maybe Int) -> Maybe Int
forall a. Id a -> a
unId
(Id (Maybe Int) -> Maybe Int)
-> (Stream Id Int -> Id (Maybe Int)) -> Stream Id Int -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Stream Id Int -> Id (Maybe Int)
forall (m :: * -> *) a.
Monad m =>
(a -> Bool) -> Stream m a -> m (Maybe Int)
MS.findIndex (BitSet -> Int -> w -> Int -> Bool
isPrev BitSet
visited Int
nv w
dnv)
(Stream Id Int -> Int) -> Stream Id Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Stream Id Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
n
MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
path Int
pos Int
v
(BitSet, Int, w) -> ST s (BitSet, Int, w)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> BitSet -> BitSet
deleteBS Int
v BitSet
visited, Int
v, BitSet -> Int -> w
dist BitSet
visited Int
v)
)
(BitSet
visitedAll, Int
origin, w
resultTSP)
((Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
>.. Int
1)
MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Int
path
where
!n :: Int
n = DenseGraph w -> Int
forall a. DenseGraph a -> Int
numVerticesDG DenseGraph w
gr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
visitedAll :: BitSet
visitedAll = Int -> BitSet
BitSet (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
origin :: Int
origin = Int
n
isPrev :: BitSet -> Int -> w -> Int -> Bool
isPrev BitSet
visited Int
nv w
dnv = \Int
v ->
Int -> BitSet -> Bool
memberBS Int
v BitSet
visited Bool -> Bool -> Bool
&& BitSet -> Int -> w
dist BitSet
visited Int
v w -> w -> w
forall a. Num a => a -> a -> a
+ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
v Int
nv w -> w -> Bool
forall a. Eq a => a -> a -> Bool
== w
dnv
{-# INLINE isPrev #-}
dist :: BitSet -> Int -> w
dist BitSet
visited Int
v = Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
dp (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @BitSet BitSet
visited Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
v)
{-# INLINE dist #-}