{-# LANGUAGE TypeFamilies #-}

module Data.Vector.Sort.Merge where

import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Bits
import Data.Function
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as GM

{- |
>>> import qualified Data.Vector.Unboxed as U
>>> inversionNumber $ U.fromList "312"
2
>>> inversionNumber $ U.fromList "100"
2
>>> inversionNumber $ U.fromList "123"
0
-}
inversionNumber :: (G.Vector v a, Ord a) => v a -> Int
inversionNumber :: forall (v :: * -> *) a. (Vector v a, Ord a) => v a -> Int
inversionNumber v a
xs = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
  mvec <- v a -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v a
xs
  mergeSort mvec
{-# INLINE inversionNumber #-}

{- |
>>> import qualified Data.Vector.Unboxed as U
>>> U.modify mergeSort_ $ U.fromList "3610425"
"0123456"
>>> import Data.List (sort)
prop> \(xs::[Int]) -> U.fromList (sort xs) == U.modify mergeSort_ (U.fromList xs)
+++ OK, passed 100 tests.
-}
mergeSort ::
  (GM.MVector mv a, Ord a, PrimMonad m) =>
  mv (PrimState m) a ->
  -- | inversion number
  m Int
mergeSort :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, Ord a, PrimMonad m) =>
mv (PrimState m) a -> m Int
mergeSort = (a -> a -> Ordering) -> mv (PrimState m) a -> m Int
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m Int
mergeSortBy a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINE mergeSort #-}

mergeSortBy ::
  (GM.MVector mv a, PrimMonad m) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  -- | inversion number
  m Int
mergeSortBy :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m Int
mergeSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec0 = do
  buf <- Int -> m (mv (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
GM.unsafeNew (Int -> Int -> Int
forall a. Integral a => a -> a -> a
quot (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
2)
  go buf mvec0
  where
    go :: mv (PrimState m) a -> mv (PrimState m) a -> m Int
go !mv (PrimState m) a
buf !mv (PrimState m) a
mvec
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
16 = (a -> a -> Ordering) -> mv (PrimState m) a -> m Int
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m Int
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec
      | Bool
otherwise = do
          let !numL :: Int
numL = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
1
              !numR :: Int
numR = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
numL
              (!mv (PrimState m) a
buf', !mv (PrimState m) a
vecR) = Int
-> mv (PrimState m) a -> (mv (PrimState m) a, mv (PrimState m) a)
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> v s a -> (v s a, v s a)
GM.splitAt Int
numL mv (PrimState m) a
mvec
              !vecL :: mv (PrimState m) a
vecL = Int -> mv (PrimState m) a -> mv (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
GM.unsafeTake Int
numL mv (PrimState m) a
buf
          mv (PrimState m) a -> mv (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
GM.unsafeCopy mv (PrimState m) a
vecL mv (PrimState m) a
buf'
          !invL <- mv (PrimState m) a -> mv (PrimState m) a -> m Int
go mv (PrimState m) a
buf' mv (PrimState m) a
vecL
          !invR <- go buf' vecR
          fix
            ( \Int -> Int -> Int -> m Int
loop !Int
invNum !Int
posL !Int
posR ->
                if Int
posL Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
numL Bool -> Bool -> Bool
&& Int
posR Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
numR
                  then do
                    vl <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vecL Int
posL
                    vr <- GM.unsafeRead vecR posR
                    case cmp vl vr of
                      Ordering
GT -> do
                        mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec (Int
posL Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
posR) a
vr
                        Int -> Int -> Int -> m Int
loop (Int
invNum Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
numL Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
posL) Int
posL (Int
posR Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                      Ordering
_ -> do
                        mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec (Int
posL Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
posR) a
vl
                        Int -> Int -> Int -> m Int
loop Int
invNum (Int
posL Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
posR
                  else do
                    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
posL Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
numL) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                      mv (PrimState m) a -> mv (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
GM.unsafeCopy
                        (Int -> Int -> mv (PrimState m) a -> mv (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice (Int
posL Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
posR) (Int
numL Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
posL) mv (PrimState m) a
mvec)
                        (Int -> Int -> mv (PrimState m) a -> mv (PrimState m) a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
GM.unsafeSlice Int
posL (Int
numL Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
posL) mv (PrimState m) a
vecL)
                    Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
invNum
            )
            (invL + invR)
            0
            0
      where
        !n :: Int
n = mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec
{-# INLINE mergeSortBy #-}

mergeSort_ :: (GM.MVector mv a, Ord a, PrimMonad m) => mv (PrimState m) a -> m ()
mergeSort_ :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, Ord a, PrimMonad m) =>
mv (PrimState m) a -> m ()
mergeSort_ = m Int -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Int -> m ())
-> (mv (PrimState m) a -> m Int) -> mv (PrimState m) a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. mv (PrimState m) a -> m Int
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, Ord a, PrimMonad m) =>
mv (PrimState m) a -> m Int
mergeSort
{-# INLINE mergeSort_ #-}

mergeSortBy_ ::
  (GM.MVector mv a, PrimMonad m) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m ()
mergeSortBy_ :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
mergeSortBy_ a -> a -> Ordering
cmp = m Int -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Int -> m ())
-> (mv (PrimState m) a -> m Int) -> mv (PrimState m) a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> Ordering) -> mv (PrimState m) a -> m Int
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m Int
mergeSortBy a -> a -> Ordering
cmp
{-# INLINE mergeSortBy_ #-}

{- |
>>> import qualified Data.Vector.Unboxed as U
>>> U.modify insertionSort_ $ U.fromList "3610425"
"0123456"
>>> U.modify insertionSort_ $ U.fromList ""
""
>>> U.modify insertionSort_ $ U.fromList "x"
"x"
>>> U.modify insertionSort_ $ U.fromList "ba"
"ab"
>>> import Data.List (sort)
prop> \(xs::[Int]) -> U.fromList (sort xs) == U.modify insertionSort_ (U.fromList xs)
+++ OK, passed 100 tests.
-}
insertionSort ::
  (GM.MVector mv a, Ord a, PrimMonad m) =>
  mv (PrimState m) a ->
  -- | inversion number
  m Int
insertionSort :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, Ord a, PrimMonad m) =>
mv (PrimState m) a -> m Int
insertionSort = (a -> a -> Ordering) -> mv (PrimState m) a -> m Int
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m Int
insertionSortBy a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINE insertionSort #-}

insertionSortBy ::
  (GM.MVector mv a, PrimMonad m) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  -- | inversion number
  m Int
insertionSortBy :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m Int
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec = do
  ((Int -> Int -> m Int) -> Int -> Int -> m Int)
-> Int -> Int -> m Int
forall a. (a -> a) -> a
fix
    ( \Int -> Int -> m Int
outer !Int
invNum !Int
i ->
        if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
          then do
            v0 <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec Int
0
            vi <- GM.unsafeRead mvec i
            case cmp v0 vi of
              Ordering
GT -> do
                ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix
                  ( \Int -> m ()
inner Int
j -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                      mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) m a -> (a -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j
                      Int -> m ()
inner (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                  )
                  Int
i
                mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
0 a
vi
                Int -> Int -> m Int
outer (Int
invNum Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              Ordering
_ -> do
                ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix
                  ( \Int -> m Int
inner Int
j -> do
                      vj' <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                      case cmp vj' vi of
                        Ordering
GT -> do
                          mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j a
vj'
                          Int -> m Int
inner (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                        Ordering
_ -> do
                          mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j a
vi
                          Int -> Int -> m Int
outer (Int
invNum Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
j) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                  )
                  Int
i
          else Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
invNum
    )
    Int
0
    Int
1
  where
    !n :: Int
n = mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec
{-# INLINE insertionSortBy #-}

insertionSort_ :: (GM.MVector mv a, Ord a, PrimMonad m) => mv (PrimState m) a -> m ()
insertionSort_ :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, Ord a, PrimMonad m) =>
mv (PrimState m) a -> m ()
insertionSort_ = m Int -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Int -> m ())
-> (mv (PrimState m) a -> m Int) -> mv (PrimState m) a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. mv (PrimState m) a -> m Int
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, Ord a, PrimMonad m) =>
mv (PrimState m) a -> m Int
insertionSort
{-# INLINE insertionSort_ #-}

insertionSortBy_ ::
  (GM.MVector mv a, PrimMonad m) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m ()
insertionSortBy_ :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
insertionSortBy_ a -> a -> Ordering
cmp = m Int -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Int -> m ())
-> (mv (PrimState m) a -> m Int) -> mv (PrimState m) a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> Ordering) -> mv (PrimState m) a -> m Int
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m Int
insertionSortBy a -> a -> Ordering
cmp
{-# INLINE insertionSortBy_ #-}