module Data.FenwickTree where

import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Coerce
import Data.Function
import Data.Monoid
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

newtype FenwickTree s a = FenwickTree {forall s a. FenwickTree s a -> MVector s a
getFenwickTree :: UM.MVector s a}

newFenwickTree ::
  (U.Unbox a, Monoid a, PrimMonad m) =>
  Int ->
  m (FenwickTree (PrimState m) a)
newFenwickTree :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Int -> m (FenwickTree (PrimState m) a)
newFenwickTree Int
n = MVector (PrimState m) a -> FenwickTree (PrimState m) a
forall s a. MVector s a -> FenwickTree s a
FenwickTree (MVector (PrimState m) a -> FenwickTree (PrimState m) a)
-> m (MVector (PrimState m) a) -> m (FenwickTree (PrimState m) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
UM.replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) a
forall a. Monoid a => a
mempty
{-# INLINE newFenwickTree #-}

-- | /O(n)/
buildFenwickTree ::
  (U.Unbox a, Monoid a, PrimMonad m) =>
  U.Vector a ->
  m (FenwickTree (PrimState m) a)
buildFenwickTree :: forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Vector a -> m (FenwickTree (PrimState m) a)
buildFenwickTree Vector a
vec = do
  let n :: Int
n = Vector a -> Int
forall a. Unbox a => Vector a -> Int
U.length Vector a
vec
  MVector (PrimState m) a
ft <- Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
UM.unsafeNew (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  MVector (PrimState m) a -> Int -> a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
UM.write MVector (PrimState m) a
ft Int
0 a
forall a. Monoid a => a
mempty
  MVector (PrimState m) a -> Vector a -> m ()
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> Vector a -> m ()
U.unsafeCopy (MVector (PrimState m) a -> MVector (PrimState m) a
forall a s. Unbox a => MVector s a -> MVector s a
UM.tail MVector (PrimState m) a
ft) Vector a
vec
  (((Int -> m ()) -> Int -> m ()) -> Int -> m ())
-> Int -> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix Int
1 (((Int -> m ()) -> Int -> m ()) -> m ())
-> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int -> m ()
loop !Int
i -> Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    let j :: Int
j = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (-Int
i))
    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      a
fti <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
ft Int
i
      MVector (PrimState m) a -> (a -> a) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) a
ft (a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
fti) Int
j
    Int -> m ()
loop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  FenwickTree (PrimState m) a -> m (FenwickTree (PrimState m) a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (FenwickTree (PrimState m) a -> m (FenwickTree (PrimState m) a))
-> FenwickTree (PrimState m) a -> m (FenwickTree (PrimState m) a)
forall a b. (a -> b) -> a -> b
$ MVector (PrimState m) a -> FenwickTree (PrimState m) a
forall s a. MVector s a -> FenwickTree s a
FenwickTree MVector (PrimState m) a
ft
{-# INLINE buildFenwickTree #-}

{- | mappend [0..k)

 /O(log n)/
-}
mappendTo ::
  (PrimMonad m, U.Unbox a, Monoid a) =>
  FenwickTree (PrimState m) a ->
  Int ->
  m a
mappendTo :: forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Monoid a) =>
FenwickTree (PrimState m) a -> Int -> m a
mappendTo (FenwickTree MVector (PrimState m) a
ft) = a -> Int -> m a
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
a -> Int -> m a
go a
forall a. Monoid a => a
mempty
  where
    go :: a -> Int -> m a
go !a
acc !Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
          a
xi <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
ft Int
i
          a -> Int -> m a
go (a
acc a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
xi) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (-Int
i)))
      | Bool
