{-# LANGUAGE CPP #-}
module PureSAT.Trail where

#define ASSERTING(x)

import Data.Primitive.PrimVar   (PrimVar, newPrimVar, readPrimVar, writePrimVar)

import PureSAT.Base
import PureSAT.Prim
import PureSAT.Clause2
import PureSAT.Level
import PureSAT.LitTable
import PureSAT.LitVar
import PureSAT.Utils

-------------------------------------------------------------------------------
-- Trail
-------------------------------------------------------------------------------

data Trail s = Trail !(PrimVar s Int) !(MutablePrimArray s Lit)

newTrail :: Int -> ST s (Trail s)
newTrail :: forall s. Int -> ST s (Trail s)
newTrail Int
capacity = do
    size <- Int -> ST s (PrimVar (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
0
    ls <- newPrimArray capacity
    return (Trail size ls)

cloneTrail :: Trail s -> ST s (Trail s)
cloneTrail :: forall s. Trail s -> ST s (Trail s)
cloneTrail (Trail PrimVar s Int
size MutablePrimArray s Lit
ls) = do
    capacity <- MutablePrimArray (PrimState (ST s)) Lit -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Lit
MutablePrimArray (PrimState (ST s)) Lit
ls
    n <- readPrimVar size
    size' <- newPrimVar n
    ls' <- newPrimArray capacity
    copyMutablePrimArray ls' 0 ls 0 n
    return (Trail size' ls')

extendTrail :: Trail s -> Int -> ST s (Trail s)
extendTrail :: forall s. Trail s -> Int -> ST s (Trail s)
extendTrail trail :: Trail s
trail@(Trail PrimVar s Int
size MutablePrimArray s Lit
ls) Int
newCapacity = do
    oldCapacity <- MutablePrimArray (PrimState (ST s)) Lit -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> m Int
getSizeofMutablePrimArray MutablePrimArray s Lit
MutablePrimArray (PrimState (ST s)) Lit
ls
    let capacity = Int -> Int
nextPowerOf2 (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
oldCapacity Int
newCapacity)
    if capacity <= oldCapacity
    then return trail
    else do
        n <- readPrimVar size
        size' <- newPrimVar n
        ls' <- newPrimArray capacity
        copyMutablePrimArray ls' 0 ls 0 n
        return (Trail size' ls')

indexTrail :: Trail s -> Int -> ST s Lit
indexTrail :: forall s. Trail s -> Int -> ST s Lit
indexTrail (Trail PrimVar s Int
_ MutablePrimArray s Lit
ls) Int
i = MutablePrimArray s Lit -> Int -> ST s Lit
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Lit
ls Int
i

popTrail :: Trail s -> ST s Lit
popTrail :: forall s. Trail s -> ST s Lit
popTrail (Trail PrimVar s Int
size MutablePrimArray s Lit
ls) = do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    ASSERTING(assertST "non empty trail" (n >= 1))
    writePrimVar size (n - 1)
    readPrimArray ls (n - 1)

pushTrail :: Lit -> Trail s -> ST s ()
pushTrail :: forall s. Lit -> Trail s -> ST s ()
pushTrail Lit
l (Trail PrimVar s Int
size MutablePrimArray s Lit
ls) = do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    writePrimVar size (n + 1)
    writePrimArray ls n l

traceTrail :: forall s. LitTable s Clause2 -> Levels s -> Trail s -> ST s ()
traceTrail :: forall s. LitTable s Clause2 -> Levels s -> Trail s -> ST s ()
traceTrail LitTable s Clause2
reasons Levels s
levels (Trail PrimVar s Int
size MutablePrimArray s Lit
lits) = do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    out <- go 0 n
    traceM $ unlines $ "=== Trail ===" : out
  where
    go :: Int -> Int -> ST s [String]
    go :: Int -> Int -> ST s [String]
go Int
i Int
n
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
        = [String] -> ST s [String]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return [String
"=== ===== ==="]

        | Bool
otherwise
        = do
            l <- MutablePrimArray s Lit -> Int -> ST s Lit
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Lit
lits Int
i
            Level d <- getLevel levels l
            c <- readLitTable reasons l
            ls <- go (i + 1) n
            if isNullClause c
            then return ((showChar '@' . shows d . showString " Decided " . showsPrec 11 l) "" : ls)
            else return ((showChar '@' . shows d . showString " Deduced " . showsPrec 11 l . showChar ' ' . showsPrec 11 c) "" : ls)

assertEmptyTrail :: HasCallStack => Trail s -> ST s ()
assertEmptyTrail :: forall s. HasCallStack => Trail s -> ST s ()
assertEmptyTrail (Trail PrimVar s Int
size MutablePrimArray s Lit
_) = do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    assertST "n == 0" $ n == 0
    return ()