module Data.FenwickTree.Sum where

import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Function
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

import My.Prelude (floorPowerOf2)

newtype SumFenwickTree s a = SumFenwickTree {forall s a. SumFenwickTree s a -> MVector s a
getSumFenwickTree :: UM.MVector s a}

newSumFenwickTree ::
  (U.Unbox a, Num a, PrimMonad m) =>
  Int ->
  m (SumFenwickTree (PrimState m) a)
newSumFenwickTree :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
Int -> m (SumFenwickTree (PrimState m) a)
newSumFenwickTree Int
n = MVector (PrimState m) a -> SumFenwickTree (PrimState m) a
forall s a. MVector s a -> SumFenwickTree s a
SumFenwickTree (MVector (PrimState m) a -> SumFenwickTree (PrimState m) a)
-> m (MVector (PrimState m) a)
-> m (SumFenwickTree (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) a
0
{-# INLINE newSumFenwickTree #-}

-- | /O(n)/
buildSumFenwickTree ::
  (U.Unbox a, Num a, PrimMonad m) =>
  U.Vector a ->
  m (SumFenwickTree (PrimState m) a)
buildSumFenwickTree :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
Vector a -> m (SumFenwickTree (PrimState m) a)
buildSumFenwickTree Vector a
vec = do
  let n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
vec
  ft <- Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  UM.write ft 0 0
  U.unsafeCopy (UM.tail ft) vec
  flip fix 1 $ \Int -> m ()
loop !Int
i -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    let j :: Int
j = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (-Int
i))
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      fti <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
ft Int
i
      UM.unsafeModify ft (+ fti) j
    Int -> m ()
loop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  return $ SumFenwickTree ft
{-# INLINE buildSumFenwickTree #-}

{- | sum[0..k)

 /O(log n)/
-}
sumTo ::
  (PrimMonad m, U.Unbox a, Num a) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  m a
sumTo :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Num a) =>
SumFenwickTree (PrimState m) a -> Int -> m a
sumTo (SumFenwickTree MVector (PrimState m) a
ft) = a -> Int -> m a
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
a -> Int -> m a
go a
0
  where
    go :: a -> Int -> m a
go !a
acc !Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
          xi <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
ft Int
i
          go (acc + xi) (i - (i .&. (-i)))
      | Bool
otherwise = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
acc
{-# INLINE sumTo #-}

{- | sum[l..r)

/O(log n)/
-}
sumFromTo ::
  (PrimMonad m, U.Unbox a, Num a) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  Int ->
  m a
sumFromTo :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Num a) =>
SumFenwickTree (PrimState m) a -> Int -> Int -> m a
sumFromTo (SumFenwickTree MVector (PrimState m) a
ft) = a -> Int -> Int -> m a
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
a -> Int -> Int -> m a
goL a
0
  where
    goL :: a -> Int -> Int -> m a
goL !a
acc !Int
l !Int
r
      | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
          xl' <- (a
acc a -> a -> a
forall a. Num a => a -> a -> a
-) (a -> a) -> m a -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
ft Int
l
          goL xl' (l - (l .&. (-l))) r
      | Bool
otherwise = a -> Int -> m a
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
a -> Int -> m a
goR a
acc Int
r
    goR :: a -> Int -> m a
goR !a
acc !Int
r
      | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
          xr' <- (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+) (a -> a) -> m a -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
ft Int
r
          goR xr' (r - (r .&. (-r)))
      | Bool