otherwise = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
acc
{-# INLINE mappendTo #-}

-- | /O(log n)/
mappendAt ::
  (U.Unbox a, Semigroup a, PrimMonad m) =>
  FenwickTree (PrimState m) a ->
  Int ->
  a ->
  m ()
mappendAt :: forall a (m :: * -> *).
(Unbox a, Semigroup a, PrimMonad m) =>
FenwickTree (PrimState m) a -> Int -> a -> m ()
mappendAt (FenwickTree MVector (PrimState m) a
ft) Int
k a
v = (((Int -> m ()) -> Int -> m ()) -> Int -> m ())
-> Int -> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> m ()) -> Int -> m ()) -> Int -> m ()
forall a. (a -> a) -> a
fix (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (((Int -> m ()) -> Int -> m ()) -> m ())
-> ((Int -> m ()) -> Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int -> m ()
loop !Int
i -> do
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    MVector (PrimState m) a -> (a -> a) -> Int -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
UM.unsafeModify MVector (PrimState m) a
ft (a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
v) Int
i
    Int -> m ()
loop (Int -> m ()) -> Int -> m ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (-Int
i))
  where
    !n :: Int
n = MVector (PrimState m) a -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) a
ft
{-# INLINE mappendAt #-}

type SumFenwickTree s a = FenwickTree s (Sum a)

newSumFenwickTree ::
  (Num a, U.Unbox a, PrimMonad m) =>
  Int ->
  m (SumFenwickTree (PrimState m) a)
newSumFenwickTree :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
Int -> m (SumFenwickTree (PrimState m) a)
newSumFenwickTree = Int -> m (FenwickTree (PrimState m) (Sum a))
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Int -> m (FenwickTree (PrimState m) a)
newFenwickTree
{-# INLINE newSumFenwickTree #-}

-- | /O(n)/
buildSumFenwickTree ::
  (Num a, U.Unbox a, PrimMonad m) =>
  U.Vector a ->
  m (SumFenwickTree (PrimState m) a)
buildSumFenwickTree :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
Vector a -> m (SumFenwickTree (PrimState m) a)
buildSumFenwickTree = Vector (Sum a) -> m (SumFenwickTree (PrimState m) a)
forall a (m :: * -> *).
(Unbox a, Monoid a, PrimMonad m) =>
Vector a -> m (FenwickTree (PrimState m) a)
buildFenwickTree (Vector (Sum a) -> m (SumFenwickTree (PrimState m) a))
-> (Vector a -> Vector (Sum a))
-> Vector a
-> m (SumFenwickTree (PrimState m) a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Sum a) -> Vector a -> Vector (Sum a)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map a -> Sum a
forall a b. Coercible a b => a -> b
coerce
{-# INLINE buildSumFenwickTree #-}

{- | sum [0..k)

 /O(log n)/
-}
sumTo ::
  (Num a, U.Unbox a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  m a
sumTo :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> m a
sumTo SumFenwickTree (PrimState m) a
ft Int
k = Sum a -> a
forall a b. Coercible a b => a -> b
coerce (Sum a -> a) -> m (Sum a) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SumFenwickTree (PrimState m) a -> Int -> m (Sum a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a, Monoid a) =>
FenwickTree (PrimState m) a -> Int -> m a
mappendTo SumFenwickTree (PrimState m) a
ft Int
k
{-# INLINE sumTo #-}

{- | sum [l..r)

 /O(log n)/
-}
sumFromTo ::
  (Num a, U.Unbox a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  Int ->
  m a
sumFromTo :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> Int -> m a
sumFromTo SumFenwickTree (PrimState m) a
ft Int
l Int
r = (-) (a -> a -> a) -> m a -> m (a -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SumFenwickTree (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> m a
sumTo SumFenwickTree (PrimState m) a
ft Int
r m (a -> a) -> m a -> m a
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SumFenwickTree (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> m a
sumTo SumFenwickTree (PrimState m) a
ft Int
l
{-# INLINE sumFromTo #-}

-- /O(log n)/
readSumFenwickTree ::
  (Num a, U.Unbox a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  m a
readSumFenwickTree :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> m a
readSumFenwickTree SumFenwickTree (PrimState m) a
ft Int
i = SumFenwickTree (PrimState m) a -> Int -> Int -> m a
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> Int -> m a
sumFromTo SumFenwickTree (PrimState m) a
ft Int
i (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE readSumFenwickTree #-}

-- /O(log n)/
writeSumFenwickTree ::
  (Num a, U.Unbox a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  a ->
  m ()
writeSumFenwickTree :: forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> a -> m ()
writeSumFenwickTree SumFenwickTree (PrimState m) a
ft Int
i a
x = SumFenwickTree (PrimState m) a -> Int -> m a
forall a (m :: * -> *).
(Num a, Unbox a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> m a
readSumFenwickTree SumFenwickTree (PrimState m) a
ft Int
i m a -> (a -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SumFenwickTree (PrimState m) a -> Int -> a -> m ()
forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> a -> m ()
addAt SumFenwickTree (PrimState m) a
ft Int
i (a -> m ()) -> (a -> a) -> a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
x -)
{-# INLINE writeSumFenwickTree #-}

-- | /O(log n)/
addAt ::
  (U.Unbox a, Num a, PrimMonad m) =>
  SumFenwickTree (PrimState m) a ->
  Int ->
  a ->
  m ()
addAt :: forall a (m :: * -> *).
(Unbox a, Num a, PrimMonad m) =>
SumFenwickTree (PrimState m) a -> Int -> a -> m ()
addAt SumFenwickTree (PrimState m) a
ft Int
k a
x = SumFenwickTree (PrimState m) a -> Int -> Sum a -> m ()
forall a (m :: * -> *).
(Unbox a, Semigroup a, PrimMonad m) =>
FenwickTree (PrimState m) a -> Int -> a -> m ()
mappendAt SumFenwickTree (PrimState m) a
ft Int
k (a -> Sum a
forall a b. Coercible a b => a -> b
coerce a
x)
{-# INLINE addAt #-}

{- | max i s.t. sum [0..i) < w

 findMaxIndexLT k [1, 1..1] == k - 1

 >>> ones <- buildFenwickTree [1, 1, 1, 1, 1]
 >>> findMaxIndexLT 3 ones
 2
 >>> findMaxIndexLT 0 ones
 0
 >>> ids <- buildFenwickTree [1, 2, 3, 4, 5]
 >>> findMaxIndexLT 6 ids
 2
 >>> findMaxIndexLT 7 ids
 3
 >>> zeros <- buildFenwickTree [0, 0, 0, 0, 0]
 >>> findMaxIndexLT 1 zeros
 5
-}
findMaxIndexLT ::
  (U.Unbox a, Num a, Ord a, PrimMonad m) =>
  FenwickTree (PrimState m) a ->
  a ->
  m Int
findMaxIndexLT :: forall a (m :: * -> *).
(Unbox a, Num a, Ord a, PrimMonad m) =>
FenwickTree (PrimState m) a -> a -> m Int
findMaxIndexLT (FenwickTree MVector (PrimState m) a
ft) a
w0
  | a
w0 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
  | Bool
otherwise = a -> Int -> Int -> m Int
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
a -> Int -> Int -> m Int
go a
w0 Int
highestOneBit Int
0
  where
    n :: Int
n = MVector (PrimState m) a -> Int
forall a s. Unbox a => MVector s a -> Int
UM.length MVector (PrimState m) a
ft
    highestOneBit :: Int
highestOneBit = (Int -> Bool) -> (Int -> Int) -> Int -> Int
forall a. (a -> Bool) -> (a -> a) -> a -> a
until (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n) (Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2) Int
1 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2
    go :: a -> Int -> Int -> m Int
go !a
w !Int
step !Int
i
      | Int
step Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
      | Bool
otherwise = do
          if Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
step Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
            then do
              a
u <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> m a
UM.unsafeRead MVector (PrimState m) a
MVector (PrimState m) a
ft (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
step)
              if a
u a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
w
                then a -> Int -> Int -> m Int
go (a
w a -> a -> a
forall a. Num a => a -> a -> a
- a
u) (Int
step Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
step)
                else a -> Int -> Int -> m Int
go a
w (Int
step Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1) Int
i
            else a -> Int -> Int -> m Int
go a
w (Int
step Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1) Int
i
{-# INLINE findMaxIndexLT #-}