ayushnoori commited on
Commit
3872a55
·
1 Parent(s): c9abdca

Add variable assignment and checking

Browse files
Files changed (4) hide show
  1. README.md +4 -0
  2. abstract_syntax_tree.py +21 -9
  3. arithmetic.py +35 -11
  4. demonstration.ipynb +2 -25
README.md CHANGED
@@ -49,6 +49,10 @@ python3 synthesizer.py --domain arithmetic --examples addition
49
 
50
  To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `example_set` and set the value to be a list of tuples.
51
 
 
 
 
 
52
  ## 🔮 Virtual Environment
53
 
54
  To create a virtual environment, run:
 
49
 
50
  To add additional input-output examples, modify `examples.py`. Add a new key to the dictionary `example_set` and set the value to be a list of tuples.
51
 
52
+ ## 🔎 Abstract Syntax Tree
53
+
54
+ The most important data structure in this implementation is the abstract syntax tree (AST). The AST is a tree representation of a program, where each node is either a primitive or a compound expression. The AST is represented by the `OperatorNode` class in `abstract_syntax_tree.py`. My AST implementation includes functions to recursively evaluate the operator and its operands, and also to generate a string representation of the program.
55
+
56
  ## 🔮 Virtual Environment
57
 
58
  To create a virtual environment, run:
abstract_syntax_tree.py CHANGED
@@ -11,28 +11,40 @@ class OperatorNode:
11
  operator (object): operator object (e.g., Add, Subtract, etc.)
12
  children (list): list of children nodes (operands)
13
 
14
- Example of usage:
 
 
 
 
 
15
 
16
- add_node = OperatorNode(Add(), [IntegerConstant(7), IntegerConstant(5)])
17
- subtract_node = OperatorNode(Subtract(), [IntegerConstant(3), IntegerConstant(1)])
18
- multiply_node = OperatorNode(Multiply(), [add_node, subtract_node])
 
 
19
  '''
 
20
  def __init__(self, operator, children):
21
  self.operator = operator # Operator object (e.g., Add, Subtract, etc.)
22
  self.children = children # list of children nodes (operands)
23
 
24
- def evaluate(self):
25
- # check arity
 
26
  if len(self.children) != self.operator.arity:
27
  raise ValueError("Invalid number of operands for operator")
 
28
  # recursively evaluate the operator and its operands
29
- operands = [child.evaluate() for child in self.children]
30
- return self.operator.evaluate(*operands)
31
 
32
  def str(self):
33
- # check arity
 
34
  if len(self.children) != self.operator.arity:
35
  raise ValueError("Invalid number of operands for operator")
 
36
  # recursively generate a string representation of the AST
37
  operand_strings = [child.str() for child in self.children]
38
  return self.operator.str(*operand_strings)
 
11
  operator (object): operator object (e.g., Add, Subtract, etc.)
12
  children (list): list of children nodes (operands)
13
 
14
+ Example:
15
+ add_node = OperatorNode(Add(), [IntegerConstant(7), IntegerConstant(5)])
16
+ subtract_node = OperatorNode(Subtract(), [IntegerConstant(3), IntegerConstant(1)])
17
+ multiply_node = OperatorNode(Multiply(), [add_node, subtract_node])
18
+ multiply_node.evaluate() # returns 24
19
+ multiply_node.str() # returns "((7 + 5) * (3 - 1))"
20
 
21
+ For variable computation, the input arguments are passed to the evaluate() method.
22
+ For example, if instead:
23
+
24
+ add_node = OperatorNode(Add(), [IntegerVariable(0), IntegerConstant(5)])
25
+ multiply_node.evaluate([7]) # returns 24
26
  '''
27
+
28
  def __init__(self, operator, children):
29
  self.operator = operator # Operator object (e.g., Add, Subtract, etc.)
30
  self.children = children # list of children nodes (operands)
31
 
32
+ def evaluate(self, input = None):
33
+
34
+ # check arity of operator in AST
35
  if len(self.children) != self.operator.arity:
36
  raise ValueError("Invalid number of operands for operator")
37
+
38
  # recursively evaluate the operator and its operands
39
+ operands = [child.evaluate(input) for child in self.children]
40
+ return self.operator.evaluate(*operands, input)
41
 
42
  def str(self):
43
+
44
+ # check arity of operator in AST
45
  if len(self.children) != self.operator.arity:
46
  raise ValueError("Invalid number of operands for operator")
47
+
48
  # recursively generate a string representation of the AST
49
  operand_strings = [child.str() for child in self.children]
50
  return self.operator.str(*operand_strings)
arithmetic.py CHANGED
@@ -13,15 +13,39 @@ class IntegerVariable:
13
  For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.
