''' BOTTOM UP ENUMERATIVE SYNTHESIS Ayush Noori CS252R, Fall 2020 Example of usage: python synthesis.py --domain arithmetic --examples addition ''' # load libraries import numpy as np import argparse import itertools import time # import examples from arithmetic import * from strings import * from abstract_syntax_tree import * from examples import example_set, check_examples import config # PARSE ARGUMENTS def parse_args(): ''' Parse command line arguments. ''' parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis in Python.") # define valid choices for the 'domain' argument valid_domain_choices = ["arithmetic", "strings"] # add examples parser.add_argument('--domain', type=str, required=True, # default="arithmetic", choices=valid_domain_choices, help='Domain of synthesis (either "arithmetic" or "string").') parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition", choices=example_set.keys(), help='Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.') parser.add_argument('--max-weight', type=int, required=False, default=3, help='Maximum weight of programs to consider before terminating search.') args = parser.parse_args() return args # EXTRACT CONSTANTS AND VARIABLES def extract_constants(examples): ''' Extracts the constants from the input-output examples. Also constructs variables as needed based on the input-output examples, and adds them to the list of constants. ''' # check validity of provided examples # if valid, extract arity and argument types arity, arg_types = check_examples(examples) # initialize list of constants constants = [] # get unique set of inputs inputs = [input for example in examples for input in example[0]] inputs = set(inputs) # add 1 to the set of inputs inputs.add(1) # extract constants in input for input in inputs: if type(input) == int: constants.append(IntegerConstant(input)) elif type(input) == str: constants.append(StringConstant(input)) pass else: raise Exception("Input of unknown type.") # initialize list of variables variables = [] # extract variables in input for position, arg in enumerate(arg_types): if arg == int: variables.append(IntegerVariable(position)) elif arg == str: variables.append(StringVariable(position)) else: raise Exception("Input of unknown type.") return constants + variables # CHECK OBSERVATIONAL EQUIVALENCE def observationally_equivalent(program_a, program_b, examples): """ Returns True if Program A and Program B are observationally equivalent, False otherwise. """ inputs = [example[0] for example in examples] a_output = [program_a.evaluate(input) for input in inputs] b_output = [program_b.evaluate(input) for input in inputs] return a_output == b_output # CHECK CORRECTNESS def check_program(program, examples): ''' Check whether the program satisfies the input-output examples. ''' inputs = [example[0] for example in examples] outputs = [example[1] for example in examples] program_output = [program.evaluate(input) for input in inputs] return program_output == outputs # RUN SYNTHESIZER def run_synthesizer(args): ''' Run bottom-up enumerative synthesis. ''' # retrieve selected input-output examples examples = example_set[args.examples_key] # extract constants from examples program_bank = extract_constants(examples) program_bank_str = [p.str() for p in program_bank] print("\nSynthesis Log:") print(f"- Extracted {len(program_bank)} constants from examples.") # define operators if args.domain == "arithmetic": operators = arithmetic_operators elif args.domain == "strings": operators = string_operators else: raise Exception('Domain not recognized. Must be either "arithmetic" or "string".') # iterate over each level for weight in range(2, args.max_weight): # print message print(f"- Searching level {weight} with {len(program_bank)} primitives.") # iterate over each operator for op in operators: # get all possible combinations of primitives in program bank combinations = itertools.combinations(program_bank, op.arity) # iterate over each combination for combination in combinations: # get type signature type_signature = [p.type for p in combination] # check if type signature matches operator if type_signature != op.arg_types: continue # check that sum of weights of arguments <= w if sum([p.weight for p in combination]) > weight: continue # create new program program = OperatorNode(op, combination) # check if program is in program bank using string representation if program.str() in program_bank_str: continue # check if program is observationally equivalent to any program in program bank if any([observationally_equivalent(program, p, examples) for p in program_bank]): continue # add program to program bank program_bank.append(program) program_bank_str.append(program.str()) # check if program passes all examples if check_program(program, examples): return(program) # return None if no program is found return None if __name__ == '__main__': # parse command line arguments args = parse_args() # print(args) # run bottom-up enumerative synthesis start_time = time.time() program = run_synthesizer(args) end_time = time.time() elapsed_time = round(end_time - start_time, 4) # check if program was found print("\nSynthesis Results:") if program is None: print(f"- Max weight of {args.max_weight} reached, no program found in {elapsed_time}s.") else: print(f"- Program found in {elapsed_time}s.") print(f"- Program: {program.str()}") print(f"- Program weight: {program.weight}") print(f"- Program return type: {program.type.__name__}")