module Data.Graph.Dense.TSP where

import Control.Monad
import Control.Monad.ST
import Data.Bits
import Data.Coerce
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import Data.Vector.Fusion.Util
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.BitSet
import Data.Graph.Dense
import My.Prelude (rep, (..<), (>..))

data TSPResult a = TSPResult
  { forall a. TSPResult a -> a
resultTSP :: !a
  , forall a. TSPResult a -> Vector a
freezedTSP :: U.Vector a
  }

{- | Traveling Salesman Problem

/O(n^2 2^n)/

>>> resultTSP . runTSP 0x3f3f3f3f $ fromListDG @Int [[0,1,999],[999,0,2],[4,999,0]]
7
>>> resultTSP . runTSP 0x3f3f3f3f $ fromListDG @Int [[0,1,1],[1,0,8],[1,999,0]]
10
>>> resultTSP . runTSP 0x3f3f3f3f $ fromListDG @Int [[0]]
0
>>> resultTSP . runTSP 999 $ fromListDG @Int [[0,-1,999],[999,0,1],[1,999,0]]
1
>>> resultTSP . runTSP 999 $ fromListDG @Int [[0,-1,999],[999,0,-1],[999,999,0]]
997
-}
runTSP ::
  (U.Unbox w, Num w, Eq w, Ord w) =>
  -- | inf @(2 * inf <= maxBound)@
  w ->
  DenseGraph w ->
  TSPResult w
runTSP :: forall w.
(Unbox w, Num w, Eq w, Ord w) =>
w -> DenseGraph w -> TSPResult w
runTSP w
inf DenseGraph w
gr = (forall s. ST s (TSPResult w)) -> TSPResult w
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (TSPResult w)) -> TSPResult w)
-> (forall s. ST s (TSPResult w)) -> TSPResult w
forall a b. (a -> b) -> a -> b
$ do
  MVector s w
dp <- Int -> w -> ST s (MVector (PrimState (ST s)) w)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) w
inf

  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
v -> do
    MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP (Int -> BitSet
singletonBS Int
v) Int
v) (w -> ST s ()) -> w -> ST s ()
forall a b. (a -> b) -> a -> b
$ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
origin Int
v

  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n) ((Int -> ST s ()) -> ST s ())
-> ((BitSet -> ST s ()) -> Int -> ST s ())
-> (BitSet -> ST s ())
-> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((BitSet -> ST s ()) -> (Int -> BitSet) -> Int -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> BitSet
BitSet) ((BitSet -> ST s ()) -> ST s ()) -> (BitSet -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \BitSet
visited ->
    Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (BitSet -> Int
forall a. Bits a => a -> Int
popCount BitSet
visited Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
      ((Int -> ST s ()) -> Stream (ST s) Int -> ST s ())
-> Stream (ST s) Int -> (Int -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> ST s ()) -> Stream (ST s) Int -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (BitSet -> Stream (ST s) Int
forall (m :: * -> *). Monad m => BitSet -> Stream m Int
toStreamBS BitSet
visited) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
v -> do
        (w -> Int -> ST s w) -> w -> Stream (ST s) Int -> ST s w
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
          ( \w
acc Int
pv -> do
              w
dpv <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP (Int -> BitSet -> BitSet
deleteBS Int
v BitSet
visited) Int
pv)
              w -> ST s w
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (w -> ST s w) -> w -> ST s w
forall a b. (a -> b) -> a -> b
$ w -> w -> w
forall a. Ord a => a -> a -> a
min w
acc (w
dpv w -> w -> w
forall a. Num a => a -> a -> a
+ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
pv Int
v)
          )
          w
inf
          (Int
n Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
>.. Int
0) -- faster than (0 ..< n)
          ST s w -> (w -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s)) w -> Int -> w -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP BitSet
visited Int
v)

  !w
res <-
    (w -> Int -> ST s w) -> w -> Stream (ST s) Int -> ST s w
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
      ( \w
acc Int
v -> do
          w
dv <- MVector (PrimState (ST s)) w -> Int -> ST s w
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s w
MVector (PrimState (ST s)) w
dp (BitSet -> Int -> Int
ixTSP BitSet
visitedAll Int
v)
          w -> ST s w
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (w -> ST s w) -> w -> ST s w
forall a b. (a -> b) -> a -> b
$ w -> w -> w
forall a. Ord a => a -> a -> a
min w
acc (w
dv w -> w -> w
forall a. Num a => a -> a -> a
+ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
v Int
origin)
      )
      (if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 then w
inf else w
0)
      (Stream (ST s) Int -> ST s w) -> Stream (ST s) Int -> ST s w
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
n

  w -> Vector w -> TSPResult w
forall a. a -> Vector a -> TSPResult a
TSPResult w
res (Vector w -> TSPResult w) -> ST s (Vector w) -> ST s (TSPResult w)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) w -> ST s (Vector w)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
U.unsafeFreeze MVector s w
MVector (PrimState (ST s)) w
dp
  where
    !n :: Int
