sat-solver/main.py
2025-03-07 09:45:39 -06:00

605 lines
16 KiB
Python
Executable file

#!/usr/bin/env python3.10
import string;
import itertools;
import argparse;
def parse_args():
parser = argparse.ArgumentParser(
prog = 'main',
description = """
Proof-of-concept SAT solver.
""",
epilog = """
Bottom Text.
""")
parser.add_argument('input_file', help = """
Set input file path.
""")
parser.add_argument('-d', '--dump-automata', action = 'store_true', help = """
While solving the SAT problem, print out intermmediate automatas
and the final automata.
The .dot file output file names will follow the form: "dot/%%05u.dot"
""")
parser.add_argument('-v', '--validate', action = "store_true",
help = """
Automatically validate the solution (or every solution) by feeding
each solution back into the boolean expression, checking that the
conclusions made by the solver truely are correct.
If the solver concludes that there are no solutions, this option
will check that by evaluating all inputs, checking that they
are indeed no solutions.
If the solver concludes that all possible inputs are valid
solutions, this option will check all inputs, checking that they
are indeed valid solutions.
""");
return parser.parse_args(sys.argv[1:]);
def parse_file(filename):
# list of tuple of (number-of-variables, list of clauses).
# list of clauses is a tuple of (variable, value) tuples.
problems = list();
with open(filename) as stream:
iterator = iter(stream);
for line in iterator:
line = line.strip();
# skip blank lines:
if not line: continue;
# skip comment lines:
if line.startswith("c"): continue;
assert(line.startswith("p"));
words = line.split();
number_of_variables = int(words[1]);
number_of_clauses = int(words[2]);
print(f"number_of_variables = {number_of_variables}");
print(f"number_of_clauses = {number_of_clauses}");
clauses = list();
# list of clauses is a list of (variable, value) tuples
for _ in range(number_of_clauses):
line = next(iterator);
clause = list();
for x in line.split():
y = int(x);
z = abs(y);
if y < 0:
clause.append((z, 0));
else:
clause.append((z, 1));
clauses.append(clause);
problems.append((number_of_variables, clauses));
return problems;
def create_variable_automata(variable, value, number_of_variables):
print(f"create_variable_automata()");
automata = [(0, (i + 1, i + 1)) for i in range(number_of_variables)];
automata.append((1, (None, None)));
v = variable;
# v = variable - 1;
if value:
automata[v] = (0, (None, v + 1));
else:
automata[v] = (0, (v + 1, None));
return automata;
def mass_union(args, subautomatas):
print(f"mass_union(subautomatas = {subautomatas})");
n = len(subautomatas);
mapping = {((0,) * n): 0}; # (indexes of subautomatas) -> index in automata
todo = [((0, ) * n, 0)]
automata = [()];
while todo:
subindexes, index = todo.pop(0);
accepting = int(any(0 if x is None else subautomatas[i][x][0] for i, x in enumerate(subindexes)))
newtos = [None, None];
for on, to in enumerate(zip(*((None, None) if x is None else subautomatas[i][x][1] for i, x in enumerate(subindexes)))):
if any(x is not None for x in to):
if to in mapping:
newtos[on] = mapping[to];
else:
newto = len(automata);
newtos[on] = newto;
automata.append(());
mapping[to] = newto;
todo.append((to, newto));
automata[index] = (accepting, newtos);
return automata;
def intersect_automata(args, lauto, rauto):
print(f"intersect_automata(left = {lauto}, right = {rauto})");
mapping = {(0, 0): 0}; # (left index, right index) -> new index
todo = [((0, 0), 0)]
nauto = [()];
while todo:
(lindex, rindex), nindex = todo.pop(0);
laccepting, ltos = lauto[lindex];
raccepting, rtos = rauto[rindex];
naccepting = int(laccepting and raccepting);
newtos = [None, None];
for on, to in enumerate(zip(ltos, rtos)):
if to[0] != None and to[1] != None:
if to in mapping:
newtos[on] = mapping[to];
else:
newto = len(nauto);
newtos[on] = newto;
nauto.append(());
mapping[to] = newto;
todo.append((to, newto));
nauto[nindex] = (naccepting, newtos);
if args.dump_automata:
dump_automata(nauto, title = f"intersection: {len(nauto)}");
return nauto;
def mass_intersection(args, subautomatas):
print(f"mass_intersection(subautomatas = {subautomatas})");
assert(subautomatas);
if len(subautomatas) == 1:
return subautomatas[0];
else:
n = len(subautomatas) // 2;
left = mass_intersection(args, subautomatas[:n]);
right = mass_intersection(args, subautomatas[n:]);
automata = intersect_automata(args, left, right);
if args.dump_automata:
dump_automata(automata, title = f"mass-intersection: {len(automata)}");
automata = simplify_automata(automata);
if args.dump_automata:
dump_automata(automata, title = f"simplified mass-intersection: {len(automata)}");
return automata;
#def mass_intersection(args, subautomatas):
# print(f"mass_intersection(subautomatas = {subautomatas})");
#
# n = len(subautomatas);
#
# mapping = {((0,) * n): 0}; # (indexes of subautomatas) -> index in automata
#
# todo = [((0, ) * n, 0)]
#
# automata = [()];
#
# while todo:
# subindexes, index = todo.pop(0);
#
# accepting = int(all(subautomatas[i][x][0] for i, x in enumerate(subindexes)))
#
# newtos = [None, None];
#
# for on, to in enumerate(zip(*(subautomatas[i][x][1] for i, x in enumerate(subindexes)))):
# if all(x is not None for x in to):
# if to in mapping:
# newtos[on] = mapping[to];
# else:
# newto = len(automata);
#
# newtos[on] = newto;
#
# automata.append(());
#
# mapping[to] = newto;
#
# todo.append((to, newto));
#
# automata[index] = (accepting, newtos);
#
# if args.dump_automata:
# dump_automata(automata, title = f"mass-intersection: {len(automata)}");
#
# return automata;
def complement_automata(automata):
mapping = {0: 0}; # old index -> new index
todo = [(0, 0)]
new_automata = [()];
phiindex = None;
while todo:
oldindex, newindex = todo.pop(0);
oldaccepting, oldtos = automata[oldindex];
newaccepting = int(not oldaccepting);
newtos = [None, None];
for on, oldto in enumerate(oldtos):
if oldto is not None:
if oldto in mapping:
newtos[on] = mapping[oldto];
else:
newto = len(new_automata);
newtos[on] = newto;
new_automata.append(());
mapping[oldto] = newto;
todo.append((oldto, newto));
else:
if phiindex is None:
phiindex = len(new_automata);
new_automata.append((1, (phiindex, phiindex)));
newtos[on] = phiindex;
new_automata[newindex] = (newaccepting, newtos);
return new_automata;
def simplify_automata(automata):
print(f"simplify_automata()");
n = len(automata);
print(f"n = {n}");
same_as = {i: set(range(n)) for i in range(n)};
dep_on = dict(); # pair of two states [affects this] set of pairs of states
todo = set() # pair of states whose difference we need to process
print(f"simplify_automata: trivial differences");
for i, j in itertools.product(range(n), range(n)):
# percent = (i * n + j) / (n * n) * 100;
#
# print(f"percent = {percent:.2f}%");
def is_different():
iaccepting, itos = automata[i];
jaccepting, jtos = automata[j];
if iaccepting != jaccepting:
return True;
for x, y in zip(itos, jtos):
if (x is None) != (y is None):
return True;
return False;
if is_different():
todo.add((i, j));
else:
for dep in zip(automata[i][1], automata[j][1]):
if dep not in dep_on:
dep_on[dep] = set();
dep_on[dep].add((i, j));
print(f"simplify_automata: nontrivial differences");
x = 0;
while todo:
print(f"simplify_automata: {x} of {x + len(todo)} ({x * 100 / (x + len(todo)):.2f}%)");
i, j = todo.pop();
same_as[i].discard(j);
same_as[j].discard(i);
if (i, j) in dep_on:
todo.update(dep_on[(i, j)]);
x += 1;
print(f"simplify_automata: reachable");
# which states can even reach any accepting state?
# print(f"can_reach: start");
can_reach = set();
dep_on = dict();
todo = set()
for i in range(len(automata)):
if automata[i][0]:
# print(f"can_reach: accepting: {i}");
todo.add(i);
else:
for to in automata[i][1]:
if to not in dep_on:
dep_on[to] = set();
# print(f"can_reach: {to} -> {i}");
dep_on[to].add(i);
while todo:
i = todo.pop();
# print(f"can_reach: {i}");
can_reach.add(i);
if i in dep_on:
todo.update(dep_on[i]);
print(f"simplify_automata: clone");
mapping = {0: 0}; # old index -> new index
todo = [(0, 0)]
new_automata = [()];
phiindex = None;
while todo:
oldindex, newindex = todo.pop(0);
accepting, oldtos = automata[oldindex];
newtos = [None, None];
for on, oldto in enumerate(oldtos):
if oldto is not None and oldto in can_reach:
oldto = min(same_as[oldto]);
if oldto is not None:
if oldto in mapping:
newtos[on] = mapping[oldto];
else:
newto = len(new_automata);
newtos[on] = newto;
new_automata.append(());
mapping[oldto] = newto;
todo.append((oldto, newto));
new_automata[newindex] = (accepting, newtos);
return new_automata;
dump_id = 0;
def dump_automata(automata, title = ""):
global dump_id;
with open(f"{dump_id:05}.dot", "w") as stream:
stream.write(f"""
digraph {{
rankdir = LR;
label = "{title}"
node [
shape = circle
];
""");
for i, x in enumerate(automata):
if x:
(accepting, (on0, on1)) = x;
if accepting:
stream.write(f"{i} [ shape = doublecircle ];\n");
if on0 is not None:
stream.write(f"{i} -> {on0} [label = \"0\"];\n");
if on1 is not None:
stream.write(f"{i} -> {on1} [label = \"1\"];\n");
stream.write("""
}
""");
dump_id += 1;
def build_automata(args, problem):
number_of_variables, clauses = problem;
clause_automatas = list();
variable_lookup = dict();
for clause in clauses:
variable_automatas = list();
for variable, value in clause:
if variable not in variable_lookup:
variable_lookup[variable] = len(variable_lookup);
variable_index = variable_lookup[variable];
variable_automata = create_variable_automata(
variable_index, value, number_of_variables);
if args.dump_automata:
title = f"{variable} ({variable_index}) = {value}"
dump_automata(variable_automata, title);
variable_automatas.append(variable_automata);
clause_automata = mass_union(args, variable_automatas);
clause_automata = simplify_automata(clause_automata);
if args.dump_automata:
title = " or ".join(f"{a} = {b}" for a, b in clause);
dump_automata(clause_automata, title);
clause_automatas.append(clause_automata);
automata = mass_intersection(args, clause_automatas);
if args.dump_automata:
dump_automata(automata, "final");
automata = simplify_automata(automata);
if args.dump_automata:
dump_automata(automata, "simplified: final");
return automata;
def find_solutions(number_of_variables, automata):
# we have three cases: all, none, or some.
def depth_first_search(automata):
def helper(prefix, stateindex):
accepting, tos = automata[stateindex];
if accepting and len(prefix) == number_of_variables:
yield prefix;
if len(prefix) < number_of_variables:
for on, to in enumerate(tos):
if to is not None:
yield from helper(prefix = prefix + (on, ), stateindex = to);
yield from helper(prefix = (), stateindex = 0);
x = depth_first_search(complement_automata(automata));
if next(x, "all solutions") == "all solutions":
return "all solutions";
x = depth_first_search(automata);
if next(x, "no solutions") == "no solutions":
return "no solutions";
for solution in depth_first_search(automata):
pretty_solution = ", ".join(
f"{v} = {s}" for v, s in enumerate(solution));
return pretty_solution;
def do_valdation(automata, solution, number_of_variables):
assert(not "TODO");
def main(args):
problems = parse_file(args.input_file);
for problem in problems:
print(f"problem = {problem}");
number_of_variables, _ = problem;
automata = build_automata(args, problem);
solution = find_solutions(number_of_variables, automata);
if args.validate:
do_valdation(automata);
solution = f"{solution} (validated.)";
print(solution);
import sys;
sys.exit(main(parse_args()));