module Data.ByteString.SuffixArray where

import Control.Monad
import Control.Monad.ST
import Data.Bits
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Unsafe as B
import Data.Coerce
import Data.Function
import Data.Int
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import My.Prelude

newtype SuffixArray a = SuffixArray {forall a. SuffixArray a -> Vector a
getSuffixArray :: U.Vector a}
  deriving (SuffixArray a -> SuffixArray a -> Bool
(SuffixArray a -> SuffixArray a -> Bool)
-> (SuffixArray a -> SuffixArray a -> Bool) -> Eq (SuffixArray a)
forall a. (Unbox a, Eq a) => SuffixArray a -> SuffixArray a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. (Unbox a, Eq a) => SuffixArray a -> SuffixArray a -> Bool
== :: SuffixArray a -> SuffixArray a -> Bool
$c/= :: forall a. (Unbox a, Eq a) => SuffixArray a -> SuffixArray a -> Bool
/= :: SuffixArray a -> SuffixArray a -> Bool
Eq)

instance (Show a, U.Unbox a) => Show (SuffixArray a) where
  show :: SuffixArray a -> String
show = Vector a -> String
forall a. Show a => a -> String
show (Vector a -> String)
-> (SuffixArray a -> Vector a) -> SuffixArray a -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SuffixArray a -> Vector a
forall a. SuffixArray a -> Vector a
getSuffixArray

indexSA :: (U.Unbox a) => SuffixArray a -> Int -> a
indexSA :: forall a. Unbox a => SuffixArray a -> Int -> a
indexSA = (Vector a -> Int -> a) -> SuffixArray a -> Int -> a
forall a b. Coercible a b => a -> b
coerce Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex
{-# INLINE indexSA #-}

{- |
/O(Tlog S)/

>>> :set -XOverloadedStrings
>>> bs = "ababab"
>>> sa = buildSuffixArray bs
>>> findSubstringsSA bs sa "ab"
[4,2,0]
>>> findSubstringsSA bs sa "xxx"
[]
>>> findSubstringsSA bs sa ""
[6,4,2,0,5,3,1]
-}
findSubstringsSA :: B.ByteString -> SuffixArray Int32 -> B.ByteString -> U.Vector Int32
findSubstringsSA :: ByteString -> SuffixArray Int32 -> ByteString -> Vector Int32
findSubstringsSA ByteString
haystack (SuffixArray Vector Int32
sa) ByteString
needle = Int -> Int -> Vector Int32 -> Vector Int32
forall a. Unbox a => Int -> Int -> Vector a -> Vector a
U.unsafeSlice Int
l (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l) Vector Int32
sa
  where
    !n :: Int
n = ByteString -> Int
B.length ByteString
haystack
    !m :: Int
m = ByteString -> Int
B.length ByteString
needle
    !l :: Int
l = Int -> Int -> (Int -> Bool) -> Int
binarySearch Int
0 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ((Int -> Bool) -> Int) -> (Int -> Bool) -> Int
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      let !sai :: Int
sai = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int32 -> Int -> Int32
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int32
sa Int
i
      ByteString
needle ByteString -> ByteString -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> ByteString -> ByteString
B.take Int
m (Int -> ByteString -> ByteString
B.unsafeDrop Int
sai ByteString
haystack)
    !r :: Int
r = Int -> Int -> (Int -> Bool) -> Int
binarySearch Int
l (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ((Int -> Bool) -> Int) -> (Int -> Bool) -> Int
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      let !sai :: Int
sai = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int32 -> Int -> Int32
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int32
sa Int
i
      ByteString
needle ByteString -> ByteString -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> ByteString -> ByteString
B.take Int
m (Int -> ByteString -> ByteString
B.unsafeDrop Int
sai ByteString
haystack)

{- | SA-IS /O(n)/

>>> :set -XOverloadedStrings
>>> buildSuffixArray "aaa"
[3,2,1,0]
>>> buildSuffixArray "mississippi"
[11,10,7,4,1,0,9,8,6,3,5,2]
>>> buildSuffixArray "abracadabra"
[11,10,7,0,3,5,8,1,4,6,9,2]
>>> buildSuffixArray "ababab"
[6,4,2,0,5,3,1]
>>> buildSuffixArray ""
[0]
>>> buildSuffixArray "sentinel\0"
[-1,8,6,1,4,7,5,2,0,3]
-}
buildSuffixArray :: B.ByteString -> SuffixArray Int32
buildSuffixArray :: ByteString -> SuffixArray Int32
buildSuffixArray ByteString
bs = Vector Int32 -> SuffixArray Int32
forall a. Vector a -> SuffixArray a
SuffixArray (Vector Int32 -> SuffixArray Int32)
-> Vector Int32 -> SuffixArray Int32
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (MVector s Int32)) -> Vector Int32
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int32)) -> Vector Int32)
-> (forall s. ST s (MVector s Int32)) -> Vector Int32
forall a b. (a -> b) -> a -> b
$ do
  sa <- Int -> Int32 -> ST s (MVector (PrimState (ST s)) Int32)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (-Int32
1)
  if n > 0
    then
      sais sa (maxBound :: Int8) $
        U.scanr' setLS sentinelLS (U.generate n (fromIntegral . B.unsafeIndex bs))
    else UM.write sa 0 0
  return sa
  where
    n :: Int
