module Data.Doubling where

import Data.Bits
import qualified Data.Foldable as F
import Data.Semigroup
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U

newtype Doubling a = Doubling {forall a. Doubling a -> Vector (Int, a)
getDoubling :: U.Vector (Int, a)}

generateDoubling :: (U.Unbox a) => Int -> (Int -> (Int, a)) -> Doubling a
generateDoubling :: forall a. Unbox a => Int -> (Int -> (Int, a)) -> Doubling a
generateDoubling Int
n Int -> (Int, a)
f = Vector (Int, a) -> Doubling a
forall a. Vector (Int, a) -> Doubling a
Doubling (Vector (Int, a) -> Doubling a) -> Vector (Int, a) -> Doubling a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> (Int, a)) -> Vector (Int, a)
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate Int
n Int -> (Int, a)
f

generateDoubling_ :: Int -> (Int -> Int) -> Doubling ()
generateDoubling_ :: Int -> (Int -> Int) -> Doubling ()
generateDoubling_ Int
n Int -> Int
f = Vector (Int, ()) -> Doubling ()
forall a. Vector (Int, a) -> Doubling a
Doubling (Vector (Int, ()) -> Doubling ())
-> Vector (Int, ()) -> Doubling ()
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> (Int, ())) -> Vector (Int, ())
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate Int
n ((Int -> () -> (Int, ())) -> () -> Int -> (Int, ())
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) () (Int -> (Int, ())) -> (Int -> Int) -> Int -> (Int, ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int
f)

instance (Semigroup a, U.Unbox a) => Semigroup (Doubling a) where
  (Doubling Vector (Int, a)
next0) <> :: Doubling a -> Doubling a -> Doubling a
<> (Doubling Vector (Int, a)
next1) =
    Vector (Int, a) -> Doubling a
forall a. Vector (Int, a) -> Doubling a
Doubling (Vector (Int, a) -> Doubling a) -> Vector (Int, a) -> Doubling a
forall a b. (a -> b) -> a -> b
$
      ((Int, a) -> (Int, a)) -> Vector (Int, a) -> Vector (Int, a)
forall a b. (Unbox a, Unbox b) => (a -> b) -> Vector a -> Vector b
U.map
        ( \(Int
nv, a
x) ->
            let (Int
nnv, a
y) = Vector (Int, a) -> Int -> (Int, a)
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector (Int, a)
next1 Int
nv
                !z :: a
z = a
x a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
y
             in (Int
nnv, a
z)
        )
        Vector (Int, a)
