module Algorithm.TwoSat where
import Control.Monad.Primitive
import Control.Monad.ST
import qualified Data.Vector.Unboxed as U
import Data.Graph.Sparse
import Data.Graph.Sparse.SCC
twoSat ::
Int ->
Int ->
(forall s. CNFBuilder s () -> ST s ()) ->
Maybe (U.Vector Bool)
twoSat :: ComponentId
-> ComponentId
-> (forall s. CNFBuilder s () -> ST s ())
-> Maybe (Vector Bool)
twoSat ComponentId
numVars ComponentId
numClauses forall s. CNFBuilder s () -> ST s ()
run
| Bool
satisfiable = Vector Bool -> Maybe (Vector Bool)
forall a. a -> Maybe a
Just (Vector Bool -> Maybe (Vector Bool))
-> Vector Bool -> Maybe (Vector Bool)
forall a b. (a -> b) -> a -> b
$
ComponentId -> (ComponentId -> Bool) -> Vector Bool
forall a. Unbox a => ComponentId -> (ComponentId -> a) -> Vector a
U.generate ComponentId
numVars ((ComponentId -> Bool) -> Vector Bool)
-> (ComponentId -> Bool) -> Vector Bool
forall a b. (a -> b) -> a -> b
$ \ComponentId
i ->
Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp ComponentId
i ComponentId -> ComponentId -> Bool
forall a. Ord a => a -> a -> Bool
> Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
| Bool
otherwise = Maybe (Vector Bool)
forall a. Maybe a
Nothing
where
satisfiable :: Bool
satisfiable = Vector Bool -> Bool
U.and (Vector Bool -> Bool) -> Vector Bool -> Bool
forall a b. (a -> b) -> a -> b
$
ComponentId -> (ComponentId -> Bool) -> Vector Bool
forall a. Unbox a => ComponentId -> (ComponentId -> a) -> Vector a
U.generate ComponentId
numVars ((ComponentId -> Bool) -> Vector Bool)
-> (ComponentId -> Bool) -> Vector Bool
forall a b. (a -> b) -> a -> b
$ \ComponentId
i ->
Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp ComponentId
i ComponentId -> ComponentId -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
offset :: ComponentId
offset = ComponentId
numVars
!comp :: Vector ComponentId
comp =
SparseGraph () -> Vector ComponentId
forall w. SparseGraph w -> Vector ComponentId
stronglyConnectedComponents (SparseGraph () -> Vector ComponentId)
-> SparseGraph () -> Vector ComponentId
forall a b. (a -> b) -> a -> b
$
ComponentId
-> ComponentId
-> (forall s. CNFBuilder s () -> ST s ())
-> SparseGraph ()
forall w.
Unbox w =>
ComponentId
-> ComponentId
-> (forall s. SparseGraphBuilder s w -> ST s ())
-> SparseGraph w
buildSparseGraph (ComponentId
2 ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
* ComponentId
numVars) (ComponentId
2 ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
* ComponentId
numClauses) CNFBuilder s () -> ST s ()
forall s. CNFBuilder s () -> ST s ()
run
type CNFBuilder s w = SparseGraphBuilder s w
addClauseCNF ::
(PrimMonad m) =>
CNFBuilder (PrimState m) () ->
(Int, Bool) ->
(Int, Bool) ->
m ()
addClauseCNF :: forall (m :: * -> *).
PrimMonad m =>
CNFBuilder (PrimState m) ()
-> (ComponentId, Bool) -> (ComponentId, Bool) -> m ()
addClauseCNF CNFBuilder (PrimState m) ()
builder (ComponentId
i, Bool
f) (ComponentId
j, Bool
g) = do
let !offset :: ComponentId
offset = ComponentId -> ComponentId -> ComponentId
forall a. Integral a => a -> a -> a
quot (CNFBuilder (PrimState m) () -> ComponentId
forall s w. SparseGraphBuilder s w -> ComponentId
numVerticesSGB CNFBuilder (PrimState m) ()
builder) ComponentId
2
case (Bool
f, Bool
g) of
(Bool
True, Bool
True) -> do
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
j)
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
i)
(Bool
True, Bool
False) -> do
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j, ComponentId
i)
(Bool
False, Bool
True) -> do
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i, ComponentId
j)
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
(Bool
False, Bool
False) -> do
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i, ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j, ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)