n = ByteString -> Int
B.length ByteString
bs

{- |
>>> :set -XOverloadedStrings
>>> viewSuffixArray "abc" $ buildSuffixArray "abc"
["","abc","bc","c"]
>>> viewSuffixArray " a b c" $ buildSuffixArray " a b c"
[""," a b c"," b c"," c","a b c","b c","c"]
-}
viewSuffixArray :: C.ByteString -> SuffixArray Int32 -> [String]
viewSuffixArray :: ByteString -> SuffixArray Int32 -> [String]
viewSuffixArray ByteString
bs (SuffixArray Vector Int32
sa) =
  (Int32 -> String) -> [Int32] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> String
C.unpack (ByteString -> String) -> (Int32 -> ByteString) -> Int32 -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> ByteString -> ByteString
`C.drop` ByteString
bs) (Int -> ByteString) -> (Int32 -> Int) -> Int32 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral) ([Int32] -> [String]) -> [Int32] -> [String]
forall a b. (a -> b) -> a -> b
$
    Vector Int32 -> [Int32]
forall a. Unbox a => Vector a -> [a]
U.toList Vector Int32
sa

class (Ord a, Num a, Integral a) => LSInt a where
  isL :: a -> Bool
  isS :: a -> Bool
  unLS :: a -> a
  setLS :: a -> a -> a
  sentinelLS :: a
  sentinelLS = a
0

instance LSInt Int where
  isL :: Int -> Bool
