module Data.Vector.Sort.Radix where

import Data.Bits
import qualified Data.Foldable as F
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
import Data.Word
import Unsafe.Coerce

--
import Data.Word64

radixSortInt :: U.Vector Int -> U.Vector Int
radixSortInt :: Vector Int -> Vector Int
radixSortInt = Vector Word64 -> Vector Int
forall a b. a -> b
unsafeCoerce (Vector Word64 -> Vector Int)
-> (Vector Int -> Vector Word64) -> Vector Int -> Vector Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Word64 -> Vector Word64
radixSort64 (Vector Word64 -> Vector Word64)
-> (Vector Int -> Vector Word64) -> Vector Int -> Vector Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int -> Vector Word64
forall a b. a -> b
unsafeCoerce

radixSort32 :: U.Vector Word32 -> U.Vector Word32
radixSort32 :: Vector Word32 -> Vector Word32
radixSort32 Vector Word32
v0 = (Vector Word32 -> Int -> Vector Word32)
-> Vector Word32 -> [Int] -> Vector Word32
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' Vector Word32 -> Int -> Vector Word32
forall {a}.
(Unbox a, Integral a, Bits a) =>
Vector a -> Int -> Vector a
step Vector Word32
v0 [Int
0, Int
16]
  where
    mask :: Int -> a -> b
mask Int
k a
x = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftR a
x Int
k a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
0xffff
    step :: Vector a -> Int -> Vector a
step Vector a
v Int
k = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
      MVector s Int
pref <-
        Vector Int -> ST s (MVector s Int)
Vector Int -> ST s (MVector (PrimState (ST s)) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.unsafeThaw
          (Vector Int -> ST s (MVector s Int))
-> (Vector (Int, Int) -> Vector Int)
-> Vector (Int, Int)
-> ST s (MVector s Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int) -> Int -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.prescanl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0
          (Vector Int -> Vector Int)
-> (Vector (Int, Int) -> Vector Int)
-> Vector (Int, Int)
-> Vector Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int)
-> Vector Int -> Vector (Int, Int) -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> Vector a -> Vector (Int, b) -> Vector a
U.unsafeAccumulate Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> Int -> Vector Int
forall a. Unbox a => Int -> a -> Vector a
U.replicate Int
0x10000 Int
0)
          (Vector (Int, Int) -> ST s (MVector s Int))
-> Vector (Int, Int) -> ST s (MVector s Int)
forall a b. (a -> b) -> a -> b
$ (a -> (Int, Int)) -> Vector a -> Vector (Int, Int)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map ((Int -> Int -> (Int, Int)) -> Int -> Int -> (Int, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) Int
1 (Int -> (Int, Int)) -> (a -> Int) -> a -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> Int
forall {a} {b}. (Integral a, Bits a, Num b) => Int -> a -> b
mask Int
k) Vector a
v
      MVector s a
