program-synthesis / synthesis.py
ayushnoori's picture
Enrich string examples
ea7fc19
'''
BOTTOM UP ENUMERATIVE SYNTHESIS
Ayush Noori
CS252R, Fall 2020
Example of usage:
python synthesis.py --domain arithmetic --examples addition
'''
# load libraries
import numpy as np
import argparse
import itertools
import time
# import examples
from arithmetic import *
from strings import *
from abstract_syntax_tree import *
from examples import example_set, check_examples
import config
# PARSE ARGUMENTS
def parse_args():
'''
Parse command line arguments.
'''
parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis in Python.")
# define valid choices for the 'domain' argument
valid_domain_choices = ["arithmetic", "strings"]
# add examples
parser.add_argument('--domain', type=str, required=True, # default="arithmetic",
choices=valid_domain_choices,
help='Domain of synthesis (either "arithmetic" or "string").')
parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition",
choices=example_set.keys(),
help='Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.')
parser.add_argument('--max-weight', type=int, required=False, default=3,
help='Maximum weight of programs to consider before terminating search.')
args = parser.parse_args()
return args
# EXTRACT CONSTANTS AND VARIABLES
def extract_constants(examples):
'''
Extracts the constants from the input-output examples. Also constructs variables as needed
based on the input-output examples, and adds them to the list of constants.
'''
# check validity of provided examples
# if valid, extract arity and argument types
arity, arg_types = check_examples(examples)
# initialize list of constants
constants = []
# get unique set of inputs
inputs = [input for example in examples for input in example[0]]
inputs = set(inputs)
# add 1 to the set of inputs
inputs.add(1)
# extract constants in input
for input in inputs:
if type(input) == int:
constants.append(IntegerConstant(input))
elif type(input) == str:
constants.append(StringConstant(input))
pass
else:
raise Exception("Input of unknown type.")
# initialize list of variables
variables = []
# extract variables in input
for position, arg in enumerate(arg_types):
if arg == int:
variables.append(IntegerVariable(position))
elif arg == str:
variables.append(StringVariable(position))
else:
raise Exception("Input of unknown type.")
return constants + variables
# CHECK OBSERVATIONAL EQUIVALENCE
def observationally_equivalent(program_a, program_b, examples):
"""
Returns True if Program A and Program B are observationally equivalent, False otherwise.
"""
inputs = [example[0] for example in examples]
a_output = [program_a.evaluate(input) for input in inputs]
b_output = [program_b.evaluate(input) for input in inputs]
return a_output == b_output
# CHECK CORRECTNESS
def check_program(program, examples):
'''
Check whether the program satisfies the input-output examples.
'''
inputs = [example[0] for example in examples]
outputs = [example[1] for example in examples]
program_output = [program.evaluate(input) for input in inputs]
return program_output == outputs
# RUN SYNTHESIZER
def run_synthesizer(args):
'''
Run bottom-up enumerative synthesis.
'''
# retrieve selected input-output examples
examples = example_set[args.examples_key]
# extract constants from examples
program_bank = extract_constants(examples)
program_bank_str = [p.str() for p in program_bank]
print("\nSynthesis Log:")
print(f"- Extracted {len(program_bank)} constants from examples.")
# define operators
if args.domain == "arithmetic":
operators = arithmetic_operators
elif args.domain == "strings":
operators = string_operators
else:
raise Exception('Domain not recognized. Must be either "arithmetic" or "string".')
# iterate over each level
for weight in range(2, args.max_weight):
# print message
print(f"- Searching level {weight} with {len(program_bank)} primitives.")
# iterate over each operator
for op in operators:
# get all possible combinations of primitives in program bank
combinations = itertools.combinations(program_bank, op.arity)
# iterate over each combination
for combination in combinations:
# get type signature
type_signature = [p.type for p in combination]
# check if type signature matches operator
if type_signature != op.arg_types:
continue
# check that sum of weights of arguments <= w
if sum([p.weight for p in combination]) > weight:
continue
# create new program
program = OperatorNode(op, combination)
# check if program is in program bank using string representation
if program.str() in program_bank_str:
continue
# check if program is observationally equivalent to any program in program bank
if any([observationally_equivalent(program, p, examples) for p in program_bank]):
continue
# add program to program bank
program_bank.append(program)
program_bank_str.append(program.str())
# check if program passes all examples
if check_program(program, examples):
return(program)
# return None if no program is found
return None
if __name__ == '__main__':
# parse command line arguments
args = parse_args()
# print(args)
# run bottom-up enumerative synthesis
start_time = time.time()
program = run_synthesizer(args)
end_time = time.time()
elapsed_time = round(end_time - start_time, 4)
# check if program was found
print("\nSynthesis Results:")
if program is None:
print(f"- Max weight of {args.max_weight} reached, no program found in {elapsed_time}s.")
else:
print(f"- Program found in {elapsed_time}s.")
print(f"- Program: {program.str()}")
print(f"- Program weight: {program.weight}")
print(f"- Program return type: {program.type.__name__}")