{-# LANGUAGE CPP #-}
module PureSAT.Prim (
    -- * ByteArray
    P.MutableByteArray,
    P.newByteArray,
    P.getSizeofMutableByteArray,
    readByteArray,
    writeByteArray,
    shrinkMutableByteArray,
    fillByteArray,
    copyMutableByteArray,
    resizeMutableByteArray,
    -- * Array of primitive values
    P.Prim,
    P.PrimArray (..),
    P.MutablePrimArray (..),
    P.newPrimArray,
    P.getSizeofMutablePrimArray,
    P.resizeMutablePrimArray,
    P.primArrayFromList,
    P.primArrayToList,
    P.foldrPrimArray,
    readPrimArray,
    writePrimArray,
    setPrimArray,
    indexPrimArray,
    P.sizeofPrimArray,
    freezePrimArray,
    P.emptyPrimArray,
    copyMutablePrimArray,
    -- * Array
    P.MutableArray,
    P.newArray,
    P.sizeofMutableArray,
    readArray,
    writeArray,
    copyMutableArray,
) where

#define PureSAT_PRIM_BOUNDS_CHECK

#ifdef PureSAT_PRIM_BOUNDS_CHECK
#define BOUNDS_CHECK_CTX HasCallStack =>
#else
#define BOUNDS_CHECK_CTX HasCallStack
#endif

import qualified Data.Primitive as P

import PureSAT.Base

-------------------------------------------------------------------------------
-- ByteArray
-------------------------------------------------------------------------------


readByteArray :: BOUNDS_CHECK_CTX P.MutableByteArray s -> Int -> ST s Word8
readByteArray :: forall s. HasCallStack => MutableByteArray s -> Int -> ST s Word8
readByteArray MutableByteArray s
arr Int
i = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
P.getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
    assertST "readByteArray" $ 0 <= i && i < n
#endif
    P.readByteArray arr i

writeByteArray :: BOUNDS_CHECK_CTX P.MutableByteArray s -> Int -> Word8 -> ST s ()
writeByteArray :: forall s.
HasCallStack =>
MutableByteArray s -> Int -> Word8 -> ST s ()
writeByteArray MutableByteArray s
arr Int
i Word8
x = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
P.getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
    assertST "readByteArray" $ 0 <= i && i < n
#endif
    P.writeByteArray arr i x

shrinkMutableByteArray :: BOUNDS_CHECK_CTX P.MutableByteArray s -> Int -> ST s ()
shrinkMutableByteArray :: forall s. HasCallStack => MutableByteArray s -> Int -> ST s ()
shrinkMutableByteArray MutableByteArray s
arr Int
m = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
P.getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
    assertST "shrinkMutableByteArray" $ 0 <= m && m <= n
#endif
    P.shrinkMutableByteArray arr m

fillByteArray :: BOUNDS_CHECK_CTX P.MutableByteArray s -> Int -> Int -> Word8 -> ST s ()
fillByteArray :: forall s.
HasCallStack =>
MutableByteArray s -> Int -> Int -> Word8 -> ST s ()
fillByteArray MutableByteArray s
arr Int
off Int
len Word8
x = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
P.getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
    -- traceM $ "fillByteArray " ++ show (off, len, n)
    assertST "fillByteArray" $ 0 <= off && off < n
    assertST "fillByteArray" $ 0 <= (off + len) && (off + len) <= n 
#endif
    P.fillByteArray arr off len x

copyMutableByteArray :: BOUNDS_CHECK_CTX P.MutableByteArray s -> Int -> P.MutableByteArray s -> Int -> Int -> ST s ()
copyMutableByteArray :: forall s.
HasCallStack =>
MutableByteArray s
-> Int -> MutableByteArray s -> Int -> Int -> ST s ()
copyMutableByteArray MutableByteArray s
dst Int
off MutableByteArray s
src Int
off' Int
len = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
P.getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
dst
    assertST "copyMutableByteArray" $ 0 <= off && off < n
    assertST "copyMutableByteArray" $ 0 <= (off + len) && (off + len) <= n 

    m <- P.getSizeofMutableByteArray src
    assertST "copyMutableByteArray" $ 0 <= off' && off' < m
    assertST "copyMutableByteArray" $ 0 <= (off' + len) && (off' + len) <= m 
#endif
    P.copyMutableByteArray dst off src off' len

resizeMutableByteArray :: BOUNDS_CHECK_CTX P.MutableByteArray s -> Int -> ST s (P.MutableByteArray s)
resizeMutableByteArray :: forall s.
HasCallStack =>
MutableByteArray s -> Int -> ST s (MutableByteArray s)
resizeMutableByteArray MutableByteArray s
arr Int
len = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
#endif
    MutableByteArray (PrimState (ST s))
-> Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> m (MutableByteArray (PrimState m))
P.resizeMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr Int
len

-------------------------------------------------------------------------------
-- PrimArray
-------------------------------------------------------------------------------

readPrimArray :: BOUNDS_CHECK_CTX P.Prim a => P.MutablePrimArray s a -> Int -> ST s a
readPrimArray :: forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s a
arr Int
i =  do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutablePrimArray (PrimState (ST s)) a -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
P.getSizeofMutablePrimArray MutablePrimArray s a
MutablePrimArray (PrimState (ST s)) a
arr
    assertST "readPrimArray" $ 0 <= i && i < n
#endif
    P.readPrimArray arr i

writePrimArray :: BOUNDS_CHECK_CTX P.Prim a => P.MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray :: forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> a -> ST s ()
writePrimArray MutablePrimArray s a
arr Int
i a
x = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutablePrimArray (PrimState (ST s)) a -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
P.getSizeofMutablePrimArray MutablePrimArray s a
MutablePrimArray (PrimState (ST s)) a
arr
    assertST "writePrimArray" $ 0 <= i && i < n
#endif
    P.writePrimArray arr i x

setPrimArray :: BOUNDS_CHECK_CTX P.Prim a => P.MutablePrimArray s a -> Int -> Int -> a -> ST s ()
setPrimArray :: forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> Int -> a -> ST s ()
setPrimArray MutablePrimArray s a
arr Int
off Int
len a
x = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutablePrimArray (PrimState (ST s)) a -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
P.getSizeofMutablePrimArray MutablePrimArray s a
MutablePrimArray (PrimState (ST s)) a
arr
    assertST "setPrimArray" $ 0 <= off && off < n
    assertST "setPrimArray" $ 0 <= (off + len) && (off + len) <= n 
#endif
    P.setPrimArray arr off len x

freezePrimArray :: BOUNDS_CHECK_CTX P.Prim a => P.MutablePrimArray s a -> Int -> Int -> ST s (P.PrimArray a)
freezePrimArray :: forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> Int -> ST s (PrimArray a)
freezePrimArray MutablePrimArray s a
arr Int
off Int
len = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutablePrimArray (PrimState (ST s)) a -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
P.getSizeofMutablePrimArray MutablePrimArray s a
MutablePrimArray (PrimState (ST s)) a
arr
    assertST "setPrimArray" $ 0 <= off && off < n
    assertST "setPrimArray" $ 0 <= (off + len) && (off + len) <= n 
#endif
    P.freezePrimArray arr off len

indexPrimArray :: BOUNDS_CHECK_CTX P.Prim a => P.PrimArray a -> Int -> a
indexPrimArray :: forall a. (HasCallStack, Prim a) => PrimArray a -> Int -> a
indexPrimArray PrimArray a
arr Int
i
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    | Bool -> Bool
not (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< PrimArray a -> Int
forall a. Prim a => PrimArray a -> Int
P.sizeofPrimArray PrimArray a
arr) = String -> a
forall a. HasCallStack => String -> a
error String
"indexPrimArray"
#endif
    | Bool
otherwise = PrimArray a -> Int -> a
forall a. Prim a => PrimArray a -> Int -> a
P.indexPrimArray PrimArray a
arr Int
i

copyMutablePrimArray :: BOUNDS_CHECK_CTX P.Prim a => P.MutablePrimArray s a -> Int -> P.MutablePrimArray s a -> Int -> Int -> ST s ()
copyMutablePrimArray :: forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a
-> Int -> MutablePrimArray s a -> Int -> Int -> ST s ()
copyMutablePrimArray MutablePrimArray s a
dst Int
off MutablePrimArray s a
src Int
off' Int
len = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    n <- MutablePrimArray (PrimState (ST s)) a -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
P.getSizeofMutablePrimArray MutablePrimArray s a
MutablePrimArray (PrimState (ST s)) a
dst
    assertST "copyMutablePrimArray" $ 0 <= off && off < n
    assertST "copyMutablePrimArray" $ 0 <= (off + len) && (off + len) <= n 

    m <- P.getSizeofMutablePrimArray src
    assertST "copyMutablePrimArray" $ 0 <= off' && off' < m
    assertST "copyMutablePrimArray" $ 0 <= (off' + len) && (off' + len) <= m 
#endif
    P.copyMutablePrimArray dst off src off' len

-------------------------------------------------------------------------------
-- Array
-------------------------------------------------------------------------------

readArray :: BOUNDS_CHECK_CTX P.MutableArray s a -> Int -> ST s a
readArray :: forall s a. HasCallStack => MutableArray s a -> Int -> ST s a
readArray MutableArray s a
arr Int
i = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST String
"readArray" (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< MutableArray s a -> Int
forall s a. MutableArray s a -> Int
P.sizeofMutableArray MutableArray s a
arr
#endif
    MutableArray (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> m a
P.readArray MutableArray s a
MutableArray (PrimState (ST s)) a
arr Int
i

writeArray :: BOUNDS_CHECK_CTX P.MutableArray s a -> Int -> a -> ST s ()
writeArray :: forall s a. HasCallStack => MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
arr Int
i a
x = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST String
"readArray" (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< MutableArray s a -> Int
forall s a. MutableArray s a -> Int
P.sizeofMutableArray MutableArray s a
arr
#endif
    MutableArray (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a -> Int -> a -> m ()
P.writeArray MutableArray s a
MutableArray (PrimState (ST s)) a
arr Int
i a
x

copyMutableArray :: BOUNDS_CHECK_CTX P.MutableArray s a -> Int -> P.MutableArray s a -> Int -> Int -> ST s ()
copyMutableArray :: forall s a.
HasCallStack =>
MutableArray s a
-> Int -> MutableArray s a -> Int -> Int -> ST s ()
copyMutableArray MutableArray s a
dst Int
off MutableArray s a
src Int
off' Int
len = do
#ifdef PureSAT_PRIM_BOUNDS_CHECK
    let n :: Int
n = MutableArray s a -> Int
forall s a. MutableArray s a -> Int
P.sizeofMutableArray MutableArray s a
dst
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST String
"copyMutableArray" (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
off Bool -> Bool -> Bool
&& Int
off Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST String
"copyMutableArray" (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) Bool -> Bool -> Bool
&& (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n 

    let m :: Int
m = MutableArray s a -> Int
forall s a. MutableArray s a -> Int
P.sizeofMutableArray MutableArray s a
src
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST String
"copyMutableArray" (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
off' Bool -> Bool -> Bool
&& Int
off' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
m
    String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST String
"copyMutableArray" (Bool -> ST s ()) -> Bool -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (Int
off' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) Bool -> Bool -> Bool
&& (Int
off' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m 
#endif
    MutableArray (PrimState (ST s)) a
-> Int
-> MutableArray (PrimState (ST s)) a
-> Int
-> Int
-> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MutableArray (PrimState m) a
-> Int -> MutableArray (PrimState m) a -> Int -> Int -> m ()
P.copyMutableArray MutableArray s a
MutableArray (PrimState (ST s)) a
dst Int
off MutableArray s a
MutableArray (PrimState (ST s)) a
src Int
off' Int
len