module Data.ByteString.ZAlgorithm where

import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Data.Function (fix)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

{- |

@z[i] = lcp s $ drop i s@

time complexity: /O(n)/

>>> :set -XOverloadedStrings
>>> zAlgorithm "ababab"
[6,0,4,0,2,0]
>>> zAlgorithm "abc$xabcxx"
[10,0,0,0,0,3,0,0,0,0]
>>> zAlgorithm ""
[]
-}
zAlgorithm :: B.ByteString -> U.Vector Int
zAlgorithm :: ByteString -> Vector Int
zAlgorithm ByteString
bs = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
  MVector s Int
z <- Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew Int
n
  ((Int -> Int -> Int -> ST s ()) -> Int -> Int -> Int -> ST s ())
-> Int -> Int -> Int -> ST s ()
forall a. (a -> a) -> a
fix
    ( \Int -> Int -> Int -> ST s ()
loop !Int
zl !Int
zr !Int
l -> Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        if Int
zr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
l
          then do
            let !r :: Int
r = Int -> Int -> Int
lcp Int
0 Int
l
            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
z Int
l (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l)
            Int -> Int -> Int -> ST s ()
loop Int
l Int
r (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          else do
            Int
zk <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
z (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
zl)
            if Int
zk Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
zr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
              then do
                MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
z Int
l Int
zk
                Int -> Int -> Int -> ST s ()
loop Int
zl Int
zr (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
              else do
                let !r :: Int
r = Int -> Int -> Int
lcp (Int
zr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l) Int
zr
                MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
z Int
l (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l)
                Int -> Int -> Int -> ST s ()
loop Int
l Int
r (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    )
    Int
0
    Int
1
    Int
1
  Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
    MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector s Int
MVector (PrimState (ST s)) Int
z Int
0 Int
n
  MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Int
z
  where
    !n :: Int
n = ByteString -> Int
B.length ByteString
bs

    lcp :: Int -> Int -> Int
    lcp :: Int -> Int -> Int
lcp = ((Int -> Int -> Int) -> Int -> Int -> Int) -> Int -> Int -> Int
forall a. (a -> a) -> a
fix (((Int -> Int -> Int) -> Int -> Int -> Int) -> Int -> Int -> Int)
-> ((Int -> Int -> Int) -> Int -> Int -> Int) -> Int -> Int -> Int
forall a b. (a -> b) -> a -> b
$ \Int -> Int -> Int
loop !Int
l !Int
r ->
      if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n Bool -> Bool -> Bool
&& ByteString -> Int -> Word8
B.unsafeIndex ByteString
bs Int
l Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int -> Word8
B.unsafeIndex ByteString
bs Int
r
        then Int -> Int -> Int
loop (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        else Int
r