module Algorithm.BinarySearch where

import Control.Monad
import Data.Bool
import Data.Functor.Identity

import My.Prelude (unsafeShiftRL)

-- | assert (p high)
lowerBoundM :: (Monad m) => Int -> Int -> (Int -> m Bool) -> m Int
lowerBoundM :: forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m Bool) -> m Int
lowerBoundM Int
low0 Int
high0 Int -> m Bool
p = Int -> Int -> m Int
go Int
low0 Int
high0
  where
    go :: Int -> Int -> m Int
go !Int
low !Int
high
      | Int
high Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
low = Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
high
      | Bool
otherwise = Int -> m Bool
p Int
mid m Bool -> (Bool -> m Int) -> m Int
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= m Int -> m Int -> Bool -> m Int
forall a. a -> a -> Bool -> a
bool (Int -> Int -> m Int
go (Int
mid Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
high) (Int -> Int -> m Int
go Int
low Int
mid)
      where
        mid :: Int
mid = Int
low Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
unsafeShiftRL (Int
high Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
low) Int
1
{-# INLINE lowerBoundM #-}

-- | assert (p low)
upperBoundM :: (Monad m) => Int -> Int -> (Int -> m Bool) -> m Int
upperBoundM :: forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m Bool) -> m Int
upperBoundM Int
low Int
high Int -> m Bool
p = do
  Bool
flg <- Int -> m Bool
p Int
high
  if Bool
flg
    then Int -> m Int
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
high
    else Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1 (Int -> Int) -> m Int -> m Int
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> Int -> Int -> (Int -> m Bool) -> m Int
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m Bool) -> m Int
lowerBoundM Int
low Int
high ((Bool -> Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not (m Bool -> m Bool) -> (Int -> m Bool) -> Int -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m Bool
p)
{-# INLINE upperBoundM #-}

-- | assert (p high)
lowerBound :: Int -> Int -> (Int -> Bool) -> Int
lowerBound :: Int -> Int -> (Int -> Bool) -> Int
lowerBound Int
low Int
high Int -> Bool
p = Identity Int -> Int
forall a. Identity a -> a
runIdentity (Int -> Int -> (Int -> Identity Bool) -> Identity Int
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m Bool) -> m Int
lowerBoundM Int
low Int
high (Bool -> Identity Bool
forall a. a -> Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Identity Bool) -> (Int -> Bool) -> Int -> Identity Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Bool
p))
{-# INLINE lowerBound #-}

-- | assert (p low)
upperBound :: Int -> Int -> (Int -> Bool) -> Int
upperBound :: Int -> Int -> (Int -> Bool) -> Int
upperBound Int
low Int
high Int -> Bool
p = Identity Int -> Int
forall a. Identity a -> a
runIdentity (Int -> Int -> (Int -> Identity Bool) -> Identity Int
forall (m :: * -> *).
Monad m =>
Int -> Int -> (Int -> m Bool) -> m Int
upperBoundM Int
low Int
high (Bool -> Identity Bool
forall a. a -> Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Identity Bool) -> (Int -> Bool) -> Int -> Identity Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Bool
p))
{-# INLINE upperBound #-}

lowerBoundInteger :: Integer -> Integer -> (Integer -> Bool) -> Integer
lowerBoundInteger :: Integer -> Integer -> (Integer -> Bool) -> Integer
lowerBoundInteger Integer
low Integer
high Integer -> Bool
p =
  Identity Integer -> Integer
forall a. Identity a -> a
runIdentity (Identity Integer -> Integer) -> Identity Integer -> Integer
forall a b. (a -> b) -> a -> b
$
    Integer
-> Integer -> (Integer -> Identity Bool) -> Identity Integer
forall (m :: * -> *).
Monad m =>
Integer -> Integer -> (Integer -> m Bool) -> m Integer
lowerBoundIntegerM Integer
low Integer
high (Bool -> Identity Bool
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Identity Bool)
-> (Integer -> Bool) -> Integer -> Identity Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Bool
p)
{-# INLINE lowerBoundInteger #-}

upperBoundInteger :: Integer -> Integer -> (Integer -> Bool) -> Integer
upperBoundInteger :: Integer -> Integer -> (Integer -> Bool) -> Integer
upperBoundInteger Integer
low Integer
high Integer -> Bool
p =
  Identity Integer -> Integer
forall a. Identity a -> a
runIdentity (Identity Integer -> Integer) -> Identity Integer -> Integer
forall a b. (a -> b) -> a -> b
$
    Integer
-> Integer -> (Integer -> Identity Bool) -> Identity Integer
forall (m :: * -> *).
Monad m =>
Integer -> Integer -> (Integer -> m Bool) -> m Integer
upperBoundIntegerM Integer
low Integer
high (Bool -> Identity Bool
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Identity Bool)
-> (Integer -> Bool) -> Integer -> Identity Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Bool
p)
{-# INLINE upperBoundInteger #-}

lowerBoundIntegerM ::
  (Monad m) => Integer -> Integer -> (Integer -> m Bool) -> m Integer
lowerBoundIntegerM :: forall (m :: * -> *).
Monad m =>
Integer -> Integer -> (Integer -> m Bool) -> m Integer
lowerBoundIntegerM Integer
low0 Integer
high0 Integer -> m Bool
p = Integer -> Integer -> m Integer
go Integer
low0 Integer
high0
  where
    go :: Integer -> Integer -> m Integer
go !Integer
low !Integer
high
      | Integer
high Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
low = Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
high
      | Bool
otherwise = do
          Bool
pmid <- Integer -> m Bool
p Integer
mid
          if Bool
pmid
            then Integer -> Integer -> m Integer
go Integer
low Integer
mid
            else Integer -> Integer -> m Integer
go (Integer
mid Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Integer
high
      where
        h :: Integer
h = Integer -> Integer
forall a. Integral a => a -> Integer
toInteger Integer
high
        l :: Integer
l = Integer -> Integer
forall a. Integral a => a -> Integer
toInteger Integer
low
        mid :: Integer
mid = Integer -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ Integer
l Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
div (Integer
h Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
l) Integer
2
{-# INLINE lowerBoundIntegerM #-}

upperBoundIntegerM ::
  (Monad m) => Integer -> Integer -> (Integer -> m Bool) -> m Integer
upperBoundIntegerM :: forall (m :: * -> *).
Monad m =>
Integer -> Integer -> (Integer -> m Bool) -> m Integer
upperBoundIntegerM Integer
low Integer
high Integer -> m Bool
p = do
  Bool
phigh <- Integer -> m Bool
p Integer
high
  if Bool
phigh
    then Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Integer
high
    else Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
subtract Integer
1 (Integer -> Integer) -> m Integer -> m Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> Integer -> (Integer -> m Bool) -> m Integer
forall (m :: * -> *).
Monad m =>
Integer -> Integer -> (Integer -> m Bool) -> m Integer
lowerBoundIntegerM Integer
low Integer
high ((Bool -> Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not (m Bool -> m Bool) -> (Integer -> m Bool) -> Integer -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> m Bool
p)
{-# INLINE upperBoundIntegerM #-}

-- | assert (p high)
lowerBoundDouble :: Double -> Double -> (Double -> Bool) -> Double
lowerBoundDouble :: Double -> Double -> (Double -> Bool) -> Double
lowerBoundDouble Double
low0 Double
high0 Double -> Bool
p = Int -> Double -> Double -> Double
go Int
50 Double
low0 Double
high0
  where
    go :: Int -> Double -> Double -> Double
    go :: Int -> Double -> Double -> Double
go !Int
n !Double
low !Double
high
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Double
high
      | Double -> Bool
p Double
mid = Int -> Double -> Double -> Double
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Double
low Double
mid
      | Bool
otherwise = Int -> Double -> Double -> Double
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Double
mid Double
high
      where
        mid :: Double
mid = (Double
low Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
high) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
0.5

-- | assert (p low)
upperBoundDouble :: Double -> Double -> (Double -> Bool) -> Double
upperBoundDouble :: Double -> Double -> (Double -> Bool) -> Double
upperBoundDouble Double
low0 Double
high0 Double -> Bool
p = Int -> Double -> Double -> Double
go Int
50 Double
low0 Double
high0
  where
    go :: Int -> Double -> Double -> Double
    go :: Int -> Double -> Double -> Double
go !Int
n !Double
low !Double
high
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Double
low
      | Double -> Bool
p Double
mid = Int -> Double -> Double -> Double
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Double
mid Double
high
      | Bool
otherwise = Int -> Double -> Double -> Double
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Double
low Double
mid
      where
        mid :: Double
mid = (Double
low Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
high) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
0.5