{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MagicHash #-}

module Math.Matrix where

import Control.Monad
import Control.Monad.ST
import Data.Primitive
import Data.Proxy
import qualified Data.Vector.Fusion.Bundle.Monadic as MB
import qualified Data.Vector.Fusion.Stream.Monadic as MS
import Data.Vector.Fusion.Util
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Primitive as P
import GHC.Exts
import GHC.TypeLits

--
import My.Prelude (rep)

-- | n x n square matrix
data SqMat (n :: Nat) a = SqMat !Int !ByteArray

viewRowSqMat ::
  (KnownNat n, Prim a, G.Vector v a) =>
  SqMat n a ->
  Int ->
  v a
viewRowSqMat :: forall (n :: Nat) a (v :: * -> *).
(KnownNat n, Prim a, Vector v a) =>
SqMat n a -> Int -> v a
viewRowSqMat (SqMat Int
n ByteArray
ba) Int
i =
  Bundle v a -> v a
forall (v :: * -> *) a. Vector v a => Bundle v a -> v a
G.unstream
    (Bundle v a -> v a)
-> (Bundle Id Vector a -> Bundle v a) -> Bundle Id Vector a -> v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bundle Id Vector a -> Bundle v a
forall (m :: * -> *) (u :: * -> *) a (v :: * -> *).
Monad m =>
Bundle m u a -> Bundle m v a
MB.reVector
    (Bundle Id Vector a -> v a) -> Bundle Id Vector a -> v a
forall a b. (a -> b) -> a -> b
$ Vector a -> Bundle Id Vector a
forall (m :: * -> *) (v :: * -> *) a.
(Monad m, Vector v a) =>
v a -> Bundle m v a
MB.fromVector (Int -> Int -> ByteArray -> Vector a
forall a. Int -> Int -> ByteArray -> Vector a
P.Vector (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) Int
n ByteArray
ba)
{-# INLINE viewRowSqMat #-}

viewColSqMat ::
  (KnownNat n, Prim a, G.Vector v a) =>
  SqMat n a ->
  Int ->
  v a
viewColSqMat :: forall (n :: Nat) a (v :: * -> *).
(KnownNat n, Prim a, Vector v a) =>
SqMat n a -> Int -> v a
viewColSqMat (SqMat Int
n ByteArray
ba) Int
j =
  Bundle v a -> v a
forall (v :: * -> *) a. Vector v a => Bundle v a -> v a
G.unstream
    (Bundle v a -> v a)
-> (Bundle Id v Int -> Bundle v a) -> Bundle Id v Int -> v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> a) -> Bundle Id v Int -> Bundle v a
forall (m :: * -> *) a b (v :: * -> *).
Monad m =>
(a -> b) -> Bundle m v a -> Bundle m v b
MB.map (ByteArray -> Int -> a
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
ba)
    (Bundle Id v Int -> v a) -> Bundle Id v Int -> v a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Int -> Bundle Id v Int
forall (m :: * -> *) a (u :: * -> *).
Monad m =>
Int -> (a -> a) -> a -> Bundle m u a
MB.iterateN Int
n (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) Int
j
{-# INLINE viewColSqMat #-}

createSqMat ::
  forall n a.
  (KnownNat n, Prim a) =>
  Proxy n ->
  (forall s. Int -> MutableByteArray s -> ST s ()) ->
  SqMat n a
createSqMat :: forall (n :: Nat) a.
(KnownNat n, Prim a) =>
Proxy n
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
createSqMat Proxy n
proxy forall s. Int -> MutableByteArray s -> ST s ()
fill = (forall s. ST s (SqMat n a)) -> SqMat n a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (SqMat n a)) -> SqMat n a)
-> (forall s. ST s (SqMat n a)) -> SqMat n a
forall a b. (a -> b) -> a -> b
$ do
  let n :: Int
n = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal Proxy n
proxy
  MutableByteArray s
