diff --git a/README.md b/README.md index 48521a1..8aa5396 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,6 @@ To focus on the superoptimizer and not making a comprehensive, realistic assembl There are many possible improvements: -- **Start state.** Right now it assumes the start state is always the same, which means there is no concept of program input. -- **Program equivalence.** A set of inputs and outputs should be specified such that two programs can actually be tested for equivalence. - **Pruning.** Many nonsensical programs are generated, which significantly slows it down. - **More instructions.** There need to be more instructions, especially a conditional instruction, to give the superoptimizer more opportunities to make improvements. diff --git a/assembler.py b/assembler.py index 01fae8f..7f073c3 100644 --- a/assembler.py +++ b/assembler.py @@ -1,34 +1,34 @@ import re -from cpu import CPU +from instruction_set import * + + +INSTRUCTION_REGEX = re.compile(r'(\w+)\s+([-\d]+(?:\s*,\s*[-\d]+)*)') + -# Turns a string into a program. def parse(assembly): + """ + Turns a string into a program + """ lines = assembly.split('\n') - program = [] - cpu = CPU(1) + instructions = [] + mem_size = 1 for line in lines: - match = re.match(r'(\w+)\s+([-\d]+)(?:,\s*([-\d]+)(?:,\s*([-\d]+))?)?', line) + line = line.strip() + if line == '': + continue + match = INSTRUCTION_REGEX.fullmatch(line) if match: - op_str, *args_str = match.groups() - op = cpu.ops[op_str] - args = [int(arg) for arg in args_str if arg is not None] - program.append((op, *args)) - return program - -# Turns a program into a string. -def output(program): - if len(program) == 0: return "\n" - cpu = CPU(1) - assembly = "" - for instruction in program: - op = instruction[0] - args = instruction[1:] - if op.__name__ == cpu.load.__name__: - assembly += f"LOAD {args[0]}\n" - elif op.__name__ == cpu.swap.__name__: - assembly += f"SWAP {args[0]}, {args[1]}\n" - elif op.__name__ == cpu.xor.__name__: - assembly += f"XOR {args[0]}, {args[1]}\n" - elif op.__name__ == cpu.inc.__name__: - assembly += f"INC {args[0]}\n" - return assembly \ No newline at end of file + op, args_str = match.groups() + args = tuple(int(arg) for arg in args_str.split(",")) + operand_types = OPS[op] + if len(args) != len(operand_types): + raise ValueError(f'Wrong number of operands: {line}') + for arg, arg_type in zip(args, operand_types): + if arg_type == 'mem': + if arg < 0: + raise ValueError(f'Negative memory address: {line}') + mem_size = max(arg + 1, mem_size) + instructions.append(Instruction(op, args)) + else: + raise ValueError(f'Invalid syntax: {line}') + return Program(tuple(instructions), mem_size) diff --git a/brute_force_equivialence_checker.py b/brute_force_equivialence_checker.py new file mode 100644 index 0000000..ba7dc8c --- /dev/null +++ b/brute_force_equivialence_checker.py @@ -0,0 +1,29 @@ +from cpu import CPU + + +class BruteForceEquivalenceChecker: + def __init__(self, program1, bit_width, input_size): + self.program1 = program1 + self.bit_width = bit_width + self.max_val = 2 ** bit_width + self.input_size = input_size + + def generate_inputs(self, input_size): + """ + Generates all possible tuples of the given size with values ranging from 0 (inclusive) + to `max_val` (exclusive). + """ + if input_size == 0: + yield () + else: + for x in range(self.max_val): + for rest in self.generate_inputs(input_size - 1): + yield x, *rest + + def is_equivalent_to(self, program2): + mem_size = max(self.program1.mem_size, program2.mem_size) + cpu = CPU(mem_size, self.bit_width) + for input in self.generate_inputs(self.input_size): + if cpu.execute(self.program1, input) != cpu.execute(program2, input): + return False + return True diff --git a/cpu.py b/cpu.py index d7d1897..0725bf4 100644 --- a/cpu.py +++ b/cpu.py @@ -1,30 +1,34 @@ +import assembler + + +def run(assembly, bit_width, input=()): + """ + Helper function that runs a piece of assembly code. + """ + program = assembler.parse(assembly) + cpu = CPU(program.mem_size, bit_width) + return cpu.execute(program, input) + + class CPU: - def __init__(self, max_mem_cells): + def __init__(self, max_mem_cells, bit_width): self.max_mem_cells = max_mem_cells - self.state = [0] * max_mem_cells - self.ops = {'LOAD': self.load, 'SWAP': self.swap, 'XOR': self.xor, 'INC': self.inc} + self.limit = 2 ** bit_width - def execute(self, program): - state = self.state.copy() - for instruction in program: - op = instruction[0] - args = list(instruction[1:]) - args.insert(0, state) - state = op(*args) - return state - - def load(self, state, val): - state[0] = val - return state - - def swap(self, state, mem1, mem2): - state[mem1], state[mem2] = state[mem2], state[mem1] - return state - - def xor(self, state, mem1, mem2): - state[mem1] = state[mem1] ^ state[mem2] + def execute(self, program, input=()): + state = [0] * self.max_mem_cells + state[0: len(input)] = input + for instruction in program.instructions: + match instruction.opcode: + case 'LOAD': + state[0] = instruction.args[0] % self.limit + case 'SWAP': + mem1, mem2 = instruction.args + state[mem1], state[mem2] = state[mem2], state[mem1] + case 'XOR': + mem1, mem2 = instruction.args + state[mem1] ^= state[mem2] + case 'INC': + mem = instruction.args[0] + state[mem] = (state[mem] + 1) % self.limit return state - - def inc(self, state, mem): - state[mem] += 1 - return state \ No newline at end of file diff --git a/instruction_set.py b/instruction_set.py new file mode 100644 index 0000000..4af93a8 --- /dev/null +++ b/instruction_set.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + + +@dataclass +class Instruction: + opcode: str + args: tuple[int, ...] + + def __str__(self): + args = ", ".join(str(arg) for arg in self.args) + return f"{self.opcode} {args}" + + +@dataclass +class Program: + instructions: tuple[Instruction, ...] + """ + The instructions that make up this program + """ + + mem_size: int + """ + The amount of memory needed to run this program + """ + + def __str__(self): + return "\n".join(str(instr) for instr in self.instructions) + "\n" + + +OPS = { + "LOAD": ("const",), + "SWAP": ("mem", "mem"), + "XOR": ("mem", "mem"), + "INC": ("mem",) +} diff --git a/main.py b/main.py index 3a1324a..1a1740f 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,40 @@ -from superoptimizer import * +from superoptimizer import optimize +from cpu import run + + +def print_optimal_from_code(assembly, max_length, bit_width, debug=False): + print(f"***Source***{assembly}") + state = run(assembly, bit_width) + print("***State***") + print(state) + print() + print("***Optimal***") + print(optimize(assembly, max_length, bit_width, debug=debug)) + print("=" * 20) + print() + def main(): # Test 1 assembly = """ -LOAD 3 -SWAP 0, 1 -LOAD 3 -SWAP 0, 2 -LOAD 3 -SWAP 0, 3 -LOAD 3 + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 + SWAP 0, 3 + LOAD 3 """ - optimal_from_code(assembly, 4, 4, 5) + print_optimal_from_code(assembly, 4, 2) # Test 2 - state = [0, 2, 1] - optimal_from_state(state, 3, 5) + assembly = """ + LOAD 2 + SWAP 0, 1 + LOAD 1 + SWAP 0, 2 + """ + print_optimal_from_code(assembly, 3, 2) - ## Test 3 - Careful, I don't think this will finish for days. - # state = [2, 4, 6, 8, 10, 12] - # optimal_from_state(state, 10, 15, True) -main() \ No newline at end of file +main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c063b90 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +pytest +z3-solver diff --git a/smt_based_equivalence_checker.py b/smt_based_equivalence_checker.py new file mode 100644 index 0000000..5c71f58 --- /dev/null +++ b/smt_based_equivalence_checker.py @@ -0,0 +1,16 @@ +import z3 +from smt_program_simulator import simulate + + +class SmtBasedEquivalenceChecker: + def __init__(self, program1, bit_width, input_size): + self.solver = z3.Solver() + self.bit_width = bit_width + self.input_size = input_size + self.mem_size = program1.mem_size + self.state1 = simulate(program1, self.mem_size, bit_width, input_size) + + def is_equivalent_to(self, program2): + state2 = simulate(program2, self.mem_size, self.bit_width, self.input_size) + programs_are_different = z3.Or(*(value1 != value2 for value1, value2 in zip(self.state1, state2))) + return self.solver.check(programs_are_different) == z3.unsat diff --git a/smt_program_simulator.py b/smt_program_simulator.py new file mode 100644 index 0000000..258f892 --- /dev/null +++ b/smt_program_simulator.py @@ -0,0 +1,32 @@ +import z3 + + +def simulate(program, mem_size, bit_width, input_size): + """ + Simulate the behavior of the program using an SMT solver. + + The result will be a list containing, for each memory cell, an SMT value representing the + value that will reside in that memory location after running the program. + """ + + def mem_cell(i): + if i < input_size: + return z3.BitVec(f'input{i}', bit_width) + else: + return z3.BitVecVal(0, bit_width) + + state = [mem_cell(i) for i in range(mem_size)] + for instruction in program.instructions: + match instruction.opcode: + case 'LOAD': + state[0] = z3.BitVecVal(instruction.args[0], bit_width) + case 'SWAP': + mem1, mem2 = instruction.args + state[mem1], state[mem2] = state[mem2], state[mem1] + case 'XOR': + mem1, mem2 = instruction.args + state[mem1] ^= state[mem2] + case 'INC': + mem = instruction.args[0] + state[mem] += 1 + return state diff --git a/superoptimizer.py b/superoptimizer.py index b098db1..b4dd501 100644 --- a/superoptimizer.py +++ b/superoptimizer.py @@ -1,61 +1,59 @@ from itertools import product -from cpu import CPU import assembler +from brute_force_equivialence_checker import BruteForceEquivalenceChecker +from instruction_set import * -# Helper function that finds the optimal code given the assembly code. -def optimal_from_code(assembly, max_length, max_mem, max_val, debug=False): - cpu = CPU(max_mem) + +def optimize(assembly, max_length, bit_width, input_size=0, *, + equivalence_checker=BruteForceEquivalenceChecker, + debug=False): + """ + Helper function that finds the optimal code given the assembly code. + """ program = assembler.parse(assembly) - state = cpu.execute(program) - print(f"***Source***{assembly}") - optimal_from_state(state, max_length, max_val, debug) - -# Helper function that finds the optimal code given the goal state. -def optimal_from_state(state, max_length, max_val, debug=False): - max_mem = len(state) - print(f"***State***\n{state}\n") - opt = Superoptimizer() - shortest_program = opt.search(max_length, max_mem, max_val, state, debug) - disassembly = assembler.output(shortest_program) - print(f"***Optimal***\n{disassembly}\n{'='*20}\n") + opt = Superoptimizer(equivalence_checker) + return opt.search(max_length, bit_width, program, input_size, debug) + class Superoptimizer: - def __init__(self): - self.program_cache = {} + def __init__(self, equivalence_checker_class): + self.equivalence_checker_class = equivalence_checker_class + + @staticmethod + def generate_operands(operand_type, max_mem, bit_width): + if operand_type == "const": + return range(2 ** bit_width) + elif operand_type == "mem": + return range(max_mem) + else: + raise ValueError(f"Illegal operand type: {operand_type}") - # Generates all possible programs. - def generate_programs(self, cpu, max_length, max_mem, max_val): + @staticmethod + def generate_programs(max_length, max_mem, bit_width): + """ + Generates all possible programs + """ + yield Program((), 0) for length in range(1, max_length + 1): - for prog in product(cpu.ops.values(), repeat=length): - arg_sets = [] - for op in prog: - if op == cpu.load: - arg_sets.append([tuple([val]) for val in range(max_val + 1)]) - elif op == cpu.swap or op == cpu.xor: - arg_sets.append(product(range(max_mem), repeat=2)) - elif op == cpu.inc: - arg_sets.append([tuple([val]) for val in range(max_mem)]) - for arg_set in product(*arg_sets): - program = [(op, *args) for op, args in zip(prog, arg_set)] - yield program - - # Tests all of the generated programs and returns the shortest. - def search(self, max_length, max_mem, max_val, target_state, debug=False): + instructions = [] + for op, operand_types in OPS.items(): + arg_sets = (Superoptimizer.generate_operands(ot, max_mem, bit_width) for ot in operand_types) + instructions.extend(assembler.Instruction(op, args) for args in product(*arg_sets)) + for prog in product(instructions, repeat=length): + yield Program(prog, max_mem) + + # Tests all the generated programs and returns the shortest. + def search(self, max_length, bit_width, program, input_size=0, debug=False): count = 0 - cpu = CPU(max_mem) - for program in self.generate_programs(cpu, max_length, max_mem, max_val): - state = cpu.execute(program) - if state == target_state: - state = tuple(state) - if state not in self.program_cache or len(program) < len(self.program_cache[state]): - self.program_cache[state] = program - + equivalence_checker = self.equivalence_checker_class(program, bit_width, input_size) + for optimal in self.generate_programs(max_length, program.mem_size, bit_width): + if equivalence_checker.is_equivalent_to(optimal): + return optimal + # Debugging. if debug: count += 1 - if count % 1000000 == 0: print(f"Programs searched: {count:,}") - if count % 10000000 == 0: - solution = self.program_cache.get(tuple(target_state), None) - print(f"Best solution: {solution}") + if count % 1000000 == 0: + print(f"Programs searched: {count:,}") - return self.program_cache.get(tuple(target_state), None) + return None diff --git a/tests/test_assembler.py b/tests/test_assembler.py new file mode 100644 index 0000000..2c730ae --- /dev/null +++ b/tests/test_assembler.py @@ -0,0 +1,50 @@ +import pytest + +from assembler import parse +from instruction_set import * + + +def test_empty_program(): + assert parse('') == Program((), 1) + + +def test_that_all_instructions_and_mem_size(): + assembly = """ + LOAD 42 + XOR 2, 3 + SWAP 42, 23 + INC 13 + """ + instructions = ( + Instruction('LOAD', (42,)), + Instruction('XOR', (2, 3)), + Instruction('SWAP', (42, 23)), + Instruction('INC', (13,)) + ) + assert parse(assembly) == Program(instructions, 43) + + +def test_syntax_errors(): + with pytest.raises(ValueError): + parse("LOAD !&%*") + + with pytest.raises(ValueError): + parse("LOAD") + + with pytest.raises(ValueError): + parse("LOAD 23, 42") + + with pytest.raises(ValueError): + parse("XOR 23") + + with pytest.raises(ValueError): + parse("XOR 23, -42") + + with pytest.raises(ValueError): + parse("INC 1, 2") + + with pytest.raises(ValueError): + parse("SWAP 23") + + with pytest.raises(ValueError): + parse("SWAP 23, -42") diff --git a/tests/test_cpu.py b/tests/test_cpu.py new file mode 100644 index 0000000..778f744 --- /dev/null +++ b/tests/test_cpu.py @@ -0,0 +1,59 @@ +from cpu import run + + +def test_load_and_swap(): + assembly = """ + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 + SWAP 0, 3 + LOAD 3 + """ + assert run(assembly, 8) == [3, 3, 3, 3] + assert run(assembly, 1) == [1, 1, 1, 1] + + +def test_load_and_xor(): + assembly = """ + LOAD 42 + XOR 1, 0 + LOAD 23 + XOR 1, 0 + """ + assert run(assembly, 8) == [23, 42 ^ 23] + + +def test_load_and_inc(): + assembly = """ + LOAD 41 + INC 0 + INC 1 + INC 1 + INC 1 + """ + assert run(assembly, 8) == [42, 3] + + +def test_input(): + assembly = """ + XOR 1, 0 + INC 1 + """ + assert run(assembly, 8) == [0, 1] + assert run(assembly, 8, [2]) == [2, 3] + assert run(assembly, 8, [1, 2]) == [1, 4] + + +def test_load_only(): + assembly = 'LOAD 42' + assert run(assembly, 8) == [42] + + +def test_wrap_around(): + assembly = """ + LOAD 255 + INC 0 + """ + assert run(assembly, 8) == [0] diff --git a/tests/test_smt_baseed_equivalence_checker.py b/tests/test_smt_baseed_equivalence_checker.py new file mode 100644 index 0000000..5e3b4c2 --- /dev/null +++ b/tests/test_smt_baseed_equivalence_checker.py @@ -0,0 +1,66 @@ +import assembler +from smt_based_equivalence_checker import SmtBasedEquivalenceChecker + + +def are_equivalent(program1, program2, bit_width, input_size): + return SmtBasedEquivalenceChecker(program1, bit_width, input_size).is_equivalent_to(program2) + + +def test_add_three(): + three_incs = assembler.parse(""" + INC 0 + INC 0 + INC 0 + """) + load_three = assembler.parse('LOAD 3') + one_inc = assembler.parse('INC 0') + + assert are_equivalent(three_incs, load_three, 8, 0) + assert not are_equivalent(three_incs, one_inc, 8, 0) + # With a single bit, += 1 and += 3 are equivalent + assert are_equivalent(three_incs, one_inc, 1, 0) + + # If there's user input, setting to three and increasing by three are no longer equivalent + assert not are_equivalent(three_incs, load_three, 8, 1) + # However, += 1 and += 3 are still equivalent for single bits + assert are_equivalent(three_incs, one_inc, 1, 1) + + +def test_swap_vs_xor(): + swap_with_xor = assembler.parse(""" + XOR 0, 1 + XOR 1, 0 + XOR 0, 1 + """) + swap_with_swap = assembler.parse('SWAP 0, 1') + just_xor = assembler.parse('XOR 0, 1') + + assert are_equivalent(swap_with_xor, swap_with_swap, 8, 2) + assert not are_equivalent(swap_with_xor, just_xor, 8, 2) + assert not are_equivalent(swap_with_swap, just_xor, 8, 2) + + +def test_large_program_with_lots_of_inputs(): + program1 = assembler.parse(""" + INC 0 + XOR 0, 1 + XOR 1, 0 + XOR 0, 1 + INC 1 + SWAP 1, 2 + INC 3 + XOR 3, 2 + INC 4 + INC 5 + """) + program2 = assembler.parse(""" + SWAP 0, 2 + SWAP 0, 1 + INC 2 + INC 2 + INC 3 + XOR 3, 2 + INC 4 + INC 5 + """) + assert are_equivalent(program1, program2, 8, 6) diff --git a/tests/test_superoptimizer.py b/tests/test_superoptimizer.py new file mode 100644 index 0000000..60f0ed5 --- /dev/null +++ b/tests/test_superoptimizer.py @@ -0,0 +1,155 @@ +from superoptimizer import optimize +from assembler import parse, Program +from smt_based_equivalence_checker import SmtBasedEquivalenceChecker + + +MAX_LENGTH = 1000000 + + +def optimize_with_both(*args): + result1 = optimize(*args) + result2 = optimize(*args, equivalence_checker=SmtBasedEquivalenceChecker) + assert result1 == result2 + return result1 + + +def test_four_threes(): + assembly = """ + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 + SWAP 0, 3 + LOAD 3 + """ + optimal = parse(""" + LOAD 3 + XOR 1, 0 + XOR 2, 0 + XOR 3, 0 + """) + # This test case takes too long with the SMT-based equivalence checker, so we only test with the brute force one + # (which doesn't actually need much force when the input size is 0) + assert optimize(assembly, MAX_LENGTH, 2, 0) == optimal + + +def test_three_threes(): + assembly = """ + LOAD 3 + SWAP 0, 1 + LOAD 3 + SWAP 0, 2 + LOAD 3 + """ + optimal = parse(""" + LOAD 3 + XOR 1, 0 + XOR 2, 0 + """) + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal + # Assert that the program is still found with a tight max_length + assert optimize_with_both(assembly, 3, 2, 0) == optimal + # Assert that the program is not found with a max_length that's below the optimal length + assert optimize_with_both(assembly, 2, 2, 0) is None + + # Changing the input size to 1 doesn't change anything as the first input will be overridden by the load + assert optimize_with_both(assembly, MAX_LENGTH, 2, 1) == optimal + + # For input size 2, we'll need to clear the second input using swap and another load + optimal = parse(""" + LOAD 3 + SWAP 0, 1 + LOAD 3 + XOR 2, 0 + """) + assert optimize_with_both(assembly, MAX_LENGTH, 2, 2) == optimal + + +def test_0_2_1(): + assembly = """ + LOAD 2 + SWAP 0, 1 + LOAD 1 + SWAP 0, 2 + """ + optimal = parse(""" + LOAD 2 + SWAP 0, 1 + INC 2 + """) + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal + + +def test_no_op(): + assembly = """ + SWAP 0,0 + """ + empty_program = Program((), 0) + # Program results in the memory being unchanged, so optimal program is empty + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == empty_program + assert optimize_with_both(assembly, MAX_LENGTH, 1, 1) == empty_program + assert optimize_with_both(assembly, MAX_LENGTH, 1, 2) == empty_program + + +def test_increasing_sequence(): + assembly = """ + INC 0 + INC 1 + INC 1 + """ + optimal = parse(""" + LOAD 1 + XOR 1, 0 + INC 1 + """) + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal + + assembly = """ + INC 0 + INC 1 + INC 1 + INC 2 + INC 2 + INC 2 + """ + optimal = parse(""" + LOAD 1 + XOR 1, 0 + XOR 2, 0 + INC 1 + XOR 2, 1 + """) + assert optimize_with_both(assembly, MAX_LENGTH, 2, 0) == optimal + + +def test_increasing_from_input(): + # Given the input x, the following program should produce the sequence x+1, x+2, x+3 + assembly = """ + XOR 1, 0 + XOR 2, 0 + INC 0 + INC 1 + INC 1 + INC 2 + INC 2 + INC 2 + """ + optimal = parse(""" + INC 0 + XOR 1, 0 + INC 1 + XOR 2, 1 + INC 2 + """) + assert optimize(assembly, MAX_LENGTH, 2, 1) == optimal + + +def test_add_to_three_mem_cells(): + assembly = """ + INC 0 + INC 1 + INC 2 + """ + optimal = parse(assembly) + assert optimize_with_both(assembly, MAX_LENGTH, 2, 2) == optimal