module Algorithm.TwoSat where

import Control.Monad.Primitive
import Control.Monad.ST
import qualified Data.Vector.Unboxed as U

import Data.Graph.Sparse
import Data.Graph.Sparse.SCC

{- | 2-SAT

@(a0 \\\/ a1) \/\\ not a0@

>>> twoSat 2 2 (\b -> addClauseCNF b (0, True) (1, True) >> addClauseCNF b (0, False) (0, False))
Just [False,True]

@a0 \/\\ not a1 \/\\ (not a0 \\\/ a1)@

>>> print $ twoSat 2 3 (\b -> addClauseCNF b (0, True) (0, True) >> addClauseCNF b (1, False) (1, False) >> addClauseCNF b (0, False) (1, True))
Nothing

@a0 \\\/ not a0@

>>> twoSat 1 1 (\b -> addClauseCNF b (0, True) (0, False))
Just [True]
-}
twoSat ::
  -- | the number of variables
  Int ->
  -- | upper bound on the number of clauses
  Int ->
  -- | CNF(Conjunctive Normal Form)
  (forall s. CNFBuilder s () -> ST s ()) ->
  Maybe (U.Vector Bool)
twoSat :: ComponentId
-> ComponentId
-> (forall s. CNFBuilder s () -> ST s ())
-> Maybe (Vector Bool)
twoSat ComponentId
numVars ComponentId
numClauses forall s. CNFBuilder s () -> ST s ()
run
  | Bool
satisfiable = Vector Bool -> Maybe (Vector Bool)
forall a. a -> Maybe a
Just (Vector Bool -> Maybe (Vector Bool))
-> Vector Bool -> Maybe (Vector Bool)
forall a b. (a -> b) -> a -> b
$
      ComponentId -> (ComponentId -> Bool) -> Vector Bool
forall a. Unbox a => ComponentId -> (ComponentId -> a) -> Vector a
U.generate ComponentId
numVars ((ComponentId -> Bool) -> Vector Bool)
-> (ComponentId -> Bool) -> Vector Bool
forall a b. (a -> b) -> a -> b
$ \ComponentId
i ->
        Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp ComponentId
i ComponentId -> ComponentId -> Bool
forall a. Ord a => a -> a -> Bool
> Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
  | Bool
otherwise = Maybe (Vector Bool)
forall a. Maybe a
Nothing
  where
    satisfiable :: Bool
satisfiable = Vector Bool -> Bool
U.and (Vector Bool -> Bool) -> Vector Bool -> Bool
forall a b. (a -> b) -> a -> b
$
      ComponentId -> (ComponentId -> Bool) -> Vector Bool
forall a. Unbox a => ComponentId -> (ComponentId -> a) -> Vector a
U.generate ComponentId
numVars ((ComponentId -> Bool) -> Vector Bool)
-> (ComponentId -> Bool) -> Vector Bool
forall a b. (a -> b) -> a -> b
$ \ComponentId
i ->
        Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp ComponentId
i ComponentId -> ComponentId -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector ComponentId -> ComponentId -> ComponentId
forall a. Unbox a => Vector a -> ComponentId -> a
U.unsafeIndex Vector ComponentId
comp (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
    offset :: ComponentId
offset = ComponentId
numVars
    !comp :: Vector ComponentId
comp =
      SparseGraph () -> Vector ComponentId
forall w. SparseGraph w -> Vector ComponentId
stronglyConnectedComponents (SparseGraph () -> Vector ComponentId)
-> SparseGraph () -> Vector ComponentId
forall a b. (a -> b) -> a -> b
$
        ComponentId
-> ComponentId
-> (forall s. CNFBuilder s () -> ST s ())
-> SparseGraph ()
forall w.
Unbox w =>
ComponentId
-> ComponentId
-> (forall s. SparseGraphBuilder s w -> ST s ())
-> SparseGraph w
buildSparseGraph (ComponentId
2 ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
* ComponentId
numVars) (ComponentId
2 ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
* ComponentId
numClauses) CNFBuilder s () -> ST s ()
forall s. CNFBuilder s () -> ST s ()
run

type CNFBuilder s w = SparseGraphBuilder s w

addClauseCNF ::
  (PrimMonad m) =>
  CNFBuilder (PrimState m) () ->
  (Int, Bool) ->
  (Int, Bool) ->
  m ()
addClauseCNF :: forall (m :: * -> *).
PrimMonad m =>
CNFBuilder (PrimState m) ()
-> (ComponentId, Bool) -> (ComponentId, Bool) -> m ()
addClauseCNF CNFBuilder (PrimState m) ()
builder (ComponentId
i, Bool
f) (ComponentId
j, Bool
g) = do
  let !offset :: ComponentId
offset = ComponentId -> ComponentId -> ComponentId
forall a. Integral a => a -> a -> a
quot (CNFBuilder (PrimState m) () -> ComponentId
forall s w. SparseGraphBuilder s w -> ComponentId
numVerticesSGB CNFBuilder (PrimState m) ()
builder) ComponentId
2
  case (Bool
f, Bool
g) of
    (Bool
True, Bool
True) -> do
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
j)
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
i)
    (Bool
True, Bool
False) -> do
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j, ComponentId
i)
    (Bool
False, Bool
True) -> do
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i, ComponentId
j)
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset, ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
    (Bool
False, Bool
False) -> do
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
i, ComponentId
j ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)
      CNFBuilder (PrimState m) () -> Edge -> m ()
forall (m :: * -> *).
PrimMonad m =>
SparseGraphBuilder (PrimState m) () -> Edge -> m ()
addDirectedEdge_ CNFBuilder (PrimState m) ()
builder (ComponentId
j, ComponentId
i ComponentId -> ComponentId -> ComponentId
forall a. Num a => a -> a -> a
+ ComponentId
offset)