next0
  {-# INLINE (<>) #-}
  stimes :: forall b. Integral b => b -> Doubling a -> Doubling a
stimes b
n Doubling a
x0
    | b
n b -> b -> Bool
forall a. Ord a => a -> a -> Bool
> b
0 = Doubling a -> Doubling a -> b -> Doubling a
forall {t} {t}. (Integral t, Semigroup t) => t -> t -> t -> t
go Doubling a
x0 Doubling a
x0 (b
n b -> b -> b
forall a. Num a => a -> a -> a
- b
1)
    | Bool
otherwise = [Char] -> Doubling a
forall a. HasCallStack => [Char] -> a
error [Char]
"stimes: n must be positive"
    where
      go :: t -> t -> t -> t
go !t
acc !t
x !t
i
        | t
i t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 = t
acc
        | t -> Bool
forall a. Integral a => a -> Bool
even t
i = t -> t -> t -> t
go t
acc (t
x t -> t -> t
forall a. Semigroup a => a -> a -> a
<> t
x) (t -> t -> t
forall a. Integral a => a -> a -> a
quot t
i t
2)
        | Bool
otherwise = t -> t -> t -> t
go (t
acc t -> t -> t
forall a. Semigroup a => a -> a -> a
<> t
x) (t
x t -> t -> t
forall a. Semigroup a => a -> a -> a
<> t
x) (t -> t -> t
forall a. Integral a => a -> a -> a
quot t
i t
2)
  {-# INLINE stimes #-}

-- | /O(Mlog N)/
doublingStepN ::
  (Semigroup a, U.Unbox a) =>
  -- | n
  Int ->
  -- | initial state
  Int ->
  -- | initial value
  a ->
  Doubling a ->
  (Int, a)
doublingStepN :: forall a.
(Semigroup a, Unbox a) =>
Int -> Int -> a -> Doubling a -> (Int, a)
doublingStepN Int
n Int
x0 a
v0 Doubling a
next
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = let (Int
xn, a
vn) = Doubling a -> Vector (Int, a)
forall a. Doubling a -> Vector (Int, a)
getDoubling (Int -> Doubling a -> Doubling a
forall b. Integral b => b -> Doubling a -> Doubling a
forall a b. (Semigroup a, Integral b) => b -> a -> a
stimes Int
n Doubling a
next) Vector (Int, a) -> Int -> (Int, a)
forall a. Unbox a => Vector a -> Int -> a
U.! Int
x0 in (Int
xn, a
v0 a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
vn)
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (Int
x0, a
v0)
  | Bool
otherwise = [Char] -> (Int, a)
forall a. HasCallStack => [Char] -> a
error [Char]
"doublingStepN: negative step"
{-# INLINE doublingStepN #-}

doublingStepN_ ::
  -- | n
  Int ->
  -- | initial state
  Int ->
  Doubling () ->
  Int
doublingStepN_ :: Int -> Int -> Doubling () -> Int
doublingStepN_ Int
n Int
x0 Doubling ()
next
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = (Int, ()) -> Int
forall a b. (a, b) -> a
fst ((Int, ()) -> Int) -> (Int, ()) -> Int
forall a b. (a -> b) -> a -> b
$ Doubling () -> Vector (Int, ())
forall a. Doubling a -> Vector (Int, a)
getDoubling (Int -> Doubling () -> Doubling ()
forall b. Integral b => b -> Doubling () -> Doubling ()
forall a b. (Semigroup a, Integral b) => b -> a -> a
stimes Int
n Doubling ()
next) Vector (Int, ()) -> Int -> (Int, ())
forall a. Unbox a => Vector a -> Int -> a
U.! Int
x0
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int
x0
  | Bool
otherwise = [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"doublingStepN_: negative step"
{-# INLINE doublingStepN_ #-}

newtype DoublingTable a = DoublingTable (V.Vector (Doubling a))

buildDoublingTable ::
  (Semigroup a, U.Unbox a) =>
  Doubling a ->
  DoublingTable a
buildDoublingTable :: forall a. (Semigroup a, Unbox a) => Doubling a -> DoublingTable a
buildDoublingTable = Vector (Doubling a) -> DoublingTable a
forall a. Vector (Doubling a) -> DoublingTable a
DoublingTable (Vector (Doubling a) -> DoublingTable a)
-> (Doubling a -> Vector (Doubling a))
-> Doubling a
-> DoublingTable a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> (Doubling a -> Doubling a) -> Doubling a -> Vector (Doubling a)
forall a. Int -> (a -> a) -> a -> Vector a
V.iterateN Int
63 (\Doubling a
next -> Doubling a
next Doubling a -> Doubling a -> Doubling a
forall a. Semigroup a => a -> a -> a
<> Doubling a
next)
{-# INLINE buildDoublingTable #-}

-- | /O(log N)/
doublingStepNQuery ::
  (Semigroup a, U.Unbox a) =>
  -- | n
  Int ->
  -- | initial state
  Int ->
  -- | initial value
  a ->
  DoublingTable a ->
  (Int, a)
doublingStepNQuery :: forall a.
(Semigroup a, Unbox a) =>
Int -> Int -> a -> DoublingTable a -> (Int, a)
doublingStepNQuery Int
n Int
x0 a
v0 (DoublingTable Vector (Doubling a)
table)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = ((Int, a) -> Int -> (Int, a)) -> (Int, a) -> [Int] -> (Int, a)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' (Int, a) -> Int -> (Int, a)
step (Int
x0, a
v0) [Int
0 .. Int
62]
  | Bool
otherwise = [Char] -> (Int, a)
forall a. HasCallStack => [Char] -> a
error [Char]
"doublingStepQuery: negative step"
  where
    step :: (Int, a) -> Int -> (Int, a)
step (Int
x, a
v) Int
i
      | Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
n Int
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
          let (Int
xi, a
vi) = Doubling a -> Vector (Int, a)
forall a. Doubling a -> Vector (Int, a)
getDoubling (Vector (Doubling a) -> Int -> Doubling a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector (Doubling a)
table Int
i) Vector (Int, a) -> Int -> (Int, a)
forall a. Unbox a => Vector a -> Int -> a
`U.unsafeIndex` Int
x
              !v' :: a
v' = a
v a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
vi
           in (Int
xi, a
v')
      | Bool
otherwise = (Int
x, a
v)
{-# INLINE doublingStepNQuery #-}