{-# LANGUAGE CPP             #-}
{-# LANGUAGE RecordWildCards #-}
module PureSAT.LitSet where

#define ASSERTING(x)

import Data.Primitive.PrimVar (readPrimVar)

import PureSAT.Base
import PureSAT.Clause2
import PureSAT.LitVar
import PureSAT.Prim
import PureSAT.SparseSet

-------------------------------------------------------------------------------
-- LitSet
-------------------------------------------------------------------------------

newtype LitSet s = LS (SparseSet s)

indexLitSet :: forall s. LitSet s -> Int -> ST s Lit
indexLitSet :: forall s. LitSet s -> Int -> ST s Lit
indexLitSet (LS SparseSet s
xs) Int
i = ST s Int -> ST s Lit
forall a b. Coercible a b => a -> b
coerce (forall s. SparseSet s -> Int -> ST s Int
indexSparseSet @s SparseSet s
xs Int
i)

newLitSet :: Int -> ST s (LitSet s)
newLitSet :: forall s. Int -> ST s (LitSet s)
newLitSet Int
n = SparseSet s -> LitSet s
forall s. SparseSet s -> LitSet s
LS (SparseSet s -> LitSet s) -> ST s (SparseSet s) -> ST s (LitSet s)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ST s (SparseSet s)
forall s. Int -> ST s (SparseSet s)
newSparseSet Int
n

insertLitSet :: Lit -> LitSet s -> ST s ()
insertLitSet :: forall s. Lit -> LitSet s -> ST s ()
insertLitSet (MkLit Int
l) (LS SparseSet s
ls) = SparseSet s -> Int -> ST s ()
forall s. SparseSet s -> Int -> ST s ()
insertSparseSet SparseSet s
ls Int
l

deleteLitSet :: Lit -> LitSet s -> ST s ()
deleteLitSet :: forall s. Lit -> LitSet s -> ST s ()
deleteLitSet (MkLit Int
l) (LS SparseSet s
ls) = SparseSet s -> Int -> ST s ()
forall s. SparseSet s -> Int -> ST s ()
deleteSparseSet SparseSet s
ls Int
l

{-# INLINE minViewLitSet #-}
minViewLitSet :: LitSet s -> ST s r -> (Lit -> ST s r) -> ST s r
minViewLitSet :: forall s r. LitSet s -> ST s r -> (Lit -> ST s r) -> ST s r
minViewLitSet (LS SparseSet s
xs) ST s r
no Lit -> ST s r
yes = SparseSet s -> ST s r -> (Int -> ST s r) -> ST s r
forall s r. SparseSet s -> ST s r -> (Int -> ST s r) -> ST s r
popSparseSet_ SparseSet s
xs ST s r
no ((Lit -> ST s r) -> Int -> ST s r
forall a b. Coercible a b => a -> b
coerce Lit -> ST s r
yes)

clearLitSet :: LitSet s -> ST s ()
clearLitSet :: forall s. LitSet s -> ST s ()
clearLitSet (LS SparseSet s
xs) = SparseSet s -> ST s ()
forall s. SparseSet s -> ST s ()
clearSparseSet SparseSet s
xs

elemsLitSet :: LitSet s -> ST s [Lit]
elemsLitSet :: forall s. LitSet s -> ST s [Lit]
elemsLitSet (LS SparseSet s
s) = ST s [Int] -> ST s [Lit]
forall a b. Coercible a b => a -> b
coerce (SparseSet s -> ST s [Int]
forall s. SparseSet s -> ST s [Int]
elemsSparseSet SparseSet s
s)

memberLitSet :: LitSet s -> Lit -> ST s Bool
memberLitSet :: forall s. LitSet s -> Lit -> ST s Bool
memberLitSet (LS SparseSet s
xs) (MkLit Int
x) = SparseSet s -> Int -> ST s Bool
forall s. SparseSet s -> Int -> ST s Bool
memberSparseSet SparseSet s
xs Int
x

sizeofLitSet :: LitSet s -> ST s Int
sizeofLitSet :: forall s. LitSet s -> ST s Int
sizeofLitSet (LS SparseSet s
xs) = SparseSet s -> ST s Int
forall s. SparseSet s -> ST s Int
sizeofSparseSet SparseSet s
xs

unsingletonLitSet :: LitSet s -> ST s Lit
unsingletonLitSet :: forall s. LitSet s -> ST s Lit
unsingletonLitSet (LS SS {PrimVar s Int
MutablePrimArray s Int
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
sparse :: forall s. SparseSet s -> MutablePrimArray s Int
dense :: forall s. SparseSet s -> MutablePrimArray s Int
size :: forall s. SparseSet s -> PrimVar s Int
..}) = do
    -- ASSERTING(n <- readPrimVar size)
    ASSERTING(assertST "size == 1" (n == 1))
    x <- MutablePrimArray s Int -> Int -> ST s Int
forall a s.
(HasCallStack, Prim a) =>
MutablePrimArray s a -> Int -> ST s a
readPrimArray MutablePrimArray s Int
dense Int
0
    return (MkLit x)

litSetToClause :: LitSet s -> ST s Clause2
litSetToClause :: forall s. LitSet s -> ST s Clause2
litSetToClause (LS SS {PrimVar s Int
MutablePrimArray s Int
sparse :: forall s. SparseSet s -> MutablePrimArray s Int
dense :: forall s. SparseSet s -> MutablePrimArray s Int
size :: forall s. SparseSet s -> PrimVar s Int
size :: PrimVar s Int
dense :: MutablePrimArray s Int
sparse :: MutablePrimArray s Int
..}) = do
    n <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
    ASSERTING(assertST "size >= 2" (n >= 2))
    l1 <- readPrimArray dense 0
    l2 <- readPrimArray dense 1
    ls <- freezePrimArray dense 2 (n - 2)
    -- TODO: learned clauses only
    return $! MkClause2 True (coerce l1) (coerce l2) (coercePrimArrayLit ls)