Spaces:
Running
Running
Commit
·
a417ea3
1
Parent(s):
771c860
Extract constants and variables
Browse files- README.md +3 -3
- abstract_syntax_trees.py +5 -0
- arithmetic.py +33 -21
- demonstration.ipynb +109 -39
- examples.py +41 -4
- synthesizer.py +73 -9
README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
|
3 |
🚨🚨PLEASE DO NOT GRADE YET🚨🚨
|
4 |
|
5 |
-
Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught Fall 2023 by Prof. Nada Amin.
|
6 |
|
7 |
## 🛠️ Background
|
8 |
|
@@ -37,7 +37,7 @@ optional arguments:
|
|
37 |
--domain {arithmetic,string}
|
38 |
Domain of synthesis (either "arithmetic" or "string").
|
39 |
--examples {addition,subtraction,multiplication,division}
|
40 |
-
Examples to synthesize program from. Must be a valid key in the "
|
41 |
--max_weight MAX_WEIGHT
|
42 |
Maximum weight of programs to consider before terminating search.
|
43 |
```
|
@@ -47,7 +47,7 @@ For example, to synthesize programs in the arithmetic domain from the addition i
|
|
47 |
python3 synthesizer.py --domain arithmetic --examples addition
|
48 |
```
|
49 |
|
50 |
-
To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `
|
51 |
|
52 |
## 🔮 Virtual Environment
|
53 |
|
|
|
2 |
|
3 |
🚨🚨PLEASE DO NOT GRADE YET🚨🚨
|
4 |
|
5 |
+
Completed for [CS252R: Program Synthesis](https://synthesis.metareflection.club/) at the Harvard John A. Paulson School of Engineering and Applied Sciences, taught in Fall 2023 by Prof. Nada Amin.
|
6 |
|
7 |
## 🛠️ Background
|
8 |
|
|
|
37 |
--domain {arithmetic,string}
|
38 |
Domain of synthesis (either "arithmetic" or "string").
|
39 |
--examples {addition,subtraction,multiplication,division}
|
40 |
+
Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.
|
41 |
--max_weight MAX_WEIGHT
|
42 |
Maximum weight of programs to consider before terminating search.
|
43 |
```
|
|
|
47 |
python3 synthesizer.py --domain arithmetic --examples addition
|
48 |
```
|
49 |
|
50 |
+
To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `example_set` and set the value to be a list of tuples.
|
51 |
|
52 |
## 🔮 Virtual Environment
|
53 |
|
abstract_syntax_trees.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
ABSTRACT SYNTAX TREES
|
3 |
+
This file contains Python classes that define the abstract syntax tree (AST) for program synthesis.
|
4 |
+
'''
|
5 |
+
|
arithmetic.py
CHANGED
@@ -7,22 +7,36 @@ This file contains Python classes that define the arithmetic operators for progr
|
|
7 |
CLASS DEFINITIONS
|
8 |
'''
|
9 |
|
10 |
-
class
|
11 |
'''
|
12 |
-
Class to represent an
|
|
|
13 |
'''
|
14 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
15 |
self.value = value
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class Add:
|
19 |
'''
|
20 |
Operator to add two numerical values.
|
21 |
'''
|
22 |
def __init__(self):
|
23 |
-
self.arity = 2
|
24 |
-
self.
|
25 |
-
self.return_type = int
|
|
|
26 |
|
27 |
def __call__(self, x, y):
|
28 |
return x + y
|
@@ -35,9 +49,10 @@ class Subtract:
|
|
35 |
Operator to subtract two numerical values.
|
36 |
'''
|
37 |
def __init__(self):
|
38 |
-
self.arity = 2
|
39 |
-
self.
|
40 |
-
self.return_type = int
|
|
|
41 |
|
42 |
def __call__(self, x, y):
|
43 |
return x - y
|
@@ -50,9 +65,10 @@ class Multiply:
|
|
50 |
Operator to multiply two numerical values.
|
51 |
'''
|
52 |
def __init__(self):
|
53 |
-
self.arity = 2
|
54 |
-
self.
|
55 |
-
self.return_type = int
|
|
|
56 |
|
57 |
def __call__(self, x, y):
|
58 |
return x * y
|
@@ -65,9 +81,10 @@ class Divide:
|
|
65 |
Operator to divide two numerical values.
|
66 |
'''
|
67 |
def __init__(self):
|
68 |
-
self.arity = 2
|
69 |
-
self.
|
70 |
-
self.return_type = int
|
|
|
71 |
|
72 |
def __call__(self, x, y):
|
73 |
try: # check for division by zero error
|
@@ -79,11 +96,6 @@ class Divide:
|
|
79 |
return f"{x} / {y}"
|
80 |
|
81 |
|
82 |
-
'''
|
83 |
-
FUNCTION DEFINITIONS
|
84 |
-
'''
|
85 |
-
|
86 |
-
|
87 |
'''
|
88 |
GLOBAL CONSTANTS
|
89 |
'''
|
|
|
7 |
CLASS DEFINITIONS
|
8 |
'''
|
9 |
|
10 |
+
class IntegerVariable:
|
11 |
'''
|
12 |
+
Class to represent an integer variable. Note that position is the position of the variable in the input.
|
13 |
+
For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.
|
14 |
'''
|
15 |
+
def __init__(self, position):
|
16 |
+
self.value = None # value of the variable, initially None
|
17 |
+
self.position = position # position of the variable in the arguments to program
|
18 |
+
self.type = int # type of the variable
|
19 |
+
|
20 |
+
def assign(self, value):
|
21 |
self.value = value
|
22 |
+
|
23 |
+
class IntegerConstant:
|
24 |
+
'''
|
25 |
+
Class to represent an integer constant.
|
26 |
+
'''
|
27 |
+
def __init__(self, value):
|
28 |
+
self.value = value # value of the constant
|
29 |
+
self.type = int # type of the constant
|
30 |
|
31 |
class Add:
|
32 |
'''
|
33 |
Operator to add two numerical values.
|
34 |
'''
|
35 |
def __init__(self):
|
36 |
+
self.arity = 2 # number of arguments
|
37 |
+
self.arg_types = [int, int] # argument types
|
38 |
+
self.return_type = int # return type
|
39 |
+
self.weight = 1 # weight
|
40 |
|
41 |
def __call__(self, x, y):
|
42 |
return x + y
|
|
|
49 |
Operator to subtract two numerical values.
|
50 |
'''
|
51 |
def __init__(self):
|
52 |
+
self.arity = 2 # number of arguments
|
53 |
+
self.arg_types = [int, int] # argument types
|
54 |
+
self.return_type = int # return type
|
55 |
+
self.weight = 1 # weight
|
56 |
|
57 |
def __call__(self, x, y):
|
58 |
return x - y
|
|
|
65 |
Operator to multiply two numerical values.
|
66 |
'''
|
67 |
def __init__(self):
|
68 |
+
self.arity = 2 # number of arguments
|
69 |
+
self.arg_types = [int, int] # argument types
|
70 |
+
self.return_type = int # return type
|
71 |
+
self.weight = 1 # weight
|
72 |
|
73 |
def __call__(self, x, y):
|
74 |
return x * y
|
|
|
81 |
Operator to divide two numerical values.
|
82 |
'''
|
83 |
def __init__(self):
|
84 |
+
self.arity = 2 # number of arguments
|
85 |
+
self.arg_types = [int, int] # argument types
|
86 |
+
self.return_type = int # return type
|
87 |
+
self.weight = 1 # weight
|
88 |
|
89 |
def __call__(self, x, y):
|
90 |
try: # check for division by zero error
|
|
|
96 |
return f"{x} / {y}"
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
99 |
'''
|
100 |
GLOBAL CONSTANTS
|
101 |
'''
|
demonstration.ipynb
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
},
|
16 |
{
|
17 |
"cell_type": "code",
|
18 |
-
"execution_count":
|
19 |
"metadata": {},
|
20 |
"outputs": [],
|
21 |
"source": [
|
@@ -27,7 +27,7 @@
|
|
27 |
"\n",
|
28 |
"# import arithmetic module\n",
|
29 |
"# from arithmetic import *\n",
|
30 |
-
"from examples import
|
31 |
"import config"
|
32 |
]
|
33 |
},
|
@@ -40,16 +40,23 @@
|
|
40 |
},
|
41 |
{
|
42 |
"cell_type": "code",
|
43 |
-
"execution_count":
|
44 |
"metadata": {},
|
45 |
"outputs": [],
|
46 |
"source": [
|
47 |
"domain = \"arithmetic\"\n",
|
48 |
"examples_key = \"addition\"\n",
|
49 |
-
"examples =
|
50 |
"max_weight = 3"
|
51 |
]
|
52 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
{
|
54 |
"cell_type": "markdown",
|
55 |
"metadata": {},
|
@@ -59,35 +66,40 @@
|
|
59 |
},
|
60 |
{
|
61 |
"cell_type": "code",
|
62 |
-
"execution_count":
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
66 |
-
"
|
67 |
-
"
|
68 |
-
"
|
69 |
-
"
|
|
|
|
|
|
|
|
|
|
|
70 |
"\n",
|
71 |
-
"
|
72 |
-
"
|
73 |
-
"''' \n",
|
74 |
"\n",
|
75 |
-
"class
|
76 |
" '''\n",
|
77 |
-
" Class to represent an
|
78 |
" '''\n",
|
79 |
" def __init__(self, value):\n",
|
80 |
-
" self.value = value\n",
|
81 |
-
" self.type = int\n",
|
82 |
"\n",
|
83 |
"class Add:\n",
|
84 |
" '''\n",
|
85 |
" Operator to add two numerical values.\n",
|
86 |
" '''\n",
|
87 |
" def __init__(self):\n",
|
88 |
-
" self.arity = 2
|
89 |
-
" self.
|
90 |
-
" self.return_type = int
|
|
|
91 |
"\n",
|
92 |
" def __call__(self, x, y):\n",
|
93 |
" return x + y\n",
|
@@ -100,9 +112,10 @@
|
|
100 |
" Operator to subtract two numerical values.\n",
|
101 |
" '''\n",
|
102 |
" def __init__(self):\n",
|
103 |
-
" self.arity = 2
|
104 |
-
" self.
|
105 |
-
" self.return_type = int
|
|
|
106 |
"\n",
|
107 |
" def __call__(self, x, y):\n",
|
108 |
" return x - y\n",
|
@@ -115,9 +128,10 @@
|
|
115 |
" Operator to multiply two numerical values.\n",
|
116 |
" '''\n",
|
117 |
" def __init__(self):\n",
|
118 |
-
" self.arity = 2
|
119 |
-
" self.
|
120 |
-
" self.return_type = int
|
|
|
121 |
"\n",
|
122 |
" def __call__(self, x, y):\n",
|
123 |
" return x * y\n",
|
@@ -130,9 +144,10 @@
|
|
130 |
" Operator to divide two numerical values.\n",
|
131 |
" '''\n",
|
132 |
" def __init__(self):\n",
|
133 |
-
" self.arity = 2
|
134 |
-
" self.
|
135 |
-
" self.return_type = int
|
|
|
136 |
"\n",
|
137 |
" def __call__(self, x, y):\n",
|
138 |
" try: # check for division by zero error\n",
|
@@ -145,11 +160,6 @@
|
|
145 |
"\n",
|
146 |
"\n",
|
147 |
"'''\n",
|
148 |
-
"FUNCTION DEFINITIONS\n",
|
149 |
-
"''' \n",
|
150 |
-
"\n",
|
151 |
-
"\n",
|
152 |
-
"'''\n",
|
153 |
"GLOBAL CONSTANTS\n",
|
154 |
"''' \n",
|
155 |
"\n",
|
@@ -161,7 +171,70 @@
|
|
161 |
"cell_type": "markdown",
|
162 |
"metadata": {},
|
163 |
"source": [
|
164 |
-
"I define
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
]
|
166 |
},
|
167 |
{
|
@@ -198,16 +271,13 @@
|
|
198 |
"metadata": {},
|
199 |
"outputs": [],
|
200 |
"source": [
|
201 |
-
"# initialize program bank\n",
|
202 |
-
"program_bank = []\n",
|
203 |
-
"\n",
|
204 |
"# iterate over each level\n",
|
205 |
-
"for i in range(
|
206 |
"\n",
|
207 |
" # define level program bank\n",
|
208 |
" level_program_bank = []\n",
|
209 |
"\n",
|
210 |
-
" for op in arithmetic_operators
|
211 |
"\n",
|
212 |
" break"
|
213 |
]
|
|
|
15 |
},
|
16 |
{
|
17 |
"cell_type": "code",
|
18 |
+
"execution_count": 1,
|
19 |
"metadata": {},
|
20 |
"outputs": [],
|
21 |
"source": [
|
|
|
27 |
"\n",
|
28 |
"# import arithmetic module\n",
|
29 |
"# from arithmetic import *\n",
|
30 |
+
"from examples import example_set, check_examples\n",
|
31 |
"import config"
|
32 |
]
|
33 |
},
|
|
|
40 |
},
|
41 |
{
|
42 |
"cell_type": "code",
|
43 |
+
"execution_count": 2,
|
44 |
"metadata": {},
|
45 |
"outputs": [],
|
46 |
"source": [
|
47 |
"domain = \"arithmetic\"\n",
|
48 |
"examples_key = \"addition\"\n",
|
49 |
+
"examples = example_set[examples_key]\n",
|
50 |
"max_weight = 3"
|
51 |
]
|
52 |
},
|
53 |
+
{
|
54 |
+
"cell_type": "markdown",
|
55 |
+
"metadata": {},
|
56 |
+
"source": [
|
57 |
+
"First, I define a function to check that, across all input-output pairs, all inputs are of the same length and that argument types are consistent across inputs."
|
58 |
+
]
|
59 |
+
},
|
60 |
{
|
61 |
"cell_type": "markdown",
|
62 |
"metadata": {},
|
|
|
66 |
},
|
67 |
{
|
68 |
"cell_type": "code",
|
69 |
+
"execution_count": 8,
|
70 |
"metadata": {},
|
71 |
"outputs": [],
|
72 |
"source": [
|
73 |
+
"class IntegerVariable:\n",
|
74 |
+
" '''\n",
|
75 |
+
" Class to represent an integer variable. Note that position is the position of the variable in the input.\n",
|
76 |
+
" For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.\n",
|
77 |
+
" '''\n",
|
78 |
+
" def __init__(self, position):\n",
|
79 |
+
" self.value = None # value of the variable, initially None\n",
|
80 |
+
" self.position = position # position of the variable in the arguments to program\n",
|
81 |
+
" self.type = int # type of the variable\n",
|
82 |
"\n",
|
83 |
+
" def assign(self, value):\n",
|
84 |
+
" self.value = value\n",
|
|
|
85 |
"\n",
|
86 |
+
"class IntegerConstant:\n",
|
87 |
" '''\n",
|
88 |
+
" Class to represent an integer constant.\n",
|
89 |
" '''\n",
|
90 |
" def __init__(self, value):\n",
|
91 |
+
" self.value = value # value of the constant\n",
|
92 |
+
" self.type = int # type of the constant\n",
|
93 |
"\n",
|
94 |
"class Add:\n",
|
95 |
" '''\n",
|
96 |
" Operator to add two numerical values.\n",
|
97 |
" '''\n",
|
98 |
" def __init__(self):\n",
|
99 |
+
" self.arity = 2 # number of arguments\n",
|
100 |
+
" self.arg_types = [int, int] # argument types\n",
|
101 |
+
" self.return_type = int # return type\n",
|
102 |
+
" self.weight = 1 # weight\n",
|
103 |
"\n",
|
104 |
" def __call__(self, x, y):\n",
|
105 |
" return x + y\n",
|
|
|
112 |
" Operator to subtract two numerical values.\n",
|
113 |
" '''\n",
|
114 |
" def __init__(self):\n",
|
115 |
+
" self.arity = 2 # number of arguments\n",
|
116 |
+
" self.arg_types = [int, int] # argument types\n",
|
117 |
+
" self.return_type = int # return type\n",
|
118 |
+
" self.weight = 1 # weight\n",
|
119 |
"\n",
|
120 |
" def __call__(self, x, y):\n",
|
121 |
" return x - y\n",
|
|
|
128 |
" Operator to multiply two numerical values.\n",
|
129 |
" '''\n",
|
130 |
" def __init__(self):\n",
|
131 |
+
" self.arity = 2 # number of arguments\n",
|
132 |
+
" self.arg_types = [int, int] # argument types\n",
|
133 |
+
" self.return_type = int # return type\n",
|
134 |
+
" self.weight = 1 # weight\n",
|
135 |
"\n",
|
136 |
" def __call__(self, x, y):\n",
|
137 |
" return x * y\n",
|
|
|
144 |
" Operator to divide two numerical values.\n",
|
145 |
" '''\n",
|
146 |
" def __init__(self):\n",
|
147 |
+
" self.arity = 2 # number of arguments\n",
|
148 |
+
" self.arg_types = [int, int] # argument types\n",
|
149 |
+
" self.return_type = int # return type\n",
|
150 |
+
" self.weight = 1 # weight\n",
|
151 |
"\n",
|
152 |
" def __call__(self, x, y):\n",
|
153 |
" try: # check for division by zero error\n",
|
|
|
160 |
"\n",
|
161 |
"\n",
|
162 |
"'''\n",
|
|
|
|
|
|
|
|
|
|
|
163 |
"GLOBAL CONSTANTS\n",
|
164 |
"''' \n",
|
165 |
"\n",
|
|
|
171 |
"cell_type": "markdown",
|
172 |
"metadata": {},
|
173 |
"source": [
|
174 |
+
"I define a function to extract constants from examples."
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": 9,
|
180 |
+
"metadata": {},
|
181 |
+
"outputs": [],
|
182 |
+
"source": [
|
183 |
+
"def extract_constants(examples):\n",
|
184 |
+
" '''\n",
|
185 |
+
" Extracts the constants from the input-output examples. Also constructs variables as needed\n",
|
186 |
+
" based on the input-output examples, and adds them to the list of constants.\n",
|
187 |
+
" '''\n",
|
188 |
+
"\n",
|
189 |
+
" # check validity of provided examples\n",
|
190 |
+
" # if valid, extract arity and argument types\n",
|
191 |
+
" arity, arg_types = check_examples(examples)\n",
|
192 |
+
"\n",
|
193 |
+
" # initialize list of constants\n",
|
194 |
+
" constants = []\n",
|
195 |
+
"\n",
|
196 |
+
" # get unique set of inputs\n",
|
197 |
+
" inputs = [input for example in examples for input in example[0]]\n",
|
198 |
+
" inputs = set(inputs)\n",
|
199 |
+
"\n",
|
200 |
+
" # add 1 to the set of inputs\n",
|
201 |
+
" inputs.add(1)\n",
|
202 |
+
"\n",
|
203 |
+
" # extract constants in input\n",
|
204 |
+
" for input in inputs:\n",
|
205 |
+
"\n",
|
206 |
+
" if type(input) == int:\n",
|
207 |
+
" constants.append(IntegerConstant(input))\n",
|
208 |
+
" elif type(input) == str:\n",
|
209 |
+
" # constants.append(StringConstant(input))\n",
|
210 |
+
" pass\n",
|
211 |
+
" else:\n",
|
212 |
+
" raise Exception(\"Input of unknown type.\")\n",
|
213 |
+
" \n",
|
214 |
+
" # initialize list of variables\n",
|
215 |
+
" variables = []\n",
|
216 |
+
"\n",
|
217 |
+
" # extract variables in input\n",
|
218 |
+
" for position, arg in enumerate(arg_types):\n",
|
219 |
+
" if arg == int:\n",
|
220 |
+
" variables.append(IntegerVariable(position))\n",
|
221 |
+
" elif arg == str:\n",
|
222 |
+
" # variables.append(StringVariable(position))\n",
|
223 |
+
" pass\n",
|
224 |
+
" else:\n",
|
225 |
+
" raise Exception(\"Input of unknown type.\")\n",
|
226 |
+
"\n",
|
227 |
+
" return constants + variables"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": 16,
|
233 |
+
"metadata": {},
|
234 |
+
"outputs": [],
|
235 |
+
"source": [
|
236 |
+
"# initialize program bank\n",
|
237 |
+
"program_bank = extract_constants(examples)"
|
238 |
]
|
239 |
},
|
240 |
{
|
|
|
271 |
"metadata": {},
|
272 |
"outputs": [],
|
273 |
"source": [
|
|
|
|
|
|
|
274 |
"# iterate over each level\n",
|
275 |
+
"for i in range(2, max_weight):\n",
|
276 |
"\n",
|
277 |
" # define level program bank\n",
|
278 |
" level_program_bank = []\n",
|
279 |
"\n",
|
280 |
+
" for op in arithmetic_operators:\n",
|
281 |
"\n",
|
282 |
" break"
|
283 |
]
|
examples.py
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
'''
|
2 |
EXAMPLES
|
3 |
This file contains input-output examples for both arithmetic and string domain-specific languages (DSLs).
|
4 |
-
To add a new example, add a new key to the dictionary '
|
|
|
|
|
|
|
|
|
|
|
5 |
'''
|
6 |
|
7 |
# define examples
|
8 |
-
|
9 |
# arithmetic examples
|
10 |
-
'addition': [([7, 2], 9), ([
|
11 |
'subtraction': [([9, 2], 7), ([6, 1], 5), ([7, 3], 4), ([8, 4], 4), ([10, 2], 8)],
|
12 |
'multiplication': [([2, 3], 6), ([4, 5], 20), ([7, 8], 56), ([9, 2], 18), ([3, 4], 12)],
|
13 |
'division': [([6, 2], 3), ([8, 4], 2), ([9, 3], 3), ([10, 5], 2), ([12, 6], 2)]
|
@@ -15,4 +20,36 @@ examples = {
|
|
15 |
# string examples
|
16 |
|
17 |
# custom user examples
|
18 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
'''
|
2 |
EXAMPLES
|
3 |
This file contains input-output examples for both arithmetic and string domain-specific languages (DSLs).
|
4 |
+
To add a new example, add a new key to the dictionary 'example_set' and set the value to be a list of tuples.
|
5 |
+
|
6 |
+
Note that we synthesize programs with a consistent arity. Therefore, in each set of input-output examples, all
|
7 |
+
input examples must be of the same length. Further, argument types must remain consistent across examples. We
|
8 |
+
test for these conditions in the `check_examples` function below, which is called by the `extract_constants`
|
9 |
+
function in the synthesizer.
|
10 |
'''
|
11 |
|
12 |
# define examples
|
13 |
+
example_set = {
|
14 |
# arithmetic examples
|
15 |
+
'addition': [([7, 2], 9), ([8, 1], 9), ([3, 9], 12), ([5, 8], 13)], # ([4, 6], 10),
|
16 |
'subtraction': [([9, 2], 7), ([6, 1], 5), ([7, 3], 4), ([8, 4], 4), ([10, 2], 8)],
|
17 |
'multiplication': [([2, 3], 6), ([4, 5], 20), ([7, 8], 56), ([9, 2], 18), ([3, 4], 12)],
|
18 |
'division': [([6, 2], 3), ([8, 4], 2), ([9, 3], 3), ([10, 5], 2), ([12, 6], 2)]
|
|
|
20 |
# string examples
|
21 |
|
22 |
# custom user examples
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
# CHECK EXAMPLE VALIDITY
|
27 |
+
def check_examples(examples):
|
28 |
+
'''
|
29 |
+
Checks that all input examples are of same length and that argument types are consistent across examples.
|
30 |
+
If valid, returns arity and argument types of function to be generated.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
examples (list): list of tuples, where each tuple is of the form (input, output)
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
input_lengths[0] (int): arity of function
|
37 |
+
arg_types[0] (list): argument types of function
|
38 |
+
'''
|
39 |
+
|
40 |
+
# get input examples
|
41 |
+
inputs = [example[0] for example in examples]
|
42 |
+
|
43 |
+
# check all inputs are of same length
|
44 |
+
input_lengths = [len(input) for input in inputs]
|
45 |
+
if len(set(input_lengths)) != 1:
|
46 |
+
raise ValueError("All input examples must be of same length.")
|
47 |
+
|
48 |
+
# check that types of arguments are same
|
49 |
+
arg_types = [[type(arg) for arg in input] for input in inputs]
|
50 |
+
consistent_types = all([arg_types[0] == arg_type for arg_type in arg_types])
|
51 |
+
if not consistent_types:
|
52 |
+
raise ValueError("Argument types must be consistent across inputs.")
|
53 |
+
|
54 |
+
# return arity and argument types
|
55 |
+
return input_lengths[0], arg_types[0]
|
synthesizer.py
CHANGED
@@ -12,11 +12,13 @@ import numpy as np
|
|
12 |
import argparse
|
13 |
|
14 |
# import examples
|
15 |
-
from
|
|
|
16 |
import config
|
17 |
|
18 |
|
19 |
-
|
|
|
20 |
'''
|
21 |
Parse command line arguments.
|
22 |
'''
|
@@ -32,8 +34,8 @@ def parse_args(examples):
|
|
32 |
help='Domain of synthesis (either "arithmetic" or "string").')
|
33 |
|
34 |
parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition",
|
35 |
-
choices=
|
36 |
-
help='Examples to synthesize program from. Must be a valid key in the "
|
37 |
|
38 |
parser.add_argument('--max_weight', type=int, required=False, default=3,
|
39 |
help='Maximum weight of programs to consider before terminating search.')
|
@@ -42,13 +44,75 @@ def parse_args(examples):
|
|
42 |
return args
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
if __name__ == '__main__':
|
46 |
|
47 |
# parse command line arguments
|
48 |
-
args = parse_args(
|
49 |
-
print(args
|
50 |
-
print(args.examples_key)
|
51 |
-
print(args.max_weight)
|
52 |
|
53 |
# run bottom-up enumerative synthesis
|
54 |
-
|
|
|
12 |
import argparse
|
13 |
|
14 |
# import examples
|
15 |
+
from arithmetic import *
|
16 |
+
from examples import example_set, check_examples
|
17 |
import config
|
18 |
|
19 |
|
20 |
+
# PARSE ARGUMENTS
|
21 |
+
def parse_args():
|
22 |
'''
|
23 |
Parse command line arguments.
|
24 |
'''
|
|
|
34 |
help='Domain of synthesis (either "arithmetic" or "string").')
|
35 |
|
36 |
parser.add_argument('--examples', dest='examples_key', type=str, required=True, # default="addition",
|
37 |
+
choices=example_set.keys(),
|
38 |
+
help='Examples to synthesize program from. Must be a valid key in the "example_set" dictionary.')
|
39 |
|
40 |
parser.add_argument('--max_weight', type=int, required=False, default=3,
|
41 |
help='Maximum weight of programs to consider before terminating search.')
|
|
|
44 |
return args
|
45 |
|
46 |
|
47 |
+
# EXTRACT CONSTANTS AND VARIABLES
|
48 |
+
def extract_constants(examples):
|
49 |
+
'''
|
50 |
+
Extracts the constants from the input-output examples. Also constructs variables as needed
|
51 |
+
based on the input-output examples, and adds them to the list of constants.
|
52 |
+
'''
|
53 |
+
|
54 |
+
# check validity of provided examples
|
55 |
+
# if valid, extract arity and argument types
|
56 |
+
arity, arg_types = check_examples(examples)
|
57 |
+
|
58 |
+
# initialize list of constants
|
59 |
+
constants = []
|
60 |
+
|
61 |
+
# get unique set of inputs
|
62 |
+
inputs = [input for example in examples for input in example[0]]
|
63 |
+
inputs = set(inputs)
|
64 |
+
|
65 |
+
# add 1 to the set of inputs
|
66 |
+
inputs.add(1)
|
67 |
+
|
68 |
+
# extract constants in input
|
69 |
+
for input in inputs:
|
70 |
+
|
71 |
+
if type(input) == int:
|
72 |
+
constants.append(IntegerConstant(input))
|
73 |
+
elif type(input) == str:
|
74 |
+
# constants.append(StringConstant(input))
|
75 |
+
pass
|
76 |
+
else:
|
77 |
+
raise Exception("Input of unknown type.")
|
78 |
+
|
79 |
+
# initialize list of variables
|
80 |
+
variables = []
|
81 |
+
|
82 |
+
# extract variables in input
|
83 |
+
for position, arg in enumerate(arg_types):
|
84 |
+
if arg == int:
|
85 |
+
variables.append(IntegerVariable(position))
|
86 |
+
elif arg == str:
|
87 |
+
# variables.append(StringVariable(position))
|
88 |
+
pass
|
89 |
+
else:
|
90 |
+
raise Exception("Input of unknown type.")
|
91 |
+
|
92 |
+
return constants + variables
|
93 |
+
|
94 |
+
|
95 |
+
# RUN SYNTHESIZER
|
96 |
+
def run_synthesizer(args):
|
97 |
+
'''
|
98 |
+
Run bottom-up enumerative synthesis.
|
99 |
+
'''
|
100 |
+
|
101 |
+
# retrieve selected input-output examples
|
102 |
+
examples = example_set[args.examples_key]
|
103 |
+
|
104 |
+
# extract constants from examples
|
105 |
+
program_bank = extract_constants(examples)
|
106 |
+
print(examples)
|
107 |
+
|
108 |
+
pass
|
109 |
+
|
110 |
+
|
111 |
if __name__ == '__main__':
|
112 |
|
113 |
# parse command line arguments
|
114 |
+
args = parse_args()
|
115 |
+
# print(args)
|
|
|
|
|
116 |
|
117 |
# run bottom-up enumerative synthesis
|
118 |
+
run_synthesizer(args)
|