res <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Int -> ST s (MVector (PrimState (ST s)) a))
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v
      Vector a -> (a -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector a
v ((a -> ST s ()) -> ST s ()) -> (a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \a
x -> do
        let !masked :: Int
masked = Int -> a -> Int
forall {a} {b}. (Integral a, Bits a, Num b) => Int -> a -> b
mask Int
k a
x
        Int
i <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
pref Int
masked
        MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
pref Int
masked (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s a
MVector (PrimState (ST s)) a
res Int
i a
x
      MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
res
{-# INLINE radixSort32 #-}

radixSort64 :: U.Vector Word64 -> U.Vector Word64
radixSort64 :: Vector Word64 -> Vector Word64
radixSort64 Vector Word64
v0 = (Vector Word64 -> Int -> Vector Word64)
-> Vector Word64 -> [Int] -> Vector Word64
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' Vector Word64 -> Int -> Vector Word64
forall {a}.
(Unbox a, Integral a, Bits a) =>
Vector a -> Int -> Vector a
step Vector Word64
v0 [Int
0, Int
16, Int
32, Int
48]
  where
    mask :: Int -> a -> b
mask Int
k a
x = a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
unsafeShiftR a
x Int
k a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
0xffff
    step :: Vector a -> Int -> Vector a
step Vector a
v Int
k = (forall s. ST s (MVector s a)) -> Vector a
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s a)) -> Vector a)
-> (forall s. ST s (MVector s a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
      MVector s Int
pref <-
        Vector Int -> ST s (MVector s Int)
Vector Int -> ST s (MVector (PrimState (ST s)) Int)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.unsafeThaw
          (Vector Int -> ST s (MVector s Int))
-> (Vector (Int, Int) -> Vector Int)
-> Vector (Int, Int)
-> ST s (MVector s Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int) -> Int -> Vector Int -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.prescanl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0
          (Vector Int -> Vector Int)
-> (Vector (Int, Int) -> Vector Int)
-> Vector (Int, Int)
-> Vector Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int)
-> Vector Int -> Vector (Int, Int) -> Vector Int
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> Vector a -> Vector (Int, b) -> Vector a
U.unsafeAccumulate Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> Int -> Vector Int
forall a. Unbox a => Int -> a -> Vector a
U.replicate Int
0x10000 Int
0)
          (Vector (Int, Int) -> ST s (MVector s Int))
-> Vector (Int, Int) -> ST s (MVector s Int)
forall a b. (a -> b) -> a -> b
$ (a -> (Int, Int)) -> Vector a -> Vector (Int, Int)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map ((Int -> Int -> (Int, Int)) -> Int -> Int -> (Int, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) Int
1 (Int -> (Int, Int)) -> (a -> Int) -> a -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> Int
forall {a} {b}. (Integral a, Bits a, Num b) => Int -> a -> b
mask Int
k) Vector a
v
      MVector s a
res <- Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Int -> ST s (MVector (PrimState (ST s)) a))
-> Int -> ST s (MVector (PrimState (ST s)) a)
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
v
      Vector a -> (a -> ST s ()) -> ST s ()
forall (m :: * -> *) a b.
(Monad m, Unbox a) =>
Vector a -> (a -> m b) -> m ()
U.forM_ Vector a
v ((a -> ST s ()) -> ST s ()) -> (a -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \a
x -> do
        let !masked :: Int
masked = Int -> a -> Int
forall {a} {b}. (Integral a, Bits a, Num b) => Int -> a -> b
mask Int
k a
x
        Int
i <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
pref Int
masked
        MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
pref Int
masked (Int -> ST s ()) -> Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        MVector (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s a
MVector (PrimState (ST s)) a
res Int
i a
x
      MVector s a -> ST s (MVector s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s a
res
{-# INLINE radixSort64 #-}

radixSort ::
  (U.Unbox a, Word64Encode a) =>
  U.Vector a ->
  U.Vector a
radixSort :: forall a. (Unbox a, Word64Encode a) => Vector a -> Vector a
radixSort = (Word64 -> a) -> Vector Word64 -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map Word64 -> a
forall a. Word64Encode a => Word64 -> a
decode64 (Vector Word64 -> Vector a)
-> (Vector a -> Vector Word64) -> Vector a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Word64 -> Vector Word64
radixSort64 (Vector Word64 -> Vector Word64)
-> (Vector a -> Vector Word64) -> Vector a -> Vector Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Word64) -> Vector a -> Vector Word64
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> Word64
forall a. Word64Encode a => a -> Word64
encode64
{-# INLINE radixSort #-}

radixSortNonNegative ::
  (U.Unbox a, Word64Encode a) =>
  U.Vector a ->
  U.Vector a
radixSortNonNegative :: forall a. (Unbox a, Word64Encode a) => Vector a -> Vector a
radixSortNonNegative =
  (Word64 -> a) -> Vector Word64 -> Vector a
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map Word64 -> a
forall a. Word64Encode a => Word64 -> a
decodeNonNegative64
    (Vector Word64 -> Vector a)
-> (Vector a -> Vector Word64) -> Vector a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Word64 -> Vector Word64
radixSort64
    (Vector Word64 -> Vector Word64)
-> (Vector a -> Vector Word64) -> Vector a -> Vector Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Word64) -> Vector a -> Vector Word64
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> Word64
forall a. Word64Encode a => a -> Word64
encodeNonNegative64
{-# INLINE radixSortNonNegative #-}