14
  '''
15
  def __init__(self, position):
16
- self.value = None # value of the variable, initially None
17
- self.position = position # position of the variable in the arguments to program
18
  self.type = int # type of the variable
19
 
20
- def assign(self, value):
21
- self.value = value
22
 
23
- def evaluate(self):
24
- return self.value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def str(self):
27
  return f"x{self.position}"
@@ -34,7 +58,7 @@ class IntegerConstant:
34
  self.value = value # value of the constant
35
  self.type = int # type of the constant
36
 
37
- def evaluate(self):
38
  return self.value
39
 
40
  def str(self):
@@ -50,7 +74,7 @@ class Add:
50
  self.return_type = int # return type
51
  self.weight = 1 # weight
52
 
53
- def evaluate(self, x, y):
54
  return x + y
55
 
56
  def str(self, x, y):
@@ -66,7 +90,7 @@ class Subtract:
66
  self.return_type = int # return type
67
  self.weight = 1 # weight
68
 
69
- def evaluate(self, x, y):
70
  return x - y
71
 
72
  def str(self, x, y):
@@ -82,7 +106,7 @@ class Multiply:
82
  self.return_type = int # return type
83
  self.weight = 1 # weight
84
 
85
- def evaluate(self, x, y):
86
  return x * y
87
 
88
  def str(self, x, y):
@@ -98,7 +122,7 @@ class Divide:
98
  self.return_type = int # return type
99
  self.weight = 1 # weight
100
 
101
- def evaluate(self, x, y):
102
  try: # check for division by zero error
103
  return x / y
104
  except ZeroDivisionError:
 
13
  For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.
14
  '''
15
  def __init__(self, position):
16
+ # self.value = None # value of the variable, initially None
17
+ self.position = position # zero-indexed position of the variable in the arguments to program
18
  self.type = int # type of the variable
19
 
20
+ # def assign(self, value):
21
+ # self.value = value
22
 
23
+ # def evaluate(self, input = None):
24
+ # # check that variable has been assigned a value
25
+ # if self.value is None:
26
+ # raise ValueError(f"Variable {self.position} has not been assigned a value.")
27
+
28
+ # return self.value
29
+
30
+ def evaluate(self, input = None):
31
+
32
+ # check that input is not None
33
+ if input is None:
34
+ raise ValueError("Input is None.")
35
+
36
+ # check that input is a list
37
+ if type(input) != list:
38
+ raise ValueError("Input is not a list.")
39
+
40
+ # check that input is not empty
41
+ if len(input) == 0:
42
+ raise ValueError("Input is empty.")
43
+
44
+ # check that position is valid
45
+ if self.position >= len(input):
46
+ raise ValueError(f"Position {self.position} is out of range for input of length {len(input)}.")
47
+
48
+ return input[self.position]
49
 
50
  def str(self):
51
  return f"x{self.position}"
 
58
  self.value = value # value of the constant
59
  self.type = int # type of the constant
60
 
61
+ def evaluate(self, input = None):
62
  return self.value
63
 
64
  def str(self):
 
74
  self.return_type = int # return type
75
  self.weight = 1 # weight
76
 
77
+ def evaluate(self, x, y, input = None):
78
  return x + y
79
 
80
  def str(self, x, y):
 
90
  self.return_type = int # return type
91
  self.weight = 1 # weight
92
 
93
+ def evaluate(self, x, y, input = None):
94
  return x - y
95
 
96
  def str(self, x, y):
 
106
  self.return_type = int # return type
107
  self.weight = 1 # weight
108
 
109
+ def evaluate(self, x, y, input = None):
110
  return x * y
111
 
112
  def str(self, x, y):
 
122
  self.return_type = int # return type
123
  self.weight = 1 # weight
124
 
125
+ def evaluate(self, x, y, input = None):
126
  try: # check for division by zero error
127
  return x / y
128
  except ZeroDivisionError:
demonstration.ipynb CHANGED
@@ -26,35 +26,12 @@
26
  "import argparse\n",
27
  "\n",
28
  "# import arithmetic module\n",
29
- "# from arithmetic import *\n",
30
- "# from abstract_syntax_tree import OperatorNode\n",
31
  "from examples import example_set, check_examples\n",
32
  "import config"
33
  ]
34
  },
35
- {
36
- "cell_type": "code",
37
- "execution_count": 14,
38
- "metadata": {},
39
- "outputs": [
40
- {
41
- "data": {
42
- "text/plain": [
43
- "24"
44
- ]
45
- },
46
- "execution_count": 14,
47
- "metadata": {},
48
- "output_type": "execute_result"
49
- }
50
- ],
51
- "source": [
52
- "add_node = OperatorNode(Add(), [IntegerConstant(7), IntegerConstant(5)])\n",
53
- "subtract_node = OperatorNode(Subtract(), [IntegerConstant(3), IntegerConstant(1)])\n",
54
- "multiply_node = OperatorNode(Multiply(), [add_node, subtract_node])\n",
55
- "multiply_node.evaluate()"
56
- ]
57
- },
58
  {
59
  "cell_type": "markdown",
60
  "metadata": {},
 
26
  "import argparse\n",
27
  "\n",
28
  "# import arithmetic module\n",
29
+ "from arithmetic import *\n",
30
+ "from abstract_syntax_tree import OperatorNode\n",
31
  "from examples import example_set, check_examples\n",
32
  "import config"
33
  ]
34
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  {
36
  "cell_type": "markdown",
37
  "metadata": {},