module Data.ByteString.Manacher where
import Control.Exception
import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Data.Function (fix)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
manacher :: B.ByteString -> U.Vector Int
manacher :: ByteString -> Vector Int
manacher ByteString
bs = Bool -> Vector Int -> Vector Int
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int -> Bool
forall a. Integral a => a -> Bool
odd Int
n) (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$
(forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
U.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
MVector s Int
rad <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate Int
n Int
0
((Int -> Int -> ST s ()) -> Int -> Int -> ST s ())
-> Int -> Int -> ST s ()
forall a. (a -> a) -> a
fix (((Int -> Int -> ST s ()) -> Int -> Int -> ST s ())
-> Int -> Int -> ST s ())
-> Int
-> ((Int -> Int -> ST s ()) -> Int -> Int -> ST s ())
-> Int
-> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
`flip` Int
0 (((Int -> Int -> ST s ()) -> Int -> Int -> ST s ())
-> Int -> ST s ())
-> Int
-> ((Int -> Int -> ST s ()) -> Int -> Int -> ST s ())
-> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
`flip` Int
0 (((Int -> Int -> ST s ()) -> Int -> Int -> ST s ()) -> ST s ())
-> ((Int -> Int -> ST s ()) -> Int -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int -> Int -> ST s ()
loop !Int
center !Int
radius -> do
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
center Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
let !radius' :: Int
radius' = Int -> Int -> Int
naiveRadius Int
center Int
radius
MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
rad Int
center Int
radius'
(((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ())
-> Int -> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> ST s ()) -> Int -> ST s ()) -> Int -> ST s ()
forall a. (a -> a) -> a
fix Int
1 (((Int -> ST s ()) -> Int -> ST s ()) -> ST s ())
-> ((Int -> ST s ()) -> Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int -> ST s ()
inner !Int
r -> do
if Int
center Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
center Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
then do
Int
radL <- MVector (PrimState (ST s)) Int -> Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector s Int
MVector (PrimState (ST s)) Int
rad (Int
center Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r)
if Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
radL Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
radius'
then do
MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
rad (Int
center Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) Int
radL
Int -> ST s ()
inner (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
else Int -> Int -> ST s ()
loop (Int
center Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) (Int
radius' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r)
else Int -> Int -> ST s ()
loop (Int
center Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r) (Int
radius' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r)
MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Int
rad
where
!n :: Int
n = ByteString -> Int
B.length ByteString
bs
naiveRadius :: Int -> Int -> Int
naiveRadius :: Int -> Int -> Int
naiveRadius Int
c Int
r = Int -> Int
go Int
r
where
go :: Int -> Int
go !Int
i
| Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0
, Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
, ByteString -> Int -> Word8
B.unsafeIndex ByteString
bs (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int -> Word8
B.unsafeIndex ByteString
bs (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) =
Int -> Int
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise = Int
i