module Data.SparseTable where

import Data.Bits
import Data.Coerce
import Data.Kind
import Data.Semigroup
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U

import Math.Utils (floorLog2)

type RMQ a = SparseTable Min a

buildRMQ :: (U.Unbox a, Ord a) => U.Vector a -> RMQ a
buildRMQ :: forall a. (Unbox a, Ord a) => Vector a -> RMQ a
buildRMQ = Vector a -> SparseTable Min a
forall (f :: * -> *) a.
(Unbox a, Semigroup (f a), Coercible (f a) a) =>
Vector a -> SparseTable f a
buildSparseTable
{-# INLINE buildRMQ #-}

readRMQ :: (U.Unbox a) => RMQ a -> Int -> a
readRMQ :: forall a. Unbox a => RMQ a -> Int -> a
readRMQ = SparseTable Min a -> Int -> a
forall a (f :: * -> *). Unbox a => SparseTable f a -> Int -> a
readSparseTable
{-# INLINE readRMQ #-}

{- | min a[l..r)

 /O(1)/
-}
queryMin :: (U.Unbox a, Ord a) => RMQ a -> Int -> Int -> a
queryMin :: forall a. (Unbox a, Ord a) => RMQ a -> Int -> Int -> a
queryMin = SparseTable Min a -> Int -> Int -> a
forall (f :: * -> *) a.
(Unbox a, Semigroup (f a), Coercible (f a) a) =>
SparseTable f a -> Int -> Int -> a
querySparseTable
{-# INLINE queryMin #-}

newtype SparseTable (f :: Type -> Type) a = SparseTable
  { forall (f :: * -> *) a. SparseTable f a -> Vector (Vector a)
getSparseTable :: V.Vector (U.Vector a)
  }
  deriving (SparseTable f a -> SparseTable f a -> Bool
(SparseTable f a -> SparseTable f a -> Bool)
-> (SparseTable f a -> SparseTable f a -> Bool)
-> Eq (SparseTable f a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall (f :: * -> *) a.
(Unbox a, Eq a) =>
SparseTable f a -> SparseTable f a -> Bool
$c== :: forall (f :: * -> *) a.
(Unbox a, Eq a) =>
SparseTable f a -> SparseTable f a -> Bool
== :: SparseTable f a -> SparseTable f a -> Bool
$c/= :: forall (f :: * -> *) a.
(Unbox a, Eq a) =>
SparseTable f a -> SparseTable f a -> Bool
/= :: SparseTable f a -> SparseTable f a -> Bool
Eq, Int -> SparseTable f a -> ShowS
[SparseTable f a] -> ShowS
SparseTable f a -> String
(Int -> SparseTable f a -> ShowS)
-> (SparseTable f a -> String)
-> ([SparseTable f a] -> ShowS)
-> Show (SparseTable f a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (f :: * -> *) a.
(Show a, Unbox a) =>
Int -> SparseTable f a -> ShowS
forall (f :: * -> *) a.
(Show a, Unbox a) =>
[SparseTable f a] -> ShowS
forall (f :: * -> *) a.
(Show a, Unbox a) =>
SparseTable f a -> String
$cshowsPrec :: forall (f :: * -> *) a.
(Show a, Unbox a) =>
Int -> SparseTable f a -> ShowS
showsPrec :: Int -> SparseTable f a -> ShowS
$cshow :: forall (f :: * -> *) a.
(Show a, Unbox a) =>
SparseTable f a -> String
show :: SparseTable f a -> String
$cshowList :: forall (f :: * -> *) a.
(Show a, Unbox a) =>
[SparseTable f a] -> ShowS
showList :: [SparseTable f a] -> ShowS
Show)

buildSparseTable ::
  forall (f :: Type -> Type) a.
  (U.Unbox a, Semigroup (f a), Coercible (f a) a) =>
  U.Vector a ->
  SparseTable f a
buildSparseTable :: forall (f :: * -> *) a.
(Unbox a, Semigroup (f a), Coercible (f a) a) =>
Vector a -> SparseTable f a
buildSparseTable Vector a
vec =
  Vector (Vector a) -> SparseTable f a
forall (f :: * -> *) a. Vector (Vector a) -> SparseTable f a
SparseTable
    (Vector (Vector a) -> SparseTable f a)
-> (Vector Int -> Vector (Vector a))
-> Vector Int
-> SparseTable f a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector a -> Int -> Vector a)
-> Vector a -> Vector Int -> Vector (Vector a)
forall a b. (a -> b -> a) -> a -> Vector b -> Vector a
V.scanl' (\Vector a
acc Int
i -> (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith ((f a -> f a -> f a) -> a -> a -> a
forall a b. Coercible a b => a -> b
coerce (forall a. Semigroup a => a -> a -> a
(<>) @(f a))) Vector a
acc (Vector a -> Vector a) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ Int -> Vector a -> Vector a
forall a. Unbox a => Int -> Vector a -> Vector a
U.drop Int
i Vector a
acc) Vector a
vec
    (Vector Int -> SparseTable f a) -> Vector Int -> SparseTable f a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Int -> Vector Int
forall a. Int -> (a -> a) -> a -> Vector a
V.iterateN (Int -> Int
floorLog2 (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
vec) (Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2) Int
1

-- | /O(1)/
readSparseTable :: (U.Unbox a) => SparseTable f a -> Int -> a
readSparseTable :: forall a (f :: * -> *). Unbox a => SparseTable f a -> Int -> a
readSparseTable SparseTable f a
st = Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex (Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.unsafeIndex (SparseTable f a -> Vector (Vector a)
forall (f :: * -> *) a. SparseTable f a -> Vector (Vector a)
getSparseTable SparseTable f a
st) Int
0)

{- | append[l..r)

 /O(1)/
-}
querySparseTable ::
  forall (f :: Type -> Type) a.
  (U.Unbox a, Semigroup (f a), Coercible (f a) a) =>
  SparseTable f a ->
  Int ->
  Int ->
  a
querySparseTable :: forall (f :: * -> *) a.
(Unbox a, Semigroup (f a), Coercible (f a) a) =>
SparseTable f a -> Int -> Int -> a
querySparseTable SparseTable f a
st Int
l Int
r = (f a -> f a -> f a) -> a -> a -> a
forall a b. Coercible a b => a -> b
coerce (forall a. Semigroup a => a -> a -> a
(<>) @(f a)) a
x a
y
  where
    logStep :: Int
logStep = Int -> Int
floorLog2 (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l
    row :: Vector a
row = Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.unsafeIndex (SparseTable f a -> Vector (Vector a)
forall (f :: * -> *) a. SparseTable f a -> Vector (Vector a)
getSparseTable SparseTable f a
st) Int
logStep
    x :: a
x = Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
row Int
l
    y :: a
y = Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector a
row (Int -> a) -> Int -> a
forall a b. (a -> b) -> a -> b
$ Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftL Int
1 Int
logStep