otherwise = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
acc
{-# INLINE sumFromTo #-}

-- | /O(log n)/
addAt ::
  (U.Unbox a, Num a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  a ->
  m ()
addAt :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> a -> m ()
addAt (SumFenwickTree MVector (PrimState m) a
ft) Int
k a
v = (((Int -> m ()) -> Int -> m ()) -> Int -> m ())
-> Int -> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (((Int -> m ()) -> Int -> m ()) -> m ())
-> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int -> m ()
loop !Int
i -> do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    MVector (PrimState m) a -> (a -> a) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) a
ft (a -> a -> a
forall a. Num a => a -> a -> a
+ a
v) Int
i
    Int -> m ()
loop (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (-Int
i))
  where
    !n :: Int
n = MVector (PrimState m) a -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) a
ft
{-# INLINE addAt #-}

-- | /O(log n)/
readSFT ::
  (Num a, U.Unbox a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  m a
readSFT :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> m a
readSFT SumFenwickTree (PrimState m) a
ft Int
i = SumFenwickTree (PrimState m) a -> Int -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Num a) =>
SumFenwickTree (PrimState m) a -> Int -> Int -> m a
sumFromTo SumFenwickTree (PrimState m) a
ft Int
i (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE readSFT #-}

-- | /O(log n)/
writeSFT ::
  (Num a, U.Unbox a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  a ->
  m ()
writeSFT :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> a -> m ()
writeSFT SumFenwickTree (PrimState m) a
ft Int
i a
x = SumFenwickTree (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> m a
readSFT SumFenwickTree (PrimState m) a
ft Int
i m a -> (a -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SumFenwickTree (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> a -> m ()
addAt SumFenwickTree (PrimState m) a
ft Int
i (a -> m ()) -> (a -> a) -> a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
x a -> a -> a
forall a. Num a => a -> a -> a
-)
{-# INLINE writeSFT #-}

{- |
min i s.t. sum [0..i) >= w

>>> ft <- buildSumFenwickTree @Int (U.fromList [1,1,1,1,1])
>>> lowerBoundSFT ft 3
3
>>> sumTo ft 3
3
>>> lowerBoundSFT ft 0
0
>>> lowerBoundSFT ft 1
1
>>> lowerBoundSFT ft 10
6

>>> ft <- buildSumFenwickTree @Int (U.fromList [1,1,0,0,0])
>>> lowerBoundSFT ft 2
2
-}
lowerBoundSFT ::
  (U.Unbox a, Num a, Ord a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  a ->
  m Int
lowerBoundSFT :: forall a (m :: * -> *).
(Unbox a, Num a, Ord a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> a -> m Int
lowerBoundSFT (SumFenwickTree MVector (PrimState m) a
ft) a
s0
  | a
s0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  | Bool
otherwise = a -> Int -> Int -> m Int
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
a -> Int -> Int -> m Int
go a
s0 (Int -> Int
floorPowerOf2 Int
n) Int
0
  where
    !n :: Int
n = MVector (PrimState m) a -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) a
ft
    go :: a -> Int -> Int -> m Int
go !a
s !Int
w !Int
i
      | Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$! Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
      | Bool
otherwise = do
          if Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
            then do
              fiw <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
ft (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
w)
              if fiw < s
                then go (s - fiw) (w !>>. 1) (i + w)
                else go s (w !>>. 1) i
            else a -> Int -> Int -> m Int
go a
s (Int
w Int -> Int -> Int
forall a. Bits a => a -> Int -> a
!>>. Int
1) Int
i
{-# INLINE lowerBoundSFT #-}

{- |
max i s.t. sum [0..i) <= w

>>> ft <- buildSumFenwickTree @Int (U.fromList [1,1,1,1,1])
>>> upperBoundSFT ft 3
3
>>> sumTo ft 3
3
>>> upperBoundSFT ft 0
0
>>> upperBoundSFT ft 1
1
>>> upperBoundSFT ft 10
5

>>> ft <- buildSumFenwickTree @Int (U.fromList [1,1,0,0,1])
>>> upperBoundSFT ft 2
4
-}
upperBoundSFT ::
  (U.Unbox a, Num a, Ord a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  a ->
  m Int
upperBoundSFT :: forall a (m :: * -> *).
(Unbox a, Num a, Ord a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> a -> m Int
upperBoundSFT (SumFenwickTree MVector (PrimState m) a
ft) a
s0
  | a
s0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  | Bool
otherwise = a -> Int -> Int -> m Int
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
a -> Int -> Int -> m Int
go a
s0 (Int -> Int
floorPowerOf2 Int
n) Int
0
  where
    !n :: Int
n = MVector (PrimState m) a -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) a
ft
    go :: a -> Int -> Int -> m Int
go !a
s !Int
w !Int
i
      | Int
w Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
      | Bool
otherwise = do
          if Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
            then do
              fiw <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
ft (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
w)
              if fiw <= s
                then go (s - fiw) (w !>>. 1) (i + w)
                else go s (w !>>. 1) i
            else a -> Int -> Int -> m Int
go a
s (Int
w Int -> Int -> Int
forall a. Bits a => a -> Int -> a
!>>. Int
1) Int
i
{-# INLINE upperBoundSFT #-}