File size: 6,732 Bytes
ecbc9c7
 
 
 
3f02d46
 
2056e9c
ecbc9c7
 
 
 
3f02d46
5b04db9
 
3f02d46
 
a417ea3
6b5a85b
5b04db9
a417ea3
389a372
ecbc9c7
3f02d46
a417ea3
 
3f02d46
 
 
 
 
 
 
ea7fc19
3f02d46
 
 
 
 
 
 
a417ea3
 
3f02d46
6b5a85b
389a372
3f02d46
 
 
 
 
a417ea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b5a85b
a417ea3
 
 
 
 
 
 
 
 
 
 
 
6b5a85b
a417ea3
 
 
 
 
 
5b04db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a417ea3
 
 
 
 
 
 
 
 
 
 
5b04db9
6b5a85b
5b04db9
 
 
6b5a85b
 
ea7fc19
6b5a85b
 
 
5b04db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a417ea3
5b04db9
 
a417ea3
 
3f02d46
 
 
a417ea3
 
3f02d46
 
5b04db9
 
 
 
 
 
6b5a85b
5b04db9
6b5a85b
5b04db9
6b5a85b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
'''
BOTTOM UP ENUMERATIVE SYNTHESIS
Ayush Noori
CS252R, Fall 2020

Example of usage:
python synthesis.py --domain arithmetic --examples addition
'''

# load libraries
import numpy as np
import argparse
import itertools
import time

# import examples
from arithmetic import *
from strings import *
from abstract_syntax_tree import *
from examples import example_set, check_examples
import config


# PARSE ARGUMENTS
def parse_args():
    '''
    Parse command line arguments.
    '''

    parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis in Python.")

    # define valid choices for the 'domain' argument
    valid_domain_choices = ["arithmetic", "strings"]

    # add examples
    parser.add_argument('--domain', type=str, required=True, # default="arithmetic",
                        choices=valid_domain_choices,
                        help='Domain of synthesis (either "arithmetic" or "string").')

    parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition",
                        choices=example_set.keys(),
                        help='Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.')
    
    parser.add_argument('--max-weight', type=int, required=False, default=3,
                        help='Maximum weight of programs to consider before terminating search.')

    args = parser.parse_args()
    return args


# EXTRACT CONSTANTS AND VARIABLES
def extract_constants(examples):
    '''
    Extracts the constants from the input-output examples. Also constructs variables as needed
    based on the input-output examples, and adds them to the list of constants.
    '''

    # check validity of provided examples
    # if valid, extract arity and argument types
    arity, arg_types = check_examples(examples)

    # initialize list of constants
    constants = []

    # get unique set of inputs
    inputs = [input for example in examples for input in example[0]]
    inputs = set(inputs)

    # add 1 to the set of inputs
    inputs.add(1)

    # extract constants in input
    for input in inputs:

        if type(input) == int:
            constants.append(IntegerConstant(input))
        elif type(input) == str:
            constants.append(StringConstant(input))
            pass
        else:
            raise Exception("Input of unknown type.")
        
    # initialize list of variables
    variables = []

    # extract variables in input
    for position, arg in enumerate(arg_types):
        if arg == int:
            variables.append(IntegerVariable(position))
        elif arg == str:
            variables.append(StringVariable(position))
        else:
            raise Exception("Input of unknown type.")

    return constants + variables


# CHECK OBSERVATIONAL EQUIVALENCE
def observationally_equivalent(program_a, program_b, examples):
    """
    Returns True if Program A and Program B are observationally equivalent, False otherwise.
    """

    inputs = [example[0] for example in examples]
    a_output = [program_a.evaluate(input) for input in inputs]
    b_output = [program_b.evaluate(input) for input in inputs]

    return a_output == b_output


# CHECK CORRECTNESS
def check_program(program, examples):
    '''
    Check whether the program satisfies the input-output examples.
    '''
    
    inputs = [example[0] for example in examples]
    outputs = [example[1] for example in examples]
    program_output = [program.evaluate(input) for input in inputs]

    return program_output == outputs


# RUN SYNTHESIZER
def run_synthesizer(args):
    '''
    Run bottom-up enumerative synthesis.
    '''

    # retrieve selected input-output examples
    examples = example_set[args.examples_key]

    # extract constants from examples
    program_bank = extract_constants(examples)
    program_bank_str = [p.str() for p in program_bank]
    print("\nSynthesis Log:")
    print(f"- Extracted {len(program_bank)} constants from examples.")

    # define operators
    if args.domain == "arithmetic":
        operators = arithmetic_operators
    elif args.domain == "strings":
        operators = string_operators
    else:
        raise Exception('Domain not recognized. Must be either "arithmetic" or "string".')

    # iterate over each level
    for weight in range(2, args.max_weight):

        # print message
        print(f"- Searching level {weight} with {len(program_bank)} primitives.")

        # iterate over each operator
        for op in operators:

            # get all possible combinations of primitives in program bank
            combinations = itertools.combinations(program_bank, op.arity)

            # iterate over each combination
            for combination in combinations:

                # get type signature
                type_signature = [p.type for p in combination]

                # check if type signature matches operator
                if type_signature != op.arg_types:
                    continue

                # check that sum of weights of arguments <= w
                if sum([p.weight for p in combination]) > weight:
                    continue

                # create new program
                program = OperatorNode(op, combination)

                # check if program is in program bank using string representation
                if program.str() in program_bank_str:
                    continue
                
                # check if program is observationally equivalent to any program in program bank
                if any([observationally_equivalent(program, p, examples) for p in program_bank]):
                    continue

                # add program to program bank
                program_bank.append(program)
                program_bank_str.append(program.str())

                # check if program passes all examples
                if check_program(program, examples):
                    return(program)    

    # return None if no program is found
    return None 


if __name__ == '__main__':

    # parse command line arguments
    args = parse_args()
    # print(args)

    # run bottom-up enumerative synthesis
    start_time = time.time()
    program = run_synthesizer(args)
    end_time = time.time()
    elapsed_time = round(end_time - start_time, 4)

    # check if program was found
    print("\nSynthesis Results:")
    if program is None:
        print(f"- Max weight of {args.max_weight} reached, no program found in {elapsed_time}s.")
    else:
        print(f"- Program found in {elapsed_time}s.")
        print(f"- Program: {program.str()}")
        print(f"- Program weight: {program.weight}")
        print(f"- Program return type: {program.type.__name__}")