Spaces:
Sleeping
Sleeping
File size: 4,528 Bytes
3f02d46 a417ea3 3f02d46 3872a55 a417ea3 5b04db9 389a372 3f02d46 5b04db9 3f02d46 a417ea3 389a372 3f02d46 a417ea3 5b04db9 3f02d46 5b04db9 3f02d46 5b04db9 3f02d46 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 a417ea3 5b04db9 3f02d46 5b04db9 3f02d46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PSET 1: Bottom-Up Synthesis\n",
"\n",
"I follow Algorithm 1 in the BUSTLE paper:\n",
"\n",
"> Odena, A. *et al.* BUSTLE: Bottom-Up Program Synthesis Through Learning-Guided Exploration. in *9th International Conference on Learning Representations*; 2021 May 3-7; Austria.\n",
"\n",
"First, I import the required libraries."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import itertools\n",
"\n",
"# argument parser for command line arguments\n",
"import argparse\n",
"\n",
"# import arithmetic module\n",
"from arithmetic import *\n",
"from abstract_syntax_tree import OperatorNode\n",
"from examples import example_set, check_examples\n",
"from synthesizer import extract_constants, observationally_equivalent, check_program\n",
"import config"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, I define variables as proxies for command-line arguments provided to the synthesizer."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"domain = \"arithmetic\"\n",
"examples_key = \"addition\"\n",
"examples = example_set[examples_key]\n",
"max_weight = 3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I define a function to extract constants from examples."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# initialize program bank\n",
"program_bank = extract_constants(examples)\n",
"program_bank_str = [p.str() for p in program_bank]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, I define the bottom-up synthesis algorithm."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(x0 + x1)\n"
]
}
],
"source": [
"# define operators\n",
"operators = arithmetic_operators\n",
"\n",
"# iterate over each level\n",
"for weight in range(2, max_weight):\n",
"\n",
" for op in operators:\n",
"\n",
" # get all possible combinations of primitives in program bank\n",
" combinations = itertools.combinations(program_bank, op.arity)\n",
"\n",
" # iterate over each combination\n",
" for combination in combinations:\n",
"\n",
" # get type signature\n",
" type_signature = [p.type for p in combination]\n",
"\n",
" # check if type signature matches operator\n",
" if type_signature != op.arg_types:\n",
" continue\n",
"\n",
" # check that sum of weights of arguments <= w\n",
" if sum([p.weight for p in combination]) > weight:\n",
" continue\n",
"\n",
" # create new program\n",
" program = OperatorNode(op, combination)\n",
"\n",
" # check if program is in program bank using string representation\n",
" if program.str() in program_bank_str:\n",
" continue\n",
" \n",
" # check if program is observationally equivalent to any program in program bank\n",
" if any([observationally_equivalent(program, p, examples) for p in program_bank]):\n",
" continue\n",
"\n",
" # add program to program bank\n",
" program_bank.append(program)\n",
" program_bank_str.append(program.str())\n",
"\n",
" # check if program passes all examples\n",
" if check_program(program, examples):\n",
" # return(program)\n",
" print(program.str())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|