module Data.Vector.Sort.Quick where

import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Function
import qualified Data.Vector.Generic.Mutable as GM
import System.Random

import System.Random.Utils

{- | Random Pivot Quick Sort

 /O(n log n)/
-}
quickSort ::
  (Ord a, PrimMonad m, GM.MVector mv a) =>
  mv (PrimState m) a ->
  m ()
quickSort :: forall a (m :: * -> *) (mv :: * -> * -> *).
(Ord a, PrimMonad m, MVector mv a) =>
mv (PrimState m) a -> m ()
quickSort = (a -> a -> Ordering) -> mv (PrimState m) a -> m ()
forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
quickSortBy a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINE quickSort #-}

{- | Random Pivot Quick Sort

 /O(n log n)/
-}
quickSortBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m ()
quickSortBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
quickSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mv0 = do
  rng0 <- m StdGen
forall (m :: * -> *). PrimMonad m => m StdGen
newStdGenPrim
  void
    $ fix
      ( \mv (PrimState m) a -> StdGen -> m StdGen
loop !mv (PrimState m) a
mvec !StdGen
rng ->
          if mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
32
            then do
              case Word64 -> StdGen -> (Word64, StdGen)
forall g. RandomGen g => Word64 -> g -> (Word64, g)
genWord64R (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) StdGen
rng of
                (Word64
w64, StdGen
rng') -> do
                  pivot <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.read mv (PrimState m) a
mvec (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w64)
                  cut <- pivotPartitionBy cmp mvec pivot
                  loop (GM.take cut mvec) rng'
                    >>= loop (GM.drop cut mvec)
            else do
              (a -> a -> Ordering) -> mv (PrimState m) a -> m ()
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec
              StdGen -> m StdGen
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return StdGen
rng
      )
      mv0
      rng0
{-# INLINE quickSortBy #-}

{- | Random Pivot Quick Select

 /O(n)/
-}
quickSelect ::
  (Ord a, PrimMonad m, GM.MVector mv a) =>
  mv (PrimState m) a ->
  Int ->
  m a
quickSelect :: forall a (m :: * -> *) (mv :: * -> * -> *).
(Ord a, PrimMonad m, MVector mv a) =>
mv (PrimState m) a -> Int -> m a
quickSelect = (a -> a -> Ordering) -> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> Int -> m a
quickSelectBy a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
{-# INLINE quickSelect #-}

{- | Random Pivot Quick Select

 /O(n)/
-}
quickSelectBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  Int ->
  m a
quickSelectBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> Int -> m a
quickSelectBy a -> a -> Ordering
cmp mv (PrimState m) a
mv0 Int
k0 =  do
  rng0 <- m StdGen
forall (m :: * -> *). PrimMonad m => m StdGen
newStdGenPrim
  fix
    ( \mv (PrimState m) a -> Int -> StdGen -> m a
loop !mv (PrimState m) a
mvec !Int
k !StdGen
rng -> do
        if mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
32
          then do
            case Word64 -> StdGen -> (Word64, StdGen)
forall g. RandomGen g => Word64 -> g -> (Word64, g)
genWord64R (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) StdGen
rng of
                (Word64
w64, StdGen
rng') -> do
                  pivot <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.read mv (PrimState m) a
mvec (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w64)
                  cut <- pivotPartitionBy cmp mvec pivot
                  if k < cut
                    then loop (GM.take cut mvec) k rng'
                    else loop (GM.drop cut mvec) (k - cut) rng'
          else do
            (a -> a -> Ordering) -> mv (PrimState m) a -> m ()
forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec
            mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec Int
k
    )
    mv0
    k0
    rng0
{-# INLINE quickSelectBy #-}

pivotPartitionBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  a ->
  m Int
pivotPartitionBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> a -> m Int
pivotPartitionBy a -> a -> Ordering
cmp mv (PrimState m) a
vec !a
pivot =
  ((Int -> Int -> m Int) -> Int -> Int -> m Int)
-> Int -> Int -> m Int
forall a. (a -> a) -> a
fix
    ( \Int -> Int -> m Int
loop !Int
l !Int
r -> do
        !l' <-
          ((Int -> m Int) -> Int -> m Int) -> Int -> m Int
forall a. (a -> a) -> a
fix
            ( \Int -> m Int
loopL !Int
i -> do
                vi <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec Int
i
                case cmp vi pivot of
                  Ordering
LT -> Int -> m Int
loopL (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                  Ordering
_ -> Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
            )
            Int
l
        !r' <-
          fix
            ( \Int -> m Int
loopR !Int
i -> do
                vi <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec Int
i
                case cmp pivot vi of
                  Ordering
LT -> Int -> m Int
loopR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                  Ordering
_ -> Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
            )
            (r - 1)
        if l' < r'
          then do
            GM.unsafeSwap vec l' r'
            loop (l' + 1) r'
          else return l'
    )
    Int
0
    (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
vec)
{-# INLINE pivotPartitionBy #-}

getMedian3PivotBy ::
  (PrimMonad m, GM.MVector mv a) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m a
getMedian3PivotBy :: forall (m :: * -> *) (mv :: * -> * -> *) a.
(PrimMonad m, MVector mv a) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m a
getMedian3PivotBy a -> a -> Ordering
cmp mv (PrimState m) a
vec =
  (a -> a -> Ordering) -> a -> a -> a -> a
forall a. (a -> a -> Ordering) -> a -> a -> a -> a
medianBy a -> a -> Ordering
cmp
    (a -> a -> a -> a) -> m a -> m (a -> a -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec Int
0
    m (a -> a -> a) -> m a -> m (a -> a)
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec (Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
vec) Int
1)
    m (a -> a) -> m a -> m a
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
vec (mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
vec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
{-# INLINE getMedian3PivotBy #-}

{- |
>>> medianBy compare 3 1 2
2
-}
medianBy :: (a -> a -> Ordering) -> a -> a -> a -> a
medianBy :: forall a. (a -> a -> Ordering) -> a -> a -> a -> a
medianBy a -> a -> Ordering
cmp a
x a
y a
z = case a -> a -> Ordering
cmp a
x a
y of
  Ordering
LT -> case a -> a -> Ordering
cmp a
y a
z of
    Ordering
LT -> a
y
    Ordering
_ -> case a -> a -> Ordering
cmp a
x a
z of
      Ordering
LT -> a
z
      Ordering
_ -> a
x
  Ordering
_ -> case a -> a -> Ordering
cmp a
x a
z of
    Ordering
LT -> a
x
    Ordering
_ -> case a -> a -> Ordering
cmp a
y a
z of
      Ordering
LT -> a
z
      Ordering
_ -> a
y
{-# INLINE medianBy #-}

insertionSortBy ::
  (GM.MVector mv a, PrimMonad m) =>
  (a -> a -> Ordering) ->
  mv (PrimState m) a ->
  m ()
insertionSortBy :: forall (mv :: * -> * -> *) a (m :: * -> *).
(MVector mv a, PrimMonad m) =>
(a -> a -> Ordering) -> mv (PrimState m) a -> m ()
insertionSortBy a -> a -> Ordering
cmp mv (PrimState m) a
mvec = do
  ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix
    ( \Int -> m ()
outer !Int
i -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        v0 <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec Int
0
        vi <- GM.unsafeRead mvec i
        case cmp v0 vi of
          Ordering
GT -> do
            ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix
              ( \Int -> m ()
inner !Int
j -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                  mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) m a -> (a -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j
                  Int -> m ()
inner (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
              )
              Int
i
            mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
0 a
vi
            Int -> m ()
outer (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          Ordering
_ -> do
            ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix
              ( \Int -> m ()
inner !Int
j -> do
                  vj' <- mv (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
GM.unsafeRead mv (PrimState m) a
mvec (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                  case cmp vj' vi of
                    Ordering
GT -> do
                      mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j a
vj'
                      Int -> m ()
inner (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                    Ordering
_ -> do
                      mv (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
GM.unsafeWrite mv (PrimState m) a
mvec Int
j a
vi
                      Int -> m ()
outer (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              )
              Int
i
    )
    Int
1
  where
    !n :: Int
n = mv (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length mv (PrimState m) a
mvec
{-# INLINE insertionSortBy #-}