'''
BOTTOM UP ENUMERATIVE SYNTHESIS
Ayush Noori
CS252R, Fall 2020

Example of usage:
python synthesis.py --domain arithmetic --examples addition
'''

# load libraries
import numpy as np
import argparse
import itertools
import time

# import examples
from arithmetic import *
from strings import *
from abstract_syntax_tree import *
from examples import example_set, check_examples
import config


# PARSE ARGUMENTS
def parse_args():
    '''
    Parse command line arguments.
    '''

    parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis in Python.")

    # define valid choices for the 'domain' argument
    valid_domain_choices = ["arithmetic", "strings"]

    # add examples
    parser.add_argument('--domain', type=str, required=True, # default="arithmetic",
                        choices=valid_domain_choices,
                        help='Domain of synthesis (either "arithmetic" or "string").')

    parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition",
                        choices=example_set.keys(),
                        help='Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.')
    
    parser.add_argument('--max-weight', type=int, required=False, default=3,
                        help='Maximum weight of programs to consider before terminating search.')

    args = parser.parse_args()
    return args


# EXTRACT CONSTANTS AND VARIABLES
def extract_constants(examples):
    '''
    Extracts the constants from the input-output examples. Also constructs variables as needed
    based on the input-output examples, and adds them to the list of constants.
    '''

    # check validity of provided examples
    # if valid, extract arity and argument types
    arity, arg_types = check_examples(examples)

    # initialize list of constants
    constants = []

    # get unique set of inputs
    inputs = [input for example in examples for input in example[0]]
    inputs = set(inputs)

    # add 1 to the set of inputs
    inputs.add(1)

    # extract constants in input
    for input in inputs:

        if type(input) == int:
            constants.append(IntegerConstant(input))
        elif type(input) == str:
            constants.append(StringConstant(input))
            pass
        else:
            raise Exception("Input of unknown type.")
        
    # initialize list of variables
    variables = []

    # extract variables in input
    for position, arg in enumerate(arg_types):
        if arg == int:
            variables.append(IntegerVariable(position))
        elif arg == str:
            variables.append(StringVariable(position))
        else:
            raise Exception("Input of unknown type.")

    return constants + variables


# CHECK OBSERVATIONAL EQUIVALENCE
def observationally_equivalent(program_a, program_b, examples):
    """
    Returns True if Program A and Program B are observationally equivalent, False otherwise.
    """

    inputs = [example[0] for example in examples]
    a_output = [program_a.evaluate(input) for input in inputs]
    b_output = [program_b.evaluate(input) for input in inputs]

    return a_output == b_output


# CHECK CORRECTNESS
def check_program(program, examples):
    '''
    Check whether the program satisfies the input-output examples.
    '''
    
    inputs = [example[0] for example in examples]
    outputs = [example[1] for example in examples]
    program_output = [program.evaluate(input) for input in inputs]

    return program_output == outputs


# RUN SYNTHESIZER
def run_synthesizer(args):
    '''
    Run bottom-up enumerative synthesis.
    '''

    # retrieve selected input-output examples
    examples = example_set[args.examples_key]

    # extract constants from examples
    program_bank = extract_constants(examples)
    program_bank_str = [p.str() for p in program_bank]
    print("\nSynthesis Log:")
    print(f"- Extracted {len(program_bank)} constants from examples.")

    # define operators
    if args.domain == "arithmetic":
        operators = arithmetic_operators
    elif args.domain == "strings":
        operators = string_operators
    else:
        raise Exception('Domain not recognized. Must be either "arithmetic" or "string".')

    # iterate over each level
    for weight in range(2, args.max_weight):

        # print message
        print(f"- Searching level {weight} with {len(program_bank)} primitives.")

        # iterate over each operator
        for op in operators:

            # get all possible combinations of primitives in program bank
            combinations = itertools.combinations(program_bank, op.arity)

            # iterate over each combination
            for combination in combinations:

                # get type signature
                type_signature = [p.type for p in combination]

                # check if type signature matches operator
                if type_signature != op.arg_types:
                    continue

                # check that sum of weights of arguments <= w
                if sum([p.weight for p in combination]) > weight:
                    continue

                # create new program
                program = OperatorNode(op, combination)

                # check if program is in program bank using string representation
                if program.str() in program_bank_str:
                    continue
                
                # check if program is observationally equivalent to any program in program bank
                if any([observationally_equivalent(program, p, examples) for p in program_bank]):
                    continue

                # add program to program bank
                program_bank.append(program)
                program_bank_str.append(program.str())

                # check if program passes all examples
                if check_program(program, examples):
                    return(program)    

    # return None if no program is found
    return None 


if __name__ == '__main__':

    # parse command line arguments
    args = parse_args()
    # print(args)

    # run bottom-up enumerative synthesis
    start_time = time.time()
    program = run_synthesizer(args)
    end_time = time.time()
    elapsed_time = round(end_time - start_time, 4)

    # check if program was found
    print("\nSynthesis Results:")
    if program is None:
        print(f"- Max weight of {args.max_weight} reached, no program found in {elapsed_time}s.")
    else:
        print(f"- Program found in {elapsed_time}s.")
        print(f"- Program: {program.str()}")
        print(f"- Program weight: {program.weight}")
        print(f"- Program return type: {program.type.__name__}")