isL = (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0)
  {-# INLINE isL #-}
  isS :: Int -> Bool
isS = (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)
  {-# INLINE isS #-}
  unLS :: Int -> Int
unLS = (Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
forall a. Bounded a => a
maxBound)
  {-# INLINE unLS #-}
  setLS :: Int -> Int -> Int
setLS Int
c Int
c' = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
c (Int -> Int
forall a. LSInt a => a -> a
unLS Int
c') of
    Ordering
LT -> Int
c
    Ordering
EQ -> Int
c Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
c' Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
forall a. Bounded a => a
minBound
    Ordering
GT -> Int
c Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Int
forall a. Bounded a => a
minBound
  {-# INLINE setLS #-}

instance LSInt Int8 where
  isL :: Int8 -> Bool
isL = (Int8 -> Int8 -> Bool
forall a. Ord a => a -> a -> Bool
< Int8
0)
  {-# INLINE isL #-}
  isS :: Int8 -> Bool
isS = (Int8 -> Int8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int8
0)
  {-# INLINE isS #-}
  unLS :: Int8 -> Int8
unLS = (Int8 -> Int8 -> Int8
forall a. Bits a => a -> a -> a
.&. Int8
forall a. Bounded a => a
maxBound)
  {-# INLINE unLS #-}
  setLS :: Int8 -> Int8 -> Int8
setLS Int8
c Int8
c' = case Int8 -> Int8 -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int8
c (Int8 -> Int8
forall a. LSInt a => a -> a
unLS Int8
c') of
    Ordering
LT -> Int8
c
    Ordering
EQ -> Int8
c Int8 -> Int8 -> Int8
forall a. Bits a => a -> a -> a
.|. Int8
c' Int8 -> Int8 -> Int8
forall a. Bits a => a -> a -> a
.&. Int8
forall a. Bounded a => a
minBound
    Ordering
GT -> Int8
c Int8 -> Int8 -> Int8
forall a. Bits a => a -> a -> a
.|. Int8
forall a. Bounded a => a
minBound
  {-# INLINE setLS #-}

instance LSInt Int16 where
  isL :: Int16 -> Bool
isL = (Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
< Int16
0)
  {-# INLINE isL #-}
  isS :: Int16 -> Bool
isS = (Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int16
0)
  {-# INLINE isS #-}
  unLS :: Int16 -> Int16
unLS = (Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.&. Int16
forall a. Bounded a => a
maxBound)
  {-# INLINE unLS #-}
  setLS :: Int16 -> Int16 -> Int16
setLS Int16
c Int16
c' = case Int16 -> Int16 -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int16
c (Int16 -> Int16
forall a. LSInt a => a -> a
unLS Int16
c') of
    Ordering
LT -> Int16
c
    Ordering
EQ -> Int16
c Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.|. Int16
c' Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.&. Int16
forall a. Bounded a => a
minBound
    Ordering
GT -> Int16
c Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.|. Int16
forall a. Bounded a => a
minBound
  {-# INLINE setLS #-}

instance LSInt Int32 where
  isL :: Int32 -> Bool
isL = (Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
0)
  {-# INLINE isL #-}
  isS :: Int32 -> Bool
isS = (Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int32
0)
  {-# INLINE isS #-}
  unLS :: Int32 -> Int32
unLS = (Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
forall a. Bounded a => a
maxBound)
  {-# INLINE unLS #-}
  setLS :: Int32 -> Int32 -> Int32
setLS Int32
c Int32
c' = case Int32 -> Int32 -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int32
c (Int32 -> Int32
forall a. LSInt a => a -> a
unLS Int32
c') of
    Ordering
LT -> Int32
c
    Ordering
EQ -> Int32
c Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.|. Int32
c' Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
forall a. Bounded a => a
minBound
    Ordering
GT -> Int32
c Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.|. Int32
forall a. Bounded a => a
minBound
  {-# INLINE setLS #-}

instance LSInt Int64 where
  isL :: Int64 -> Bool
isL = (Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0)
  {-# INLINE isL #-}
  isS :: Int64 -> Bool
isS = (Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0)
  {-# INLINE isS #-}
  unLS :: Int64 -> Int64
unLS = (Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
forall a. Bounded a => a
maxBound)
  {-# INLINE unLS #-}
  setLS :: Int64 -> Int64 -> Int64
setLS Int64
c Int64
c' = case Int64 -> Int64 -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int64
c (Int64 -> Int64
forall a. LSInt a => a -> a
unLS Int64
c') of
    Ordering
LT -> Int64
c
    Ordering
EQ -> Int64
c Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.|. Int64
c' Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
forall a. Bounded a => a
minBound
    Ordering
GT -> Int64
c Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.|. Int64
forall a. Bounded a => a
minBound
  {-# INLINE setLS #-}

isLMS :: (LSInt a, LSInt b, U.Unbox a, U.Unbox b) => U.Vector a -> b -> Bool
isLMS :: forall a b.
(LSInt a, LSInt b, Unbox a, Unbox b) =>
Vector a -> b -> Bool
isLMS Vector a
ls b
si =
  b
si b -> b -> Bool
forall a. Ord a => a -> a -> Bool
> b
0
    Bool -> Bool -> Bool
&& a -> Bool
forall a. LSInt a => a -> Bool
isL (Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls (b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
si Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
    Bool -> Bool -> Bool
&& a -> Bool
forall a. LSInt a => a -> Bool
isS (Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls (b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
si))
{-# INLINE isLMS #-}

buildInitialBucket :: (LSInt a, U.Unbox a) => a -> U.Vector a -> U.Vector Int32
buildInitialBucket :: forall a. (LSInt a, Unbox a) => a -> Vector a -> Vector Int32
buildInitialBucket a
maxC Vector a
ls =
  (Int32 -> Int32 -> Int32) -> Int32 -> Vector Int32 -> Vector Int32
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> a -> Vector b -> Vector a
U.scanl' Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
(+) Int32
0 (Vector Int32 -> Vector Int32) -> Vector Int32 -> Vector Int32
forall a b. (a -> b) -> a -> b
$
    (Int32 -> Int32 -> Int32)
-> Vector Int32 -> Vector Int -> Vector Int32 -> Vector Int32
forall a b.
(Unbox a, Unbox b) =>
(a -> b -> a) -> Vector a -> Vector Int -> Vector b -> Vector a
U.unsafeAccumulate_
      Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
(+)
      (Int -> Int32 -> Vector Int32
forall a. Unbox a => Int -> a -> Vector a
U.replicate (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
maxC Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int32
0)
      ((a -> Int) -> Vector a -> Vector Int
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> (a -> a) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. LSInt a => a -> a
unLS) Vector a
ls)
      (Int -> Int32 -> Vector Int32
forall a. Unbox a => Int -> a -> Vector a
U.replicate (Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
ls) Int32
1)
{-# INLINE buildInitialBucket #-}

findLMSIndices :: (LSInt a, U.Unbox a) => U.Vector a -> U.Vector Int
findLMSIndices :: forall a. (LSInt a, Unbox a) => Vector a -> Vector Int
findLMSIndices Vector a
ls =
  (Int -> Bool) -> Vector Int -> Vector Int
forall a. Unbox a => (a -> Bool) -> Vector a -> Vector a
U.filter (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$
    (Int -> a -> a -> Int) -> Vector a -> Vector a -> Vector Int
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(Int -> a -> b -> c) -> Vector a -> Vector b -> Vector c
U.izipWith
      ( \Int
i a
pc a
c ->
          if a -> Bool
forall a. LSInt a => a -> Bool
isL a
pc Bool -> Bool -> Bool
&& a -> Bool
forall a. LSInt a => a -> Bool
isS a
c
            then Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            else Int
0
      )
      Vector a
ls
      (Vector a -> Vector a
forall a. Unbox a => Vector a -> Vector a
U.tail Vector a
ls)
{-# INLINE findLMSIndices #-}

induceSortL ::
  (LSInt a, LSInt b, U.Unbox a, U.Unbox b) =>
  UM.MVector s a ->
  UM.MVector s Int32 ->
  U.Vector b ->
  U.Vector Int32 ->
  ST s ()
induceSortL :: forall a b s.
(LSInt a, LSInt b, Unbox a, Unbox b) =>
MVector s a
-> MVector s Int32 -> Vector b -> Vector Int32 -> ST s ()
induceSortL MVector s a
sa MVector s Int32
bucket Vector b
ls Vector Int32
bucket0 = do
  MVector (PrimState (ST s)) Int32 -> Vector Int32 -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
U.copy MVector s Int32
MVector (PrimState (ST s)) Int32
bucket (Vector Int32 -> Vector Int32
forall a. Unbox a => Vector a -> Vector a
U.init Vector Int32
bucket0)
  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep (Vector b -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector b
ls) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    j <- Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1 (Int -> Int) -> (a -> Int) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> ST s a -> ST s Int
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s a
MVector (PrimState (ST s)) a
sa Int
i
    when (j >= 0) $ do
      let c = Vector b -> Int -> b
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector b
ls Int
j
      when (isL c) $ do
        pos <- UM.unsafeRead bucket (fromIntegral (unLS c))
        UM.unsafeWrite bucket (fromIntegral (unLS c)) (pos + 1)
        UM.unsafeWrite sa (fromIntegral pos) (fromIntegral j)
{-# INLINE induceSortL #-}

induceSortS ::
  (LSInt a, LSInt b, U.Unbox a, U.Unbox b) =>
  UM.MVector s a ->
  UM.MVector s Int32 ->
  U.Vector b ->
  U.Vector Int32 ->
  ST s ()
induceSortS :: forall a b s.
(LSInt a, LSInt b, Unbox a, Unbox b) =>
MVector s a
-> MVector s Int32 -> Vector b -> Vector Int32 -> ST s ()
induceSortS MVector s a
sa MVector s Int32
bucket Vector b
ls Vector Int32
bucket0 = do
  MVector (PrimState (ST s)) Int32 -> Vector Int32 -> ST s ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
U.copy MVector s Int32
MVector (PrimState (ST s)) Int32
bucket (Vector Int32 -> Vector Int32
forall a. Unbox a => Vector a -> Vector a
U.tail Vector Int32
bucket0)
  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rev (Vector b -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector b
ls) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    j <- Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1 (Int -> Int) -> (a -> Int) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> ST s a -> ST s Int
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s a
MVector (PrimState (ST s)) a
sa Int
i
    when (j >= 0) $ do
      let c = Vector b -> Int -> b
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector b
ls Int
j
      when (isS c) $ do
        pos <- subtract 1 <$!> UM.unsafeRead bucket (fromIntegral (unLS c))
        UM.unsafeWrite bucket (fromIntegral (unLS c)) pos
        UM.unsafeWrite sa (fromIntegral pos) (fromIntegral j)
{-# INLINE induceSortS #-}

reduceLMS ::
  (LSInt a, LSInt b, U.Unbox a, U.Unbox b) =>
  UM.MVector s a ->
  U.Vector b ->
  ST s (UM.MVector s a, a, U.Vector a)
reduceLMS :: forall a b s.
(LSInt a, LSInt b, Unbox a, Unbox b) =>
MVector s a -> Vector b -> ST s (MVector s a, a, Vector a)
reduceLMS MVector s a
sa Vector b
ls = do
  !n1 <-
    (Int -> Int -> ST s Int) -> Int -> Vector Int -> ST s Int
forall (m :: * -> *) b a.
(Monad m, Unbox b) =>
(a -> b -> m a) -> a -> Vector b -> m a
U.foldM'
      ( \Int
pos Int
i -> do
          sj <- MVector (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s a
MVector (PrimState (ST s)) a
sa Int
i
          if isLMS ls sj
            then do
              UM.unsafeWrite sa pos sj
              return $! pos + 1
            else return pos
      )
      Int
0
      (Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate Int
n Int -> Int
forall a. a -> a
id)

  UM.set (UM.drop n1 sa) (-1)
  !rank <-
    fst
      <$!> U.foldM'
        ( \(!a
r, !Int
prev) Int
i -> do
            cur <- a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> ST s a -> ST s Int
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s a
MVector (PrimState (ST s)) a
sa Int
i
            let !r' = a
r a -> a -> a
forall a. Num a => a -> a -> a
+ Vector b -> Int -> Int -> a
forall a b.
(LSInt a, LSInt b, Unbox a) =>
Vector a -> Int -> Int -> b
neqLMS Vector b
ls Int
prev Int
cur
            UM.unsafeWrite sa (n1 + unsafeShiftR cur 1) r'
            return (r', cur)
        )
        (-1, 0)
        (U.generate n1 id)

  UM.write sa (n - 1) sentinelLS
  fix
    ( \Int -> Int -> ST s ()
loop !Int
pos !Int
i -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n1) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        r <- MVector (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s a
MVector (PrimState (ST s)) a
sa Int
i
        if r > 0
          then do
            r' <- UM.unsafeRead sa (pos + 1)
            UM.unsafeWrite sa pos (setLS r r')
            loop (pos - 1) (i - 1)
          else loop pos (i - 1)
    )
    (n - 2)
    (n1 + unsafeShiftR (n - 1) 1 - 1)

  (,,) (UM.take n1 sa) rank
    <$> U.unsafeFreeze (UM.drop (n - n1) sa)
  where
    !n :: Int
n = Vector b -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector b
ls

neqLMS :: (LSInt a, LSInt b, U.Unbox a) => U.Vector a -> Int -> Int -> b
neqLMS :: forall a b.
(LSInt a, LSInt b, Unbox a) =>
Vector a -> Int -> Int -> b
neqLMS Vector a
ls Int
si Int
sj
  | Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls Int
si a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls Int
sj = b
1
  | Bool
otherwise = Int -> b
forall {t}. Num t => Int -> t
go Int
1
  where
    go :: Int -> t
go !Int
k
      | a
ci a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
cj = t
1
      | a -> Bool
forall a. LSInt a => a -> Bool
isS a
ci, a -> Bool
forall a. LSInt a => a -> Bool
isL (Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls (Int
si Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) = t
0
      | Bool
otherwise = Int -> t
go (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      where
        ci :: a
ci = Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls (Int
si Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
        cj :: a
cj = Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls (Int
sj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
{-# INLINE neqLMS #-}

sais ::
  (LSInt a, LSInt b, U.Unbox a, U.Unbox b) =>
  -- | filled with (-1)
  UM.MVector s b ->
  -- | the maximum alphabet
  a ->
  -- | LS typed
  U.Vector a ->
  ST s ()
sais :: forall a b s.
(LSInt a, LSInt b, Unbox a, Unbox b) =>
MVector s b -> a -> Vector a -> ST s ()
sais MVector s b
msa a
_ Vector a
ls | Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
ls Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = MVector (PrimState (ST s)) b -> Int -> b -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s b
MVector (PrimState (ST s)) b
msa Int
0 b
0
sais MVector s b
msa a
maxC Vector a
ls = do
  bkt <- Vector Int32 -> ST s (MVector (PrimState (ST s)) Int32)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
U.thaw (Vector Int32 -> ST s (MVector (PrimState (ST s)) Int32))
-> Vector Int32 -> ST s (MVector (PrimState (ST s)) Int32)
forall a b. (a -> b) -> a -> b
$ Vector Int32 -> Vector Int32
forall a. Unbox a => Vector a -> Vector a
U.tail Vector Int32
bucket0

  U.ifoldM'_
    ( \a
pc Int
si a
c -> do
        if a -> Bool
forall a. LSInt a => a -> Bool
isL a
pc Bool -> Bool -> Bool
&& a -> Bool
forall a. LSInt a => a -> Bool
isS a
c
          then do
            pos <- Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
subtract Int32
1 (Int32 -> Int32) -> ST s Int32 -> ST s Int32
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MVector (PrimState (ST s)) Int32 -> Int -> ST s Int32
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int32
MVector (PrimState (ST s)) Int32
bkt (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> a
forall a. LSInt a => a -> a
unLS a
c))
            UM.unsafeWrite bkt (fromIntegral (unLS c)) pos
            UM.unsafeWrite msa (fromIntegral pos) (fromIntegral si)
            return c
          else a -> ST s a
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
c
    )
    sentinelLS
    ls

  induceSortL msa bkt ls bucket0
  induceSortS msa bkt ls bucket0

  (msa', maxC', ls') <- reduceLMS msa ls

  if fromIntegral maxC' < U.length ls' - 1
    then do
      UM.set msa' (-1)
      sais msa' maxC' ls'
    else do
      flip U.imapM_ ls' $ \Int
i b
c -> do
        MVector (PrimState (ST s)) b -> Int -> b -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s b
MVector (PrimState (ST s)) b
msa' (b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b -> b
forall a. LSInt a => a -> a
unLS b
c)) (Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)

  mls' <- U.unsafeThaw ls'
  U.imapM_ (\Int
pos Int
si -> MVector (PrimState (ST s)) b -> Int -> b -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s b
MVector (PrimState (ST s)) b
mls' Int
pos (Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
si)) $
    findLMSIndices ls

  rep (UM.length msa') $ \Int
i -> do
    MVector (PrimState (ST s)) b -> Int -> ST s b
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s b
MVector (PrimState (ST s)) b
msa' Int
i
      ST s b -> (b -> ST s b) -> ST s b
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s)) b -> Int -> ST s b
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s b
MVector (PrimState (ST s)) b
mls' (Int -> ST s b) -> (b -> Int) -> b -> ST s b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
      ST s b -> (b -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVector (PrimState (ST s)) b -> Int -> b -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s b
MVector (PrimState (ST s)) b
msa' Int
i (b -> ST s ()) -> (b -> b) -> b -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  UM.set (UM.drop (UM.length msa') msa) (-1)

  U.copy bkt (U.tail bucket0)
  rev (UM.length msa') $ \Int
i -> do
    !sj <- MVector (PrimState (ST s)) b -> Int -> ST s b
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s b
MVector (PrimState (ST s)) b
msa' Int
i
    UM.unsafeWrite msa' i (-1)
    let c = a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> (a -> a) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. LSInt a => a -> a
unLS (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
ls (b -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
sj)
    pos <- subtract 1 <$!> UM.unsafeRead bkt c
    UM.unsafeWrite bkt c pos
    UM.unsafeWrite msa (fromIntegral pos) sj

  induceSortL msa bkt ls bucket0
  induceSortS msa bkt ls bucket0
  where
    !bucket0 :: Vector Int32
bucket0 = a -> Vector a -> Vector Int32
forall a. (LSInt a, Unbox a) => a -> Vector a -> Vector Int32
buildInitialBucket a
maxC Vector a
ls
{-# SPECIALIZE sais :: UM.MVector s Int32 -> Int8 -> U.Vector Int8 -> ST s () #-}
{-# SPECIALIZE sais :: UM.MVector s Int32 -> Int32 -> U.Vector Int32 -> ST s () #-}
{-# SPECIALIZE sais :: UM.MVector s Int -> Int8 -> U.Vector Int8 -> ST s () #-}
{-# SPECIALIZE sais :: UM.MVector s Int -> Int -> U.Vector Int -> ST s () #-}