mba <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int# -> Int
I# (a -> Int#
forall a. Prim a => a -> Int#
sizeOf# (a
forall a. HasCallStack => a
undefined :: a)) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n)
  Int -> MutableByteArray s -> ST s ()
forall s. Int -> MutableByteArray s -> ST s ()
fill Int
n MutableByteArray s
mba
  Int -> ByteArray -> SqMat n a
forall {k} (n :: Nat) (a :: k). Int -> ByteArray -> SqMat n a
SqMat Int
n (ByteArray -> SqMat n a) -> ST s ByteArray -> ST s (SqMat n a)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba
{-# INLINE createSqMat #-}

reifyMatDim :: (Integral i) => i -> (forall n. (KnownNat n) => Proxy n -> a) -> a
reifyMatDim :: forall i a.
Integral i =>
i -> (forall (n :: Nat). KnownNat n => Proxy n -> a) -> a
reifyMatDim i
n forall (n :: Nat). KnownNat n => Proxy n -> a
f = case Integer -> Maybe SomeNat
someNatVal (i -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral i
n) of
  Just (SomeNat Proxy n
proxy) -> Proxy n -> a
forall (n :: Nat). KnownNat n => Proxy n -> a
f Proxy n
proxy
  Maybe SomeNat
Nothing -> [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"reifyMatDim: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Integer -> [Char]
forall a. Show a => a -> [Char]
show (i -> Integer
forall a. Integral a => a -> Integer
toInteger i
n)
{-# INLINE reifyMatDim #-}

streamSqMat :: (Prim a, Monad m) => SqMat n a -> MS.Stream m a
streamSqMat :: forall a (m :: * -> *) (n :: Nat).
(Prim a, Monad m) =>
SqMat n a -> Stream m a
streamSqMat (SqMat Int
n ByteArray
ba) = Int -> (Int -> m a) -> Stream m a
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> Stream m a
MS.generateM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) ((Int -> m a) -> Stream m a) -> (Int -> m a) -> Stream m a
forall a b. (a -> b) -> a -> b
$ a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> m a) -> (Int -> a) -> Int -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteArray -> Int -> a
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
ba
{-# INLINE [1] streamSqMat #-}

unstreamSqMat :: forall n a. (KnownNat n, Prim a) => MS.Stream Id a -> SqMat n a
unstreamSqMat :: forall (n :: Nat) a.
(KnownNat n, Prim a) =>
Stream Id a -> SqMat n a
unstreamSqMat Stream Id a
s = Proxy n
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
Proxy n
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
createSqMat Proxy n
forall {k} (t :: k). Proxy t
Proxy ((forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a)
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
forall a b. (a -> b) -> a -> b
$ \Int
_ MutableByteArray s
mba -> do
  ((Int, a) -> ST s ()) -> Stream (ST s) (Int, a) -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (\(Int
i, a
x) -> MutableByteArray (PrimState (ST s)) -> Int -> a -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba Int
i a
x) (Stream (ST s) (Int, a) -> ST s ())
-> Stream (ST s) (Int, a) -> ST s ()
forall a b. (a -> b) -> a -> b
$
    (forall z. Id z -> ST s z)
-> Stream Id (Int, a) -> Stream (ST s) (Int, a)
forall (m :: * -> *) (m' :: * -> *) a.
(Monad m, Monad m') =>
(forall z. m z -> m' z) -> Stream m a -> Stream m' a
MS.trans (z -> ST s z
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (z -> ST s z) -> (Id z -> z) -> Id z -> ST s z
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id z -> z
forall a. Id a -> a
unId) (Stream Id (Int, a) -> Stream (ST s) (Int, a))
-> Stream Id (Int, a) -> Stream (ST s) (Int, a)
forall a b. (a -> b) -> a -> b
$
      Stream Id a -> Stream Id (Int, a)
forall (m :: * -> *) a. Monad m => Stream m a -> Stream m (Int, a)
MS.indexed Stream Id a
s
{-# INLINE [1] unstreamSqMat #-}

{-# RULES
"streamSqMat/unstreamSqMat" forall s.
  streamSqMat (unstreamSqMat s) =
    MS.trans (return . unId) s
"unstreamSqMat/streamSqMat" forall mat.
  unstreamSqMat (streamSqMat mat) =
    mat
  #-}

liftSqMat0 :: forall n a. (KnownNat n, Num a, Prim a) => a -> SqMat n a
liftSqMat0 :: forall (n :: Nat) a. (KnownNat n, Num a, Prim a) => a -> SqMat n a
liftSqMat0 a
x = Proxy n
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
Proxy n
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
createSqMat Proxy n
forall {k} (t :: k). Proxy t
Proxy ((forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a)
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
forall a b. (a -> b) -> a -> b
$ \Int
n MutableByteArray s
mba -> do
  MutableByteArray (PrimState (ST s)) -> Int -> Int -> a -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> Int -> a -> m ()
setByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba Int
0 (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n) (a
0 :: a)
  (Int -> ST s ()) -> Stream (ST s) Int -> ST s ()
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stream m a -> m ()
MS.mapM_ (\Int
i -> MutableByteArray (PrimState (ST s)) -> Int -> a -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba Int
i a
x) (Stream (ST s) Int -> ST s ()) -> Stream (ST s) Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Int -> Stream (ST s) Int
forall (m :: * -> *) a.
Monad m =>
Int -> (a -> a) -> a -> Stream m a
MS.iterateN Int
n (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) Int
0
{-# INLINE liftSqMat0 #-}

liftSqMat1 :: (KnownNat n, Prim a) => (a -> a) -> SqMat n a -> SqMat n a
liftSqMat1 :: forall (n :: Nat) a.
(KnownNat n, Prim a) =>
(a -> a) -> SqMat n a -> SqMat n a
liftSqMat1 a -> a
f = Stream Id a -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
Stream Id a -> SqMat n a
unstreamSqMat (Stream Id a -> SqMat n a)
-> (SqMat n a -> Stream Id a) -> SqMat n a -> SqMat n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a) -> Stream Id a -> Stream Id a
forall (m :: * -> *) a b.
Monad m =>
(a -> b) -> Stream m a -> Stream m b
MS.map a -> a
f (Stream Id a -> Stream Id a)
-> (SqMat n a -> Stream Id a) -> SqMat n a -> Stream Id a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqMat n a -> Stream Id a
forall a (m :: * -> *) (n :: Nat).
(Prim a, Monad m) =>
SqMat n a -> Stream m a
streamSqMat
{-# INLINE liftSqMat1 #-}

liftSqMat2 ::
  (KnownNat n, Prim a) =>
  (a -> a -> a) ->
  SqMat n a ->
  SqMat n a ->
  SqMat n a
liftSqMat2 :: forall (n :: Nat) a.
(KnownNat n, Prim a) =>
(a -> a -> a) -> SqMat n a -> SqMat n a -> SqMat n a
liftSqMat2 a -> a -> a
f SqMat n a
x SqMat n a
y =
  Stream Id a -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
Stream Id a -> SqMat n a
unstreamSqMat (Stream Id a -> SqMat n a) -> Stream Id a -> SqMat n a
forall a b. (a -> b) -> a -> b
$
    (a -> a -> a) -> Stream Id a -> Stream Id a -> Stream Id a
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Stream m a -> Stream m b -> Stream m c
MS.zipWith a -> a -> a
f (SqMat n a -> Stream Id a
forall a (m :: * -> *) (n :: Nat).
(Prim a, Monad m) =>
SqMat n a -> Stream m a
streamSqMat SqMat n a
x) (SqMat n a -> Stream Id a
forall a (m :: * -> *) (n :: Nat).
(Prim a, Monad m) =>
SqMat n a -> Stream m a
streamSqMat SqMat n a
y)
{-# INLINE liftSqMat2 #-}

mulSqMat ::
  forall n a.
  (KnownNat n, Num a, Prim a) =>
  SqMat n a ->
  SqMat n a ->
  SqMat n a
mulSqMat :: forall (n :: Nat) a.
(KnownNat n, Num a, Prim a) =>
SqMat n a -> SqMat n a -> SqMat n a
mulSqMat SqMat n a
x SqMat n a
y = Proxy n
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
Proxy n
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
createSqMat Proxy n
forall {k} (t :: k). Proxy t
Proxy ((forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a)
-> (forall s. Int -> MutableByteArray s -> ST s ()) -> SqMat n a
forall a b. (a -> b) -> a -> b
$ \Int
n MutableByteArray s
mba -> do
  Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
    let r :: Vector a
r = SqMat n a -> Int -> Vector a
forall (n :: Nat) a (v :: * -> *).
(KnownNat n, Prim a, Vector v a) =>
SqMat n a -> Int -> v a
viewRowSqMat SqMat n a
x Int
i
    Int -> (Int -> ST s ()) -> ST s ()
forall (m :: * -> *). Monad m => Int -> (Int -> m ()) -> m ()
rep Int
n ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
      let c :: Vector a
c = SqMat n a -> Int -> Vector a
forall (n :: Nat) a (v :: * -> *).
(KnownNat n, Prim a, Vector v a) =>
SqMat n a -> Int -> v a
viewColSqMat SqMat n a
y Int
j
      forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray @a MutableByteArray s
MutableByteArray (PrimState (ST s))
mba (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (a -> ST s ()) -> (Vector a -> a) -> Vector a -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> a
forall a. (Prim a, Num a) => Vector a -> a
P.sum (Vector a -> ST s ()) -> Vector a -> ST s ()
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Prim a, Prim b, Prim c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
P.zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*) Vector a
r Vector a
c
{-# INLINE mulSqMat #-}

instance (KnownNat n, Num a, Prim a) => Num (SqMat n a) where
  {-# SPECIALIZE instance (KnownNat n) => Num (SqMat n Int) #-}
  {-# SPECIALIZE instance (KnownNat n) => Num (SqMat n Double) #-}
  + :: SqMat n a -> SqMat n a -> SqMat n a
(+) = (a -> a -> a) -> SqMat n a -> SqMat n a -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
(a -> a -> a) -> SqMat n a -> SqMat n a -> SqMat n a
liftSqMat2 a -> a -> a
forall a. Num a => a -> a -> a
(+)
  {-# INLINE (+) #-}
  (-) = (a -> a -> a) -> SqMat n a -> SqMat n a -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
(a -> a -> a) -> SqMat n a -> SqMat n a -> SqMat n a
liftSqMat2 (-)
  {-# INLINE (-) #-}
  * :: SqMat n a -> SqMat n a -> SqMat n a
(*) = SqMat n a -> SqMat n a -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Num a, Prim a) =>
SqMat n a -> SqMat n a -> SqMat n a
mulSqMat
  {-# INLINE (*) #-}
  negate :: SqMat n a -> SqMat n a
negate = (a -> a) -> SqMat n a -> SqMat n a
forall (n :: Nat) a.
(KnownNat n, Prim a) =>
(a -> a) -> SqMat n a -> SqMat n a
liftSqMat1 a -> a
forall a. Num a => a -> a
negate
  {-# INLINE negate #-}
  abs :: SqMat n a -> SqMat n a
abs = SqMat n a -> SqMat n a
forall a. a -> a
id
  {-# INLINE abs #-}
  signum :: SqMat n a -> SqMat n a
signum = SqMat n a -> SqMat n a
forall a. a -> a
id
  {-# INLINE signum #-}
  fromInteger :: Integer -> SqMat n a
fromInteger = a -> SqMat n a
forall (n :: Nat) a. (KnownNat n, Num a, Prim a) => a -> SqMat n a
liftSqMat0 (a -> SqMat n a) -> (Integer -> a) -> Integer -> SqMat n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
  {-# INLINE fromInteger #-}