|
""" |
|
This script aims at quickly initialize a sketch for HumanEvalPlus. It's not going to be |
|
perfect, but we will either manually or automatically fix/complete it later. |
|
+ CHANGE 1: Adds "contract", "base_input", "atol" in addition to HumanEval. |
|
""" |
|
|
|
import json |
|
import os |
|
import pathlib |
|
from importlib import import_module |
|
from inspect import getsource |
|
from typing import Tuple |
|
|
|
from tempdir import TempDir |
|
|
|
from evalplus.data.humaneval import get_human_eval |
|
|
|
HUMANEVAL_PLUS_PATH = ( |
|
pathlib.Path(__file__).parent.parent.parent / "HumanEvalPlus.jsonl" |
|
) |
|
|
|
|
|
def _ret(entry_point) -> str: |
|
"""This is a hacky function to return some garbages so that we can |
|
successfully run the function . |
|
""" |
|
if entry_point == "sort_third" or entry_point == "sort_even": |
|
return [1, 2, 3] |
|
elif entry_point == "bf": |
|
return () |
|
return "1" |
|
|
|
|
|
def instrument_inputs(entry_point, prompt, test) -> str: |
|
globals()["_inputs"] = [] |
|
fn_text = f"""{prompt.split(f"def {entry_point}")[0]} |
|
|
|
def {entry_point}(*args): |
|
_inputs.append(args) |
|
return {_ret(entry_point)} |
|
""" |
|
exec(fn_text, globals()) |
|
exec(test.replace("assert ", ""), globals()) |
|
exec(f"check({entry_point})", globals()) |
|
exec(fn_text, globals()) |
|
return globals()["_inputs"] |
|
|
|
|
|
def get_contract_and_ref(task_id: int, entry_point) -> Tuple[str, str]: |
|
mod = import_module(f"groundtruth.humaneval.{str(task_id).zfill(3)}_{entry_point}") |
|
fn = getattr(mod, entry_point) |
|
|
|
doc = fn.__doc__ |
|
if task_id == 51: |
|
doc = doc.replace("bcdf\nghjklm", r"bcdf\nghjklm").replace( |
|
"abcdef\nghijklm", r"abcdef\nghijklm" |
|
) |
|
|
|
code = ( |
|
getsource(fn).replace(doc, "").replace("''''''", '""""""').split('""""""\n')[-1] |
|
) |
|
|
|
assert code, f"Something wrong with {task_id}!" |
|
assert code[:3] != "def", f"Something wrong with the {task_id}!" |
|
|
|
|
|
contract = "" |
|
impl = "" |
|
|
|
reading_contract = True |
|
for line in code.strip("\n").split("\n"): |
|
if reading_contract and "$_CONTRACT_$" in line: |
|
contract += line + "\n" |
|
else: |
|
reading_contract = False |
|
impl += line + "\n" |
|
|
|
if contract: |
|
contract = "\n" + contract |
|
|
|
return contract, "\n" + impl + "\n" |
|
|
|
|
|
def get_atol(task_id: int) -> float: |
|
if task_id == 2 or task_id == 4: |
|
return 1e-6 |
|
elif task_id == 32: |
|
return 1e-4 |
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
assert not HUMANEVAL_PLUS_PATH.exists(), f"{HUMANEVAL_PLUS_PATH} already exists!" |
|
|
|
human_eval = get_human_eval() |
|
with TempDir() as temp_dir: |
|
tmp_file = os.path.join(temp_dir, HUMANEVAL_PLUS_PATH) |
|
with open(tmp_file, "w") as writer: |
|
for task in human_eval: |
|
task_id = int(task["task_id"].split("/")[-1]) |
|
task["contract"], task["canonical_solution"] = get_contract_and_ref( |
|
task_id, task["entry_point"] |
|
) |
|
task["base_input"] = instrument_inputs( |
|
task["entry_point"], task["prompt"], task["test"] |
|
) |
|
task["atol"] = get_atol(task_id) |
|
task["task_id"] = task["task_id"] |
|
|
|
writer.write(json.dumps(task) + "\n") |
|
|
|
os.rename(tmp_file, HUMANEVAL_PLUS_PATH) |
|
|