|
from time import sleep |
|
import ast |
|
import astunparse |
|
import openai |
|
from openai.error import RateLimitError, APIConnectionError |
|
from pygments import highlight |
|
from pygments.lexers import PythonLexer |
|
from pygments.formatters import TerminalFormatter |
|
|
|
|
|
class LMP: |
|
|
|
def __init__(self, name, cfg, lmp_fgen, fixed_vars, variable_vars, md_logger): |
|
self._name = name |
|
self._cfg = cfg |
|
self._md_logger = md_logger |
|
|
|
with open(self._cfg['prompt_path'], 'r') as f: |
|
self._base_prompt = f.read() |
|
|
|
self._stop_tokens = list(self._cfg['stop']) |
|
|
|
self._lmp_fgen = lmp_fgen |
|
|
|
self._fixed_vars = fixed_vars |
|
self._variable_vars = variable_vars |
|
self.exec_hist = '' |
|
|
|
def clear_exec_hist(self): |
|
self.exec_hist = '' |
|
|
|
def build_prompt(self, query, context=''): |
|
if len(self._variable_vars) > 0: |
|
variable_vars_imports_str = f"from utils import {', '.join(self._variable_vars.keys())}" |
|
else: |
|
variable_vars_imports_str = '' |
|
prompt = self._base_prompt.replace('{variable_vars_imports}', variable_vars_imports_str) |
|
|
|
if self._cfg['maintain_session']: |
|
prompt += f'\n{self.exec_hist}' |
|
|
|
if context != '': |
|
prompt += f'\n{context}' |
|
|
|
use_query = f'{self._cfg["query_prefix"]}{query}{self._cfg["query_suffix"]}' |
|
prompt += f'\n{use_query}' |
|
|
|
return prompt, use_query |
|
|
|
def __call__(self, query, context='', **kwargs): |
|
prompt, use_query = self.build_prompt(query, context=context) |
|
|
|
while True: |
|
try: |
|
code_str = openai.Completion.create( |
|
prompt=prompt, |
|
stop=self._stop_tokens, |
|
temperature=self._cfg['temperature'], |
|
engine=self._cfg['engine'], |
|
max_tokens=self._cfg['max_tokens'] |
|
)['choices'][0]['text'].strip() |
|
break |
|
except (RateLimitError, APIConnectionError) as e: |
|
print(f'OpenAI API got err {e}') |
|
print('Retrying after 10s.') |
|
sleep(10) |
|
|
|
if self._cfg['include_context'] and context != '': |
|
to_exec = f'{context}\n{code_str}' |
|
to_log = f'{context}\n{use_query}\n{code_str}' |
|
else: |
|
to_exec = code_str |
|
to_log = f'{use_query}\n{to_exec}' |
|
|
|
to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter()) |
|
print(f'LMP {self._name} generated code:\n{to_log_pretty}') |
|
self._md_logger.log_text(f'LMP {self._name} Generated Code:') |
|
self._md_logger.log_code(to_log) |
|
|
|
new_fs = self._lmp_fgen.create_new_fs_from_code(code_str) |
|
self._variable_vars.update(new_fs) |
|
|
|
gvars = merge_dicts([self._fixed_vars, self._variable_vars]) |
|
lvars = kwargs |
|
|
|
if not self._cfg['debug_mode']: |
|
exec_safe(to_exec, gvars, lvars) |
|
|
|
self.exec_hist += f'\n{to_exec}' |
|
|
|
if self._cfg['maintain_session']: |
|
self._variable_vars.update(lvars) |
|
|
|
if self._cfg['has_return']: |
|
return lvars[self._cfg['return_val_name']] |
|
|
|
|
|
class LMPFGen: |
|
|
|
def __init__(self, cfg, fixed_vars, variable_vars, md_logger): |
|
self._cfg = cfg |
|
|
|
self._stop_tokens = list(self._cfg['stop']) |
|
self._fixed_vars = fixed_vars |
|
self._variable_vars = variable_vars |
|
self._md_logger = md_logger |
|
|
|
with open(self._cfg['prompt_path'], 'r') as f: |
|
self._base_prompt = f.read() |
|
|
|
def create_f_from_sig(self, f_name, f_sig, other_vars=None, fix_bugs=False, return_src=False): |
|
print(f'Creating function: {f_sig}') |
|
|
|
use_query = f'{self._cfg["query_prefix"]}{f_sig}{self._cfg["query_suffix"]}' |
|
prompt = f'{self._base_prompt}\n{use_query}' |
|
|
|
while True: |
|
try: |
|
f_src = openai.Completion.create( |
|
prompt=prompt, |
|
stop=self._stop_tokens, |
|
temperature=self._cfg['temperature'], |
|
engine=self._cfg['engine'], |
|
max_tokens=self._cfg['max_tokens'] |
|
)['choices'][0]['text'].strip() |
|
break |
|
except (RateLimitError, APIConnectionError) as e: |
|
print(f'OpenAI API got err {e}') |
|
print('Retrying after 10s.') |
|
sleep(10) |
|
|
|
if fix_bugs: |
|
f_src = openai.Edit.create( |
|
model='code-davinci-edit-001', |
|
input='# ' + f_src, |
|
temperature=0, |
|
instruction='Fix the bug if there is one. Improve readability. Keep same inputs and outputs. Only small changes. No comments.', |
|
)['choices'][0]['text'].strip() |
|
|
|
if other_vars is None: |
|
other_vars = {} |
|
gvars = merge_dicts([self._fixed_vars, self._variable_vars, other_vars]) |
|
lvars = {} |
|
|
|
exec_safe(f_src, gvars, lvars) |
|
|
|
f = lvars[f_name] |
|
|
|
to_print = f'{use_query}\n{f_src}' |
|
to_print_pretty = highlight(to_print, PythonLexer(), TerminalFormatter()) |
|
print(f'LMPFGen generated code:\n{to_print_pretty}') |
|
self._md_logger.log_text('Generated Function:') |
|
self._md_logger.log_code(to_print) |
|
|
|
if return_src: |
|
return f, f_src |
|
return f |
|
|
|
def create_new_fs_from_code(self, code_str, other_vars=None, fix_bugs=False, return_src=False): |
|
fs, f_assigns = {}, {} |
|
f_parser = FunctionParser(fs, f_assigns) |
|
f_parser.visit(ast.parse(code_str)) |
|
for f_name, f_assign in f_assigns.items(): |
|
if f_name in fs: |
|
fs[f_name] = f_assign |
|
|
|
if other_vars is None: |
|
other_vars = {} |
|
|
|
new_fs = {} |
|
srcs = {} |
|
for f_name, f_sig in fs.items(): |
|
all_vars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) |
|
if not var_exists(f_name, all_vars): |
|
f, f_src = self.create_f_from_sig(f_name, f_sig, new_fs, fix_bugs=fix_bugs, return_src=True) |
|
|
|
|
|
f_def_body = astunparse.unparse(ast.parse(f_src).body[0].body) |
|
child_fs, child_f_srcs = self.create_new_fs_from_code( |
|
f_def_body, other_vars=all_vars, fix_bugs=fix_bugs, return_src=True |
|
) |
|
|
|
if len(child_fs) > 0: |
|
new_fs.update(child_fs) |
|
srcs.update(child_f_srcs) |
|
|
|
|
|
gvars = merge_dicts([self._fixed_vars, self._variable_vars, new_fs, other_vars]) |
|
lvars = {} |
|
|
|
exec_safe(f_src, gvars, lvars) |
|
|
|
f = lvars[f_name] |
|
|
|
new_fs[f_name], srcs[f_name] = f, f_src |
|
|
|
if return_src: |
|
return new_fs, srcs |
|
return new_fs |
|
|
|
|
|
class FunctionParser(ast.NodeTransformer): |
|
|
|
def __init__(self, fs, f_assigns): |
|
super().__init__() |
|
self._fs = fs |
|
self._f_assigns = f_assigns |
|
|
|
def visit_Call(self, node): |
|
self.generic_visit(node) |
|
if isinstance(node.func, ast.Name): |
|
f_sig = astunparse.unparse(node).strip() |
|
f_name = astunparse.unparse(node.func).strip() |
|
self._fs[f_name] = f_sig |
|
return node |
|
|
|
def visit_Assign(self, node): |
|
self.generic_visit(node) |
|
if isinstance(node.value, ast.Call): |
|
assign_str = astunparse.unparse(node).strip() |
|
f_name = astunparse.unparse(node.value.func).strip() |
|
self._f_assigns[f_name] = assign_str |
|
return node |
|
|
|
|
|
def var_exists(name, all_vars): |
|
try: |
|
eval(name, all_vars) |
|
except: |
|
exists = False |
|
else: |
|
exists = True |
|
return exists |
|
|
|
|
|
def merge_dicts(dicts): |
|
return { |
|
k : v |
|
for d in dicts |
|
for k, v in d.items() |
|
} |
|
|
|
|
|
def exec_safe(code_str, gvars=None, lvars=None): |
|
banned_phrases = ['import', '__'] |
|
for phrase in banned_phrases: |
|
assert phrase not in code_str |
|
|
|
if gvars is None: |
|
gvars = {} |
|
if lvars is None: |
|
lvars = {} |
|
empty_fn = lambda *args, **kwargs: None |
|
custom_gvars = merge_dicts([ |
|
gvars, |
|
{'exec': empty_fn, 'eval': empty_fn} |
|
]) |
|
exec(code_str, custom_gvars, lvars) |