n = DenseGraph w -> Int
forall a. DenseGraph a -> Int
numVerticesDG DenseGraph w
gr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    origin :: Int
origin = Int
n
    visitedAll :: BitSet
visitedAll = Int -> BitSet
BitSet (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

    ixTSP :: BitSet -> Int -> Int
    ixTSP :: BitSet -> Int -> Int
ixTSP BitSet
visited Int
lastPos = forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @BitSet @Int BitSet
visited Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lastPos
    {-# INLINE ixTSP #-}
{-# INLINE runTSP #-}

{- |

/O(n^2)/

>>> let gr = fromListDG @Int [[0,1,999],[999,0,2],[4,999,0]]
>>> reconstructTSP gr $ runTSP gr
[2,0,1,2]
-}
reconstructTSP ::
  (U.Unbox w, Num w, Eq w) =>
  DenseGraph w ->
  TSPResult w ->
  U.Vector Int
reconstructTSP :: forall w.
(Unbox w, Num w, Eq w) =>
DenseGraph w -> TSPResult w -> Vector Int
reconstructTSP DenseGraph w
gr TSPResult{freezedTSP :: forall a. TSPResult a -> Vector a
freezedTSP = Vector w
dp, w
resultTSP :: forall a. TSPResult a -> a
resultTSP :: w
resultTSP} = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
  MVector s Int
path <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2)
  MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
path Int
0 Int
origin
  MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
path (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
origin

  ST s (BitSet, Int, w) -> ST s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST s (BitSet, Int, w) -> ST s ())
-> ST s (BitSet, Int, w) -> ST s ()
forall a b. (a -> b) -> a -> b
$
    ((BitSet, Int, w) -> Int -> ST s (BitSet, Int, w))
-> (BitSet, Int, w) -> Stream (ST s) Int -> ST s (BitSet, Int, w)
forall (m :: * -> *) a b.
Monad m =>
(a -> b -> m a) -> a -> Stream m b -> m a
MS.foldM'
      ( \(!BitSet
visited, !Int
nv, !w
dnv) Int
pos -> do
          let !v :: Int
v =
                Int -> (Int -> Int) -> Maybe Int -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"reconstructTSP") Int -> Int
forall a. a -> a
id
                  (Maybe Int -> Int)
-> (Stream Id Int -> Maybe Int) -> Stream Id Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id (Maybe Int) -> Maybe Int
forall a. Id a -> a
unId
                  (Id (Maybe Int) -> Maybe Int)
-> (Stream Id Int -> Id (Maybe Int)) -> Stream Id Int -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> Stream Id Int -> Id (Maybe Int)
forall (m :: * -> *) a.
Monad m =>
(a -> Bool) -> Stream m a -> m (Maybe Int)
MS.findIndex (BitSet -> Int -> w -> Int -> Bool
isPrev BitSet
visited Int
nv w
dnv)
                  (Stream Id Int -> Int) -> Stream Id Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Stream Id Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
..< Int
n
          MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
path Int
pos Int
v
          (BitSet, Int, w) -> ST s (BitSet, Int, w)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> BitSet -> BitSet
deleteBS Int
v BitSet
visited, Int
v, BitSet -> Int -> w
dist BitSet
visited Int
v)
      )
      (BitSet
visitedAll, Int
origin, w
resultTSP)
      ((Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Stream (ST s) Int
forall (m :: * -> *). Monad m => Int -> Int -> Stream m Int
>.. Int
1)
  MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Int
path
  where
    !n :: Int
n = DenseGraph w -> Int
forall a. DenseGraph a -> Int
numVerticesDG DenseGraph w
gr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    visitedAll :: BitSet
visitedAll = Int -> BitSet
BitSet (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
shiftL Int
1 Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    origin :: Int
origin = Int
n

    isPrev :: BitSet -> Int -> w -> Int -> Bool
isPrev BitSet
visited Int
nv w
dnv = \Int
v ->
      Int -> BitSet -> Bool
memberBS Int
v BitSet
visited Bool -> Bool -> Bool
&& BitSet -> Int -> w
dist BitSet
visited Int
v w -> w -> w
forall a. Num a => a -> a -> a
+ DenseGraph w -> Int -> Int -> w
forall a. Unbox a => DenseGraph a -> Int -> Int -> a
matDG DenseGraph w
gr Int
v Int
nv w -> w -> Bool
forall a. Eq a => a -> a -> Bool
== w
dnv
    {-# INLINE isPrev #-}

    dist :: BitSet -> Int -> w
dist BitSet
visited Int
v = Vector w -> Int -> w
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector w
dp (forall a b. Coercible a b => a -> b
forall a b. Coercible a b => a -> b
coerce @BitSet BitSet
visited Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
v)
    {-# INLINE dist #-}