loubnabnl HF staff commited on
Commit
2988706
·
1 Parent(s): 38d768b

add example script to evaluate a model and generate code

Browse files
Files changed (1) hide show
  1. example_script.py +132 -0
example_script.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This is an example script to evaluate a code generation model on APPS, you can also use the APPS solutions as code generations
2
+ >>> python example_script.py --model_ckpt MODEL_NAME --num_tasks 10 --difficulty introductory --n_samples 1
3
+ >>> python example_script.py --use_solutions True --num_tasks 10 --difficulty introductory --n_samples 1"""
4
+
5
+ import json
6
+ import pprint
7
+ from tqdm import tqdm
8
+ from datasets import load_dataset
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, set_seed
10
+ from tools.utils import compute_metrics
11
+
12
+ def generate_prompt(sample):
13
+ starter_code = None if len(sample["starter_code"]) == 0 else sample["starter_code"]
14
+ try:
15
+ input_outpout = json.loads(sample["input_output"])
16
+ fn_name = None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
17
+ except ValueError:
18
+ fn_name = None
19
+ _input = "\nQUESTION:\n"
20
+ _input += sample["question"]
21
+ if starter_code:
22
+ _input += starter_code
23
+ if fn_name:
24
+ _input += "\nUse Standard Input format"
25
+ else:
26
+ _input += "\nUse Call-Based format"
27
+
28
+ _input += "\nANSWER:\n"
29
+ return _input
30
+
31
+
32
+ def complete_code(pipe, prompt, num_completions=1, max_length=256, **gen_kwargs):
33
+ """Complete prompt with text generation pipeline and return num_completions."""
34
+ prompt = pipe.tokenizer.eos_token + prompt
35
+ try:
36
+ code_gens = pipe(prompt, num_return_sequences=num_completions, max_length=max_length, **gen_kwargs)
37
+ return [code_gen["generated_text"][len(prompt):] for code_gen in code_gens]
38
+ except IndexError:
39
+ print("prompt is longer than the context size of the model, generation skipped")
40
+ code_gens = ""
41
+ return [""]
42
+
43
+
44
+ def make_generations(dataset, args, model, tokenizer):
45
+ set_seed(args.seed)
46
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
47
+
48
+ # Generation settings
49
+ gen_kwargs = {
50
+ "do_sample": args.do_sample,
51
+ "temperature": args.temperature,
52
+ "top_p": args.top_p,
53
+ "top_k": args.top_k
54
+ }
55
+
56
+ # Generate completions for evaluation set
57
+ n_tasks = args.num_tasks if args.num_tasks is not None else len(dataset)
58
+ print(f"ntasks is {n_tasks}")
59
+ generations = []
60
+ for task in tqdm(range(n_tasks)):
61
+ task_generations = []
62
+ prompt = generate_prompt(dataset[task]).strip()
63
+ task_generations.extend(complete_code(pipe, prompt, num_completions=args.n_samples, max_length=args.max_length, **gen_kwargs))
64
+ generations.append([gen.replace(args.eos, "") for gen in task_generations])
65
+ return generations
66
+
67
+
68
+ def main(args):
69
+ DATA_PATH = "codeparrot/apps"
70
+ argsdict = vars(args)
71
+ print(pprint.pformat(argsdict))
72
+
73
+ # setup
74
+ print("Loading evaluation dataset...")
75
+ dataset = load_dataset(DATA_PATH, split="test", difficulties=[args.difficulty])
76
+ if args.use_solutions:
77
+ print("Using data solutions as code generations")
78
+ model = None
79
+ tokenizer = None
80
+ generations = []
81
+ for index in range(args.num_tasks+1):
82
+ try:
83
+ sol = json.loads(dataset[index]["solutions"])
84
+ generations.append(sol[:args.n_solutions])
85
+ except ValueError:
86
+ print(f"No solutions for task {index} or not enough to have {args.n_solutions} solutions")
87
+ break
88
+
89
+ else:
90
+ print("Loading tokenizer and model...")
91
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
92
+ model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
93
+ generations = make_generations(dataset, args, model, tokenizer)
94
+
95
+ metrics = compute_metrics(generations, level=args.difficulty, k_list=args.k_list, count_errors=args.count_errors, debug=args.debug)
96
+ print(metrics)
97
+ with open(args.output_file, "w") as fp:
98
+ json.dump(metrics, fp)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ import argparse
103
+
104
+ parser = argparse.ArgumentParser(description="Testing a Language Model on APPS Python Code dataset")
105
+ #model and tokenizer arguments
106
+ parser.add_argument("--model_ckpt", default="loubnabnl/apps-1.5B-model", type=str, help="path to model checkpoint.")
107
+ parser.add_argument("--tokenizer", default="gpt2", type=str, help="tokenizer to use.")
108
+ parser.add_argument("--eos", default="<|endoftext|>", type=str, help="end of sentence token.")
109
+ # generation arguments
110
+ parser.add_argument("--do_sample", default=True, type=bool, help="do sampling in generation")
111
+ parser.add_argument("--temperature", default=0.2, type=float, help="temperature for sampling")
112
+ parser.add_argument("--top_p", default=0.95, type=float, help="top p for sampling")
113
+ parser.add_argument("--top_k", default=0, type=float, help="top k for sampling")
114
+ parser.add_argument("--max_length", default=1024, type=int, help="max length of generated code")
115
+ # evaluation arguments
116
+ parser.add_argument("--difficulty", default="all", type=str, help="difficulty level to select in the dataset from:\
117
+ 'all', 'introductory', 'interview' and 'competition' ")
118
+ parser.add_argument("--num_tasks", default=6, type=int, help="number of tasks to evaluate")
119
+ parser.add_argument("--use_solutions", default=False, type=bool, help="use solutions instead of generating new code")
120
+ parser.add_argument("--n_samples", default=1, type=int, help="number of samples to generate")
121
+ parser.add_argument("--n_solutions", default=1, type=int, help="number of solutions to use")
122
+ parser.add_argument("--k_list", default=[1, 2, 3], type=list, help="list of k values to evaluate pass@k")
123
+ parser.add_argument("--count_errors", default=False, type=bool, help="count compilation and runtime errors for single generations")
124
+ # configuration
125
+ parser.add_argument("--seed", default=0, type=int, help="generation seed")
126
+ parser.add_argument("--device_int", default=-1, type=int, help="device on which code generation is run, if positive use GPU")
127
+ parser.add_argument("--debug", default=False, type=bool, help="debug mode")
128
+ # save
129
+ parser.add_argument("--output_file", default="apps_metrics.json", type=str, help="output file to save the results")
130
+
131
+ args = parser.parse_args()
132
+ main(args)