Spaces:
Running
Running
add example script to evaluate a model and generate code
Browse files- example_script.py +132 -0
example_script.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This is an example script to evaluate a code generation model on APPS, you can also use the APPS solutions as code generations
|
2 |
+
>>> python example_script.py --model_ckpt MODEL_NAME --num_tasks 10 --difficulty introductory --n_samples 1
|
3 |
+
>>> python example_script.py --use_solutions True --num_tasks 10 --difficulty introductory --n_samples 1"""
|
4 |
+
|
5 |
+
import json
|
6 |
+
import pprint
|
7 |
+
from tqdm import tqdm
|
8 |
+
from datasets import load_dataset
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, set_seed
|
10 |
+
from tools.utils import compute_metrics
|
11 |
+
|
12 |
+
def generate_prompt(sample):
|
13 |
+
starter_code = None if len(sample["starter_code"]) == 0 else sample["starter_code"]
|
14 |
+
try:
|
15 |
+
input_outpout = json.loads(sample["input_output"])
|
16 |
+
fn_name = None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
|
17 |
+
except ValueError:
|
18 |
+
fn_name = None
|
19 |
+
_input = "\nQUESTION:\n"
|
20 |
+
_input += sample["question"]
|
21 |
+
if starter_code:
|
22 |
+
_input += starter_code
|
23 |
+
if fn_name:
|
24 |
+
_input += "\nUse Standard Input format"
|
25 |
+
else:
|
26 |
+
_input += "\nUse Call-Based format"
|
27 |
+
|
28 |
+
_input += "\nANSWER:\n"
|
29 |
+
return _input
|
30 |
+
|
31 |
+
|
32 |
+
def complete_code(pipe, prompt, num_completions=1, max_length=256, **gen_kwargs):
|
33 |
+
"""Complete prompt with text generation pipeline and return num_completions."""
|
34 |
+
prompt = pipe.tokenizer.eos_token + prompt
|
35 |
+
try:
|
36 |
+
code_gens = pipe(prompt, num_return_sequences=num_completions, max_length=max_length, **gen_kwargs)
|
37 |
+
return [code_gen["generated_text"][len(prompt):] for code_gen in code_gens]
|
38 |
+
except IndexError:
|
39 |
+
print("prompt is longer than the context size of the model, generation skipped")
|
40 |
+
code_gens = ""
|
41 |
+
return [""]
|
42 |
+
|
43 |
+
|
44 |
+
def make_generations(dataset, args, model, tokenizer):
|
45 |
+
set_seed(args.seed)
|
46 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
|
47 |
+
|
48 |
+
# Generation settings
|
49 |
+
gen_kwargs = {
|
50 |
+
"do_sample": args.do_sample,
|
51 |
+
"temperature": args.temperature,
|
52 |
+
"top_p": args.top_p,
|
53 |
+
"top_k": args.top_k
|
54 |
+
}
|
55 |
+
|
56 |
+
# Generate completions for evaluation set
|
57 |
+
n_tasks = args.num_tasks if args.num_tasks is not None else len(dataset)
|
58 |
+
print(f"ntasks is {n_tasks}")
|
59 |
+
generations = []
|
60 |
+
for task in tqdm(range(n_tasks)):
|
61 |
+
task_generations = []
|
62 |
+
prompt = generate_prompt(dataset[task]).strip()
|
63 |
+
task_generations.extend(complete_code(pipe, prompt, num_completions=args.n_samples, max_length=args.max_length, **gen_kwargs))
|
64 |
+
generations.append([gen.replace(args.eos, "") for gen in task_generations])
|
65 |
+
return generations
|
66 |
+
|
67 |
+
|
68 |
+
def main(args):
|
69 |
+
DATA_PATH = "codeparrot/apps"
|
70 |
+
argsdict = vars(args)
|
71 |
+
print(pprint.pformat(argsdict))
|
72 |
+
|
73 |
+
# setup
|
74 |
+
print("Loading evaluation dataset...")
|
75 |
+
dataset = load_dataset(DATA_PATH, split="test", difficulties=[args.difficulty])
|
76 |
+
if args.use_solutions:
|
77 |
+
print("Using data solutions as code generations")
|
78 |
+
model = None
|
79 |
+
tokenizer = None
|
80 |
+
generations = []
|
81 |
+
for index in range(args.num_tasks+1):
|
82 |
+
try:
|
83 |
+
sol = json.loads(dataset[index]["solutions"])
|
84 |
+
generations.append(sol[:args.n_solutions])
|
85 |
+
except ValueError:
|
86 |
+
print(f"No solutions for task {index} or not enough to have {args.n_solutions} solutions")
|
87 |
+
break
|
88 |
+
|
89 |
+
else:
|
90 |
+
print("Loading tokenizer and model...")
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
92 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
93 |
+
generations = make_generations(dataset, args, model, tokenizer)
|
94 |
+
|
95 |
+
metrics = compute_metrics(generations, level=args.difficulty, k_list=args.k_list, count_errors=args.count_errors, debug=args.debug)
|
96 |
+
print(metrics)
|
97 |
+
with open(args.output_file, "w") as fp:
|
98 |
+
json.dump(metrics, fp)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
import argparse
|
103 |
+
|
104 |
+
parser = argparse.ArgumentParser(description="Testing a Language Model on APPS Python Code dataset")
|
105 |
+
#model and tokenizer arguments
|
106 |
+
parser.add_argument("--model_ckpt", default="loubnabnl/apps-1.5B-model", type=str, help="path to model checkpoint.")
|
107 |
+
parser.add_argument("--tokenizer", default="gpt2", type=str, help="tokenizer to use.")
|
108 |
+
parser.add_argument("--eos", default="<|endoftext|>", type=str, help="end of sentence token.")
|
109 |
+
# generation arguments
|
110 |
+
parser.add_argument("--do_sample", default=True, type=bool, help="do sampling in generation")
|
111 |
+
parser.add_argument("--temperature", default=0.2, type=float, help="temperature for sampling")
|
112 |
+
parser.add_argument("--top_p", default=0.95, type=float, help="top p for sampling")
|
113 |
+
parser.add_argument("--top_k", default=0, type=float, help="top k for sampling")
|
114 |
+
parser.add_argument("--max_length", default=1024, type=int, help="max length of generated code")
|
115 |
+
# evaluation arguments
|
116 |
+
parser.add_argument("--difficulty", default="all", type=str, help="difficulty level to select in the dataset from:\
|
117 |
+
'all', 'introductory', 'interview' and 'competition' ")
|
118 |
+
parser.add_argument("--num_tasks", default=6, type=int, help="number of tasks to evaluate")
|
119 |
+
parser.add_argument("--use_solutions", default=False, type=bool, help="use solutions instead of generating new code")
|
120 |
+
parser.add_argument("--n_samples", default=1, type=int, help="number of samples to generate")
|
121 |
+
parser.add_argument("--n_solutions", default=1, type=int, help="number of solutions to use")
|
122 |
+
parser.add_argument("--k_list", default=[1, 2, 3], type=list, help="list of k values to evaluate pass@k")
|
123 |
+
parser.add_argument("--count_errors", default=False, type=bool, help="count compilation and runtime errors for single generations")
|
124 |
+
# configuration
|
125 |
+
parser.add_argument("--seed", default=0, type=int, help="generation seed")
|
126 |
+
parser.add_argument("--device_int", default=-1, type=int, help="device on which code generation is run, if positive use GPU")
|
127 |
+
parser.add_argument("--debug", default=False, type=bool, help="debug mode")
|
128 |
+
# save
|
129 |
+
parser.add_argument("--output_file", default="apps_metrics.json", type=str, help="output file to save the results")
|
130 |
+
|
131 |
+
args = parser.parse_args()
|
132 |
+
main(args)
|