module Data.ByteString.LCP where

import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Unsafe as B
import Data.Function
import Data.Int
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import Data.ByteString.SuffixArray
import My.Prelude

newtype LCPArray = LCPArray {LCPArray -> Vector Int
getLCPArray :: U.Vector Int}

instance Show LCPArray where
  show :: LCPArray -> String
show = Vector Int -> String
forall a. Show a => a -> String
show (Vector Int -> String)
-> (LCPArray -> Vector Int) -> LCPArray -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LCPArray -> Vector Int
getLCPArray

{- |
>>> :set -XOverloadedStrings
>>> sa = buildSuffixArray "abab"
>>> viewSuffixArray "abab" sa
["","ab","abab","b","bab"]
>>> lcp = buildLCPArray "abab" sa
>>> lcp
[0,2,0,1]
>>> viewLCPArray "abab" sa lcp
["","ab","","b"]

>>> :set -XOverloadedStrings
>>> bs = " ab ab a"
>>> n = length $ C.words bs
>>> sa = buildSuffixArray bs
>>> take n . tail $ viewSuffixArray bs sa
[" a"," ab a"," ab ab a"]
>>> lcp = buildLCPArray bs sa
>>> take (n - 1) . tail $ viewLCPArray bs sa lcp
[" a"," ab a"]
-}
viewLCPArray :: B.ByteString -> SuffixArray Int32 -> LCPArray -> [String]
viewLCPArray :: ByteString -> SuffixArray Int32 -> LCPArray -> [String]
viewLCPArray ByteString
bs (SuffixArray Vector Int32
sa) (LCPArray Vector Int
lcp) =
  ((Int, Int) -> String) -> [(Int, Int)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\(Int
i, Int
l) -> ByteString -> String
C.unpack (ByteString -> String)
-> (ByteString -> ByteString) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> ByteString
C.take Int
l (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
C.drop (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ Vector Int32
sa Vector Int32 -> Int -> Int32
forall a. Unbox a => Vector a -> Int -> a
U.! Int
i) ByteString
bs)
    ([(Int, Int)] -> [String])
-> (Vector (Int, Int) -> [(Int, Int)])
-> Vector (Int, Int)
-> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (Int, Int) -> [(Int, Int)]
forall a. Unbox a => Vector a -> [a]
U.toList
    (Vector (Int, Int) -> [String]) -> Vector (Int, Int) -> [String]
forall a b. (a -> b) -> a -> b
$ Vector Int -> Vector (Int, Int)
forall a. Unbox a => Vector a -> Vector (Int, a)
U.indexed Vector Int
lcp

{- | /O(n)/

* lcp[i] = lcp(sa[i], sa[i+1])
* lcp(s[l],s[r]) = minimum[lcp[sa[l]],lcp[sa[l]+1]..lcp[sa[r]-1]]

>>> :set -XOverloadedStrings
>>> bs = "abracadabra"
>>> buildLCPArray bs $ buildSuffixArray bs
[0,1,4,1,1,0,3,0,0,0,2]
-}
buildLCPArray :: B.ByteString -> SuffixArray Int32 -> LCPArray
buildLCPArray :: ByteString -> SuffixArray Int32 -> LCPArray
buildLCPArray ByteString
bs SuffixArray Int32
sa = Vector Int -> LCPArray
LCPArray (Vector Int -> LCPArray) -> Vector Int -> LCPArray
forall a b. (a -> b) -> a -> b
$
  (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
    lcp <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew Int
n
    UM.unsafeWrite lcp 0 0
    U.ifoldM'_
      ( \Int
h Int
i Int
r -> do
          let !j :: Int
j = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ SuffixArray Int32 -> Int -> Int32
forall a. Unbox a => SuffixArray a -> Int -> a
indexSA SuffixArray Int32
sa (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
              h' :: Int
h' =
                ((Int -> Int) -> Int -> Int) -> Int -> Int
forall a. (a -> a) -> a
fix
                  ( \Int -> Int
loop !Int
d ->
                      if Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
                        Bool -> Bool -> Bool
&& Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
                        Bool -> Bool -> Bool
&& ByteString -> Int -> Word8
B.unsafeIndex ByteString
bs (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
d) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int -> Word8
B.unsafeIndex ByteString
bs (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
d)
                        then Int -> Int
loop (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                        else Int
d
                  )
                  (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
          MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
lcp (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
h'
          Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
h'
      )
      0
      $ U.init rank
    return lcp
  where
    !n :: Int
n = ByteString -> Int
B.length ByteString
bs
    !rank :: Vector Int
rank = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
      buf <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      rep (n + 1) $ \Int
i -> do
        MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
buf (Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ SuffixArray Int32 -> Int -> Int32
forall a. Unbox a => SuffixArray a -> Int -> a
indexSA SuffixArray Int32
sa Int
i) Int
i
      return buf

{- | /O(n)/

>>> :set -XOverloadedStrings
>>> naiveLCP "abc0" "abc1"
3
>>> naiveLCP "ab" "a"
1
>>> naiveLCP "" ""
0
-}
naiveLCP :: B.ByteString -> B.ByteString -> Int
naiveLCP :: ByteString -> ByteString -> Int
naiveLCP ByteString
xs ByteString
ys = Int -> Int
go Int
0
  where
    !n :: Int
n = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (ByteString -> Int
B.length ByteString
xs) (ByteString -> Int
B.length ByteString
ys)
    go :: Int -> Int
go !Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
      , ByteString -> Int -> Word8
B.unsafeIndex ByteString
xs Int
i Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int -> Word8
B.unsafeIndex ByteString
ys Int
i =
          Int -> Int
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Bool
otherwise = Int
i