module Data.Vector.Sort.Quick where

import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Function
import qualified Data.Vector.Generic.Mutable as GM
import System.Random

import System.Random.Utils

{- | Random Pivot Quick Sort

 /O(n log n)/
-}
quickSort ::
  (Ord a, PrimMonad m, GM.MVector mv a) =>
  mv (PrimState m) a ->
  m ()
quickSort :: forall a (m :: * -> *) (mv :: * -> * -> *).
(Ord a, PrimMonad m, MVector mv a) =>
mv (PrimState m) a -> m ()
quickSort = (a -> a -> Ordering) -> mv (PrimState m) a -> m ()
forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
quickSortBy a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINE quickSort #-}

{- | Random Pivot Quick Sort

 /O(n log n)/
-}
quickSortBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m ()
quickSortBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
quickSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mv0 = do
  StdGen
rng0 <- m StdGen
forall (m :: * -> *). PrimMonad m => m StdGen
newStdGenPrim
  m StdGen -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void
    (m StdGen -> m ()) -> m StdGen -> m ()
forall a b. (a -> b) -> a -> b
$ ((mv (PrimState m) a -> StdGen -> m StdGen)
 -> mv (PrimState m) a -> StdGen -> m StdGen)
-> mv (PrimState m) a -> StdGen -> m StdGen
forall a. (a -> a) -> a
fix
      ( \mv (PrimState m) a -> StdGen -> m StdGen
loop !mv (PrimState m) a
mvec !StdGen
rng ->
          if mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
32
            then do
              case Word64 -> StdGen -> (Word64, StdGen)
forall g. RandomGen g => Word64 -> g -> (Word64, g)
genWord64R (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) StdGen
rng of
                (Word64
w64, StdGen
rng') -> do
                  a
pivot <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.read mv (PrimState m) a
mvec (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w64)
                  Int
cut <- (a -> a -> Ordering) -> mv (PrimState m) a -> a -> m Int
forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> a -> m Int
pivotPartitionBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec a
pivot
                  mv (PrimState m) a -> StdGen -> m StdGen
loop (Int -> mv (PrimState m) a -> mv (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
GM.take Int
cut mv (PrimState m) a
mvec) StdGen
rng'
                    m StdGen -> (StdGen -> m StdGen) -> m StdGen
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= mv (PrimState m) a -> StdGen -> m StdGen
loop (Int -> mv (PrimState m) a -> mv (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
GM.drop Int
cut mv (PrimState m) a
mvec)
            else do
              (a -> a -> Ordering) -> mv (PrimState m) a -> m ()
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec
              StdGen -> m StdGen
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return StdGen
rng
      )
      mv (PrimState m) a
mv0
      StdGen
rng0
{-# INLINE quickSortBy #-}

{- | Random Pivot Quick Select

 /O(n)/
-}
quickSelect ::
  (Ord a, PrimMonad m, GM.MVector mv a) =>
  mv (PrimState m) a ->
  Int ->
  m a
quickSelect :: forall a (m :: * -> *) (mv :: * -> * -> *).
(Ord a, PrimMonad m, MVector mv a) =>
mv (PrimState m) a -> Int -> m a
quickSelect = (a -> a -> Ordering) -> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> Int -> m a
quickSelectBy a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINE quickSelect #-}

{- | Random Pivot Quick Select

 /O(n)/
-}
quickSelectBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  Int ->
  m a
quickSelectBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> Int -> m a
quickSelectBy a -> a -> Ordering
cmp mv (PrimState m) a
mv0 Int
k0 =  do
  StdGen
rng0 <- m StdGen
forall (m :: * -> *). PrimMonad m => m StdGen
newStdGenPrim
  ((mv (PrimState m) a -> Int -> StdGen -> m a)
 -> mv (PrimState m) a -> Int -> StdGen -> m a)
-> mv (PrimState m) a -> Int -> StdGen -> m a
forall a. (a -> a) -> a
fix
    ( \mv (PrimState m) a -> Int -> StdGen -> m a
loop !mv (PrimState m) a
mvec !Int
k !StdGen
rng -> do
        if mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
32
          then do
            case Word64 -> StdGen -> (Word64, StdGen)
forall g. RandomGen g => Word64 -> g -> (Word64, g)
genWord64R (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) StdGen
rng of
                (Word64
w64, StdGen
rng') -> do
                  a
pivot <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.read mv (PrimState m) a
mvec (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w64)
                  Int
cut <- (a -> a -> Ordering) -> mv (PrimState m) a -> a -> m Int
forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> a -> m Int
pivotPartitionBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec a
pivot
                  if Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
cut
                    then mv (PrimState m) a -> Int -> StdGen -> m a
loop (Int -> mv (PrimState m) a -> mv (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
GM.take Int
cut mv (PrimState m) a
mvec) Int
k StdGen
rng'
                    else mv (PrimState m) a -> Int -> StdGen -> m a
loop (Int -> mv (PrimState m) a -> mv (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => Int -> v s a -> v s a
GM.drop Int
cut mv (PrimState m) a
mvec) (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
cut) StdGen
rng'
          else do
            (a -> a -> Ordering) -> mv (PrimState m) a -> m ()
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec
            mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec Int
k
    )
    mv (PrimState m) a
mv0
    Int
k0
    StdGen
rng0
{-# INLINE quickSelectBy #-}

pivotPartitionBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  a ->
  m Int
pivotPartitionBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> a -> m Int
pivotPartitionBy a -> a -> Ordering
cmp mv (PrimState m) a
vec !a
pivot =
  ((Int -> Int -> m Int) -> Int -> Int -> m Int)
-> Int -> Int -> m Int
forall a. (a -> a) -> a
fix
    ( \Int -> Int -> m Int
loop !Int
l !Int
r -> do
        !Int
l' <-
          ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix
            ( \Int -> m Int
loopL !Int
i -> do
                a
vi <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec Int
i
                case a -> a -> Ordering
cmp a
vi a
pivot of
                  Ordering
LT -> Int -> m Int
loopL (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                  Ordering
_ -> Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
            )
            Int
l
        !Int
r' <-
          ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix
            ( \Int -> m Int
loopR !Int
i -> do
                a
vi <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec Int
i
                case a -> a -> Ordering
cmp a
pivot a
vi of
                  Ordering
LT -> Int -> m Int
loopR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                  Ordering
_ -> Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
            )
            (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        if Int
l' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
r'
          then do
            mv (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
GM.unsafeSwap mv (PrimState m) a
vec Int
l' Int
r'
            Int -> Int -> m Int
loop (Int
l' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
r'
          else Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
l'
    )
    Int
0
    (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
vec)
{-# INLINE pivotPartitionBy #-}

getMedian3PivotBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m a
getMedian3PivotBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m a
getMedian3PivotBy a -> a -> Ordering
cmp mv (PrimState m) a
vec =
  (a -> a -> Ordering) -> a -> a -> a -> a
forall a. (a -> a -> Ordering) -> a -> a -> a -> a
medianBy a -> a -> Ordering
cmp
    (a -> a -> a -> a) -> m a -> m (a -> a -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec Int
0
    m (a -> a -> a) -> m a -> m (a -> a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
vec) Int
1)
    m (a -> a) -> m a -> m a
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
vec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
{-# INLINE getMedian3PivotBy #-}

{- |
>>> medianBy compare 3 1 2
2
-}
medianBy :: (a -> a -> Ordering) -> a -> a -> a -> a
medianBy :: forall a. (a -> a -> Ordering) -> a -> a -> a -> a
medianBy a -> a -> Ordering
cmp a
x a
y a
z = case a -> a -> Ordering
cmp a
x a
y of
  Ordering
LT -> case a -> a -> Ordering
cmp a
y a
z of
    Ordering
LT -> a
y
    Ordering
_ -> case a -> a -> Ordering
cmp a
x a
z of
      Ordering
LT -> a
z
      Ordering
_ -> a
x
  Ordering
_ -> case a -> a -> Ordering
cmp a
x a
z of
    Ordering
LT -> a
x
    Ordering
_ -> case a -> a -> Ordering
cmp a
y a
z of
      Ordering
LT -> a
z
      Ordering
_ -> a
y
{-# INLINE medianBy #-}

insertionSortBy ::
  (GM.MVector mv a, PrimMonad m) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m ()
insertionSortBy :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec = do
  ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix
    ( \Int -> m ()
outer !Int
i -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        a
v0 <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec Int
0
        a
vi <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec Int
i
        case a -> a -> Ordering
cmp a
v0 a
vi of
          Ordering
GT -> do
            ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix
              ( \Int -> m ()
inner !Int
j -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                  mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) m a -> (a -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j
                  Int -> m ()
inner (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
              )
              Int
i
            mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
0 a
vi
            Int -> m ()
outer (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          Ordering
_ -> do
            ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix
              ( \Int -> m ()
inner !Int
j -> do
                  a
vj' <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                  case a -> a -> Ordering
cmp a
vj' a
vi of
                    Ordering
GT -> do
                      mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j a
vj'
                      Int -> m ()
inner (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                    Ordering
_ -> do
                      mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j a
vi
                      Int -> m ()
outer (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              )
              Int
i
    )
    Int
1
  where
    !n :: Int
n = mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec
{-# INLINE insertionSortBy #-}