loubnabnl HF staff commited on
Commit
8280fac
·
1 Parent(s): 622f142

remove tools folder

Browse files
Files changed (2) hide show
  1. tools/testing_util.py +0 -438
  2. tools/utils.py +0 -188
tools/testing_util.py DELETED
@@ -1,438 +0,0 @@
1
- import json
2
- import sys
3
- import faulthandler
4
-
5
- # used for debugging to time steps
6
- from datetime import datetime
7
-
8
- # to run the solution files we're using a timing based approach
9
- import signal
10
-
11
- import numpy as np
12
- # for capturing the stdout
13
- from io import StringIO
14
- # used for testing the code that reads from input
15
- from unittest.mock import patch, mock_open
16
-
17
- from pyext import RuntimeModule
18
-
19
- from enum import Enum
20
- class CODE_TYPE(Enum):
21
- call_based = 0
22
- standard_input = 1
23
-
24
- # stuff for setting up signal timer
25
- class TimeoutException(Exception):
26
- pass
27
- def timeout_handler(signum, frame):
28
- print("alarm went off")
29
- #return
30
- raise TimeoutException
31
- signal.signal(signal.SIGALRM, timeout_handler)
32
- timeout = 4 # seconds
33
-
34
- # used to capture stdout as a list
35
- # from https://stackoverflow.com/a/16571630/6416660
36
- # alternative use redirect_stdout() from contextlib
37
- class Capturing(list):
38
- def __enter__(self):
39
- self._stdout = sys.stdout
40
- sys.stdout = self._stringio = StringIO()
41
- # Make closing the StringIO a no-op
42
- self._stringio.close = lambda x: 1
43
- return self
44
- def __exit__(self, *args):
45
- self.extend(self._stringio.getvalue().splitlines())
46
- del self._stringio # free up some memory
47
- sys.stdout = self._stdout
48
-
49
-
50
- def run_test(sample, test=None, debug=False):
51
- """
52
- if test(generated_code) is not None it'll try to run the code.
53
- otherwise it'll just return an input and output pair.
54
- """
55
- if debug:
56
- print(f"start = {datetime.now().time()}")
57
-
58
- try:
59
- in_outs = json.loads(sample["input_output"])
60
- except ValueError:
61
- in_outs = None
62
- if in_outs:
63
- if in_outs.get("fn_name") is None:
64
- which_type = CODE_TYPE.standard_input # Standard input
65
- method_name = None
66
- else:
67
- which_type = CODE_TYPE.call_based # Call-based
68
- method_name = in_outs["fn_name"]
69
-
70
- if debug:
71
- print(f"loaded input_output = {datetime.now().time()}")
72
-
73
- if test is None:
74
- return in_outs
75
- elif test is not None:
76
- results = []
77
- sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"
78
- if debug:
79
- print(f"loading test code = {datetime.now().time()}")
80
-
81
- if which_type == CODE_TYPE.call_based:
82
- sol += test
83
- if debug:
84
- print(f"sol = {sol}")
85
- signal.alarm(timeout)
86
- try:
87
- tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
88
- if "class Solution" not in test:
89
- tmp = tmp_sol
90
- else:
91
- tmp = tmp_sol.Solution()
92
- signal.alarm(0)
93
- except Exception as e:
94
- signal.alarm(0)
95
- if debug:
96
- print(f"type 0 compilation error = {e}")
97
- results.append(-2)
98
- return results
99
- signal.alarm(0)
100
-
101
- elif which_type == CODE_TYPE.standard_input:
102
- # sol
103
- tmp_test = test.split("\n")
104
-
105
- new_test = []
106
- for x in tmp_test:
107
- if (not x.startswith("from ")) and (not x.startswith("import ")):
108
- new_test.append("\t" + x + "\n")
109
- else:
110
- new_test.append(x + "\n")
111
- tmp_test = new_test
112
-
113
- new_test = ""
114
- started = False
115
- for i in tmp_test:
116
- if i.startswith("\t") and not started:
117
- new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
118
- new_test += "def code():\n"
119
- new_test += i
120
- started = True
121
- elif started and ((i.startswith("from ")) or (i.startswith("import "))):
122
- new_test += "\t" + i
123
- else:
124
- new_test += i
125
- tmp_test = new_test
126
-
127
- sol += tmp_test
128
- if debug:
129
- print(f"sol = {sol}")
130
- method_name = "code"
131
- signal.alarm(timeout)
132
- try:
133
- tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
134
- tmp = tmp_sol
135
- signal.alarm(0)
136
- except Exception as e:
137
- signal.alarm(0)
138
- if debug:
139
- print(f"type 1 compilation error = {e}")
140
- results.append(-2)
141
- return results
142
- signal.alarm(0)
143
- if debug:
144
- print(f"get method = {datetime.now().time()}")
145
-
146
- try:
147
- method = getattr(tmp, method_name) # get_attr second arg must be str
148
- except:
149
- signal.alarm(0)
150
- e = sys.exc_info()
151
- print(f"unable to get function error = {e}")
152
- return results
153
-
154
- for index, inputs in enumerate(in_outs["inputs"]):
155
- # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
156
- try:
157
- if isinstance(inputs[0], dict):
158
- inputs = [{int(k): v for k,v in inputs[0].items()}]
159
- except:
160
- True
161
- try:
162
- if isinstance(in_outs["outputs"][index], dict):
163
- in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index].items()}]
164
- except:
165
- True
166
- try:
167
- if isinstance(in_outs["outputs"][index][0], dict):
168
- in_outs["outputs"][index] = [{int(k): v for k,v in in_outs["outputs"][index][0].items()}]
169
- except:
170
- True
171
-
172
- if debug:
173
- print(f"time: {datetime.now().time()} testing index = {index} inputs = {inputs}, {type(inputs)}. type = {which_type}")
174
- if which_type == CODE_TYPE.call_based: # Call-based
175
- signal.alarm(timeout)
176
- faulthandler.enable()
177
- try:
178
- output = method(*inputs)
179
-
180
- # ground truth sequences are not tuples
181
- if isinstance(output, tuple):
182
- output = list(output)
183
-
184
- tmp_result = output == in_outs["outputs"][index]
185
- if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
186
- tmp_result = tmp_result or (output == in_outs["outputs"][index][0])
187
-
188
- # ground truth sequences are not tuples
189
- try:
190
- if isinstance(output[0], tuple):
191
- tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
192
- except:
193
- True
194
- results.append(tmp_result)
195
-
196
- # reset the alarm
197
- signal.alarm(0)
198
- except Exception as e:
199
- signal.alarm(0)
200
- faulthandler.disable()
201
- print(f"Standard input runtime error or time limit exceeded error = {e}")
202
- results.append(-1)
203
- continue
204
- faulthandler.disable()
205
- signal.alarm(0)
206
- if debug:
207
- print(f"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
208
- elif which_type == CODE_TYPE.standard_input: # Standard input
209
- faulthandler.enable()
210
- signal.alarm(timeout)
211
- passed = False
212
-
213
- if isinstance(inputs, list):
214
- inputs = "\n".join(inputs)
215
- if isinstance(in_outs['outputs'][index], list):
216
- in_outs['outputs'][index] = "\n".join(in_outs['outputs'][index])
217
-
218
- with Capturing() as output:
219
- try:
220
- call_method(method, inputs)
221
- # reset the alarm
222
- signal.alarm(0)
223
- passed = True
224
- except Exception as e:
225
- # runtime error or took too long
226
- signal.alarm(0)
227
- print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
228
- results.append(-1)
229
- signal.alarm(0)
230
-
231
- if not passed:
232
- if debug:
233
- nl = "\n"
234
- if not isinstance(inputs, list):
235
- print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
236
- else:
237
- print(f"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
238
- continue
239
-
240
- if passed and debug:
241
- print(f"==> output = {output}, test outputs = {in_outs['outputs'][index]}")
242
-
243
- if custom_compare_(output, in_outs['outputs'][index]):
244
- tmp_result = True
245
- results.append(tmp_result)
246
- continue
247
-
248
- # ground truth sequences are expressed as lists not tuples
249
- if isinstance(output, tuple):
250
- output = list(output)
251
-
252
- tmp_result = False
253
- try:
254
- tmp_result = (output == [in_outs["outputs"][index]])
255
- if isinstance(in_outs["outputs"][index], list):
256
- tmp_result = tmp_result or (output == in_outs["outputs"][index])
257
- if isinstance(output[0], str):
258
- tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
259
- except Exception as e:
260
- if debug:
261
- print(f"Failed check1 exception = {e}")
262
- pass
263
-
264
- if tmp_result == True:
265
- results.append(tmp_result)
266
- continue
267
-
268
- # try one more time without \n
269
- if isinstance(in_outs["outputs"][index], list):
270
- for tmp_index, i in enumerate(in_outs["outputs"][index]):
271
- in_outs["outputs"][index][tmp_index] = i.split("\n")
272
- in_outs["outputs"][index][tmp_index] = [x.strip() for x in in_outs["outputs"][index][tmp_index] if x]
273
- else:
274
- in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
275
- in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
276
- in_outs["outputs"][index] = list(map(lambda x:x.strip(), in_outs["outputs"][index]))
277
-
278
- try:
279
- tmp_result = (output == [in_outs["outputs"][index]])
280
- if isinstance(in_outs["outputs"][index], list):
281
- tmp_result = tmp_result or (output == in_outs["outputs"][index])
282
- except Exception as e:
283
- if debug:
284
- print(f"Failed check2 exception = {e}")
285
- pass
286
-
287
- if tmp_result == True:
288
- results.append(tmp_result)
289
- continue
290
-
291
- # try by converting the output into a split up list too
292
- if isinstance(output, list):
293
- output = list(filter(len, output))
294
-
295
- if debug:
296
- nl = "\n"
297
- if not isinstance(inputs, list):
298
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
299
- else:
300
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
301
-
302
- if tmp_result == True:
303
- results.append(tmp_result)
304
- continue
305
-
306
- try:
307
- tmp_result = (output == [in_outs["outputs"][index]])
308
- if isinstance(in_outs["outputs"][index], list):
309
- tmp_result = tmp_result or (output == in_outs["outputs"][index])
310
- except Exception as e:
311
- if debug:
312
- print(f"Failed check3 exception = {e}")
313
- pass
314
-
315
- try:
316
- output_float = [float(e) for e in output]
317
- gt_float = [float(e) for e in in_outs['outputs'][index]]
318
- tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
319
- except Exception as e:
320
- pass
321
- try:
322
- if isinstance(output[0], list):
323
- output_float = [float(e) for e in output[0]]
324
- gt_float = [float(e) for e in in_outs['outputs'][index][0]]
325
- tmp_result = tmp_result or ((len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float))
326
- except Exception as e:
327
- pass
328
-
329
- if tmp_result == True:
330
- results.append(tmp_result)
331
- continue
332
-
333
- # try by converting the stuff into split up list
334
- if isinstance(in_outs["outputs"][index], list):
335
- for tmp_index, i in enumerate(in_outs["outputs"][index]):
336
- in_outs["outputs"][index][tmp_index] = set(i.split())
337
- else:
338
- in_outs["outputs"][index] = set(in_outs["outputs"][index].split())
339
-
340
- try:
341
- tmp_result = (output == in_outs["outputs"][index])
342
- except Exception as e:
343
- if debug:
344
- print(f"Failed check4 exception = {e}")
345
- continue
346
-
347
- if tmp_result == True:
348
- results.append(tmp_result)
349
- continue
350
-
351
- # try by converting the output into a split up list too
352
- if isinstance(output, list):
353
- for tmp_index, i in enumerate(output):
354
- output[tmp_index] = i.split()
355
- output = list(filter(len, output))
356
- for tmp_index, i in enumerate(output):
357
- output[tmp_index] = set(i)
358
- else:
359
- output = output.split()
360
- output = list(filter(len, output))
361
- output = set(output)
362
-
363
- try:
364
- tmp_result = (set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index]))
365
- except Exception as e:
366
- if debug:
367
- print(f"Failed check5 exception = {e}")
368
-
369
-
370
- # if they are all numbers, round so that similar numbers are treated as identical
371
- try:
372
- tmp_result = tmp_result or (set(frozenset(round(float(t),3) for t in s) for s in output) ==\
373
- set(frozenset(round(float(t),3) for t in s) for s in in_outs["outputs"][index]))
374
- except Exception as e:
375
- if debug:
376
- print(f"Failed check6 exception = {e}")
377
-
378
- if tmp_result == True and debug:
379
- print("PASSED")
380
-
381
- results.append(tmp_result)
382
-
383
- if debug:
384
- nl = "\n"
385
- if not isinstance(inputs, list):
386
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
387
- else:
388
- print(f"output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}")
389
-
390
-
391
- return results
392
-
393
-
394
- def custom_compare_(output, ground_truth):
395
-
396
- if isinstance(output, list):
397
- output_1 = "\n".join(output)
398
- if stripped_string_compare(output_1, ground_truth):
399
- return True
400
-
401
- if isinstance(output, list):
402
- output_2 = [o.lstrip().rstrip() for o in output]
403
- output_2 = "\n".join(output_2)
404
- if stripped_string_compare(output_2, ground_truth):
405
- return True
406
-
407
- return False
408
-
409
- def stripped_string_compare(s1, s2):
410
- s1 = s1.lstrip().rstrip()
411
- s2 = s2.lstrip().rstrip()
412
- return s1 == s2
413
-
414
- def call_method(method, inputs):
415
-
416
- if isinstance(inputs, list):
417
- inputs = "\n".join(inputs)
418
-
419
- inputs_line_iterator = iter(inputs.split("\n"))
420
-
421
- # sys.setrecursionlimit(10000)
422
-
423
- # @patch('builtins.input', side_effect=inputs.split("\n"))
424
- @patch('builtins.open', mock_open(read_data=inputs))
425
- @patch('sys.stdin', StringIO(inputs))
426
- @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator))
427
- @patch('sys.stdin.readlines', lambda *args: inputs.split("\n"))
428
- @patch('sys.stdin.read', lambda *args: inputs)
429
- # @patch('sys.stdout.write', print)
430
- def _inner_call_method(_method):
431
- try:
432
- return _method()
433
- except SystemExit as e:
434
- pass
435
- finally:
436
- pass
437
- return _inner_call_method(method)
438
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/utils.py DELETED
@@ -1,188 +0,0 @@
1
- import itertools
2
- import numpy as np
3
- from typing import Dict
4
- from datasets import load_dataset
5
- import tools.testing_util as test_util
6
-
7
-
8
- DATASET = "codeparrot/apps"
9
-
10
-
11
- def evaluate_generations(generations: list, level: str = "all", debug: bool = False):
12
- """We take the list of code generations and try to compile them
13
- and the run their corresponding unit tests which are retrieved from the APPS dataset.
14
-
15
- Args:
16
- generations: list of code generations (same order as samples in APPS dataset)
17
- level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"
18
-
19
- Returns:
20
- results: dictionary of results, key is the problem index, value is a list of results for each generation
21
- [-2] = compile error, [-1] = runtime error [False] = failed test case [True] = passed test case
22
- """
23
-
24
- # generations are code generations in the same order of the dataset
25
- apps_eval = load_dataset(DATASET, split="test", difficulties=[level])
26
- results = {}
27
- for index in range(len(generations)):
28
- # code generations for problem (index)
29
- problem_generations = generations[index]
30
- # get corresponding samples from APPS dataset
31
- sample = apps_eval[index]
32
- res = []
33
- # loop over the generations
34
- for o_idx, o in enumerate(problem_generations):
35
- curr_res = [-2]
36
- try:
37
- curr_res = test_util.run_test(sample, test=o, debug=debug)
38
- #if debug:
39
- print(f"\nSuccessful compilation of task {index}!")
40
- fixed = []
41
- for e in curr_res:
42
- if isinstance(e, np.ndarray):
43
- e = e.item(0)
44
- if isinstance(e, np.bool_):
45
- e = bool(e)
46
- fixed.append(e)
47
- curr_res = fixed
48
- if not np.all(curr_res):
49
- #if debug:
50
- print(f"Results were not True for all test cases")
51
- except Exception as e:
52
- if debug:
53
- print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
54
- break
55
- finally:
56
- assert isinstance(curr_res, list)
57
- res.append(curr_res)
58
- results[index] = res
59
- return results
60
-
61
-
62
- def estimate_pass_at_k(num_samples, num_correct, k):
63
- """Estimates pass@k of each problem and returns them in an array."""
64
-
65
- def estimator(n: int, c: int, k: int) -> float:
66
- """Calculates 1 - comb(n - c, k) / comb(n, k)."""
67
- if n - c < k:
68
- return 1.0
69
- return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
70
-
71
- if isinstance(num_samples, int):
72
- num_samples_it = itertools.repeat(num_samples, len(num_correct))
73
- else:
74
- assert len(num_samples) == len(num_correct)
75
- num_samples_it = iter(num_samples)
76
-
77
- return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
78
-
79
-
80
- def get_results(results: Dict[int, list], count_errors: bool = False, k_list: list = [1, 10, 100]):
81
- """
82
- Given the results evaluated against the testcases we output some statistics.
83
- For single generations:
84
- >>> example_results = {0: [[-2]], 1: [[False,False]], 2: [[True,True]], 3: [[False,True,False,True]], 4: [[-1,-1]]}
85
- >>> get_results(example_results, count_errors=True)
86
- Computing accuracy metrics...
87
- number of compile errors = 1 avg = 0.2
88
- number of runtime errors = 1 avg = 0.2
89
- number of problems evaluated = 5
90
- Average Accuracy : 0.3
91
- Strict Accuracy : 0.2
92
- {'avg_accuracy': 0.3, 'strict_accuracy': 0.2, 'pass_at_k': None}
93
-
94
- For multiple generations:
95
- >>> example_results = {0: [[-2], [True, True, True]], 1: [[-1,-1, -1], [True, False, True]]}
96
- >>> get_results(example_results, k_list=[1, 2])
97
- Computing pass@k metric for multiple generations...
98
- {'pass@1': 0.25, 'pass@2': 0.5}
99
- {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 0.25, 'pass@2': 0.5}}
100
- """
101
-
102
- metrics = {"avg_accuracy": None, "strict_accuracy": None, "pass_at_k": None}
103
-
104
- if len(results[0]) == 1:
105
- # for single generations we compute average accuracy and stric accuracy: original APPS metrics
106
- print("Computing accuracy metrics...")
107
- res = []
108
- per_prob_res = []
109
- all_correct = []
110
- for index in results:
111
- problem_results = np.asarray(results[index])
112
- res.extend(problem_results)
113
- per_prob_res.append(np.mean(problem_results > 0))
114
- all_correct.append(np.all(problem_results > 0))
115
- # we count campilation and runtime errors once per pronlem
116
- compile_errors = len([e for e in res if -2 in e])
117
- runtime_errors = len([e for e in res if -1 in e])
118
- total_testcases = len(res)
119
- if count_errors:
120
- print(f"number of compile errors = {compile_errors} avg = {compile_errors / total_testcases}")
121
- print(f"number of runtime errors = {runtime_errors} avg = {runtime_errors / total_testcases}")
122
- print(f"number of problems evaluated = {total_testcases}")
123
-
124
- print(f"Average Accuracy : {np.mean(per_prob_res)}")
125
- print(f"Strict Accuracy : {np.mean(all_correct)}")
126
- metrics["avg_accuracy"] = np.mean(per_prob_res)
127
- metrics["strict_accuracy"] = np.mean(all_correct)
128
-
129
- else:
130
- # for multiple generations we use pass@k metric used in the HumanEval benchmark
131
- # we use strict accuracy, a generation is valid if it has to pass all the tests
132
- print("Computing pass@k metric for multiple generations...")
133
- # total is list with nb generations per task (task=index)
134
- # correct is number of generations that passed all tests per task
135
- total = []
136
- correct = []
137
- for index in results:
138
- all_correct = []
139
- for generation in results[index]:
140
- gen = np.array(generation)
141
- all_correct.append(np.all(gen>0))
142
- total.append(len(all_correct))
143
- correct.append(sum(all_correct))
144
- total = np.array(total)
145
- correct = np.array(correct)
146
- ks = k_list
147
- pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
148
- print(pass_at_k)
149
- metrics["pass_at_k"] = pass_at_k
150
- return metrics
151
-
152
- def compute_metrics(generations, level="all", k_list=[1, 10, 100], count_errors=True, debug=False):
153
- """Return metrics for the given generations.
154
- Args:
155
- generations: list of code generations for each problem (each generation is a list of generations)
156
- k_list: list of k values to compute pass@k when using multiple generations
157
- count_errors: whether to count compilation and runtime errors when using single generations
158
- level: difficulty level in APPS dataset that was used for the given generations (from: "all", "introductory", "interview", "competition")
159
- Returns:
160
- metrics: dict of metrics
161
-
162
- Examples:
163
-
164
- >>> import json
165
- >>> # lists of solutions to the two first APPS problems (note not all solutions pass all tests)
166
- >>> solution_sample1 = json.load(open("test_examples/solutions_problem_1.json", "r"))
167
- >>> solution_sample2 = json.load(open("test_examples/solutions_problem_2.json", "r"))
168
- >>> single_solutions = [solution_sample1[:1], solution_sample2[:1]]
169
- >>> compute_metrics(single_solutions, level="all")
170
- Computing accuracy metrics...
171
- number of compile errors = 0 avg = 0.0
172
- number of runtime errors = 0 avg = 0.0
173
- number of problems evaluated = 2
174
- Average Accuracy : 1.0
175
- Strict Accuracy : 1.0
176
- {'avg_accuracy': 1.0, 'strict_accuracy': 1.0, 'pass_at_k': None}
177
- >>> multiple_solutions = [solution_sample1[:3], solution_sample2[:3]]
178
- >>> compute_metrics(multiple_solutions, level="all", k_list=[1, 2, 3])
179
- Computing pass@k metric for multiple generations...
180
- {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}
181
- {'avg_accuracy': None, 'strict_accuracy': None, 'pass_at_k': {'pass@1': 1.0, 'pass@2': 1.0, 'pass@3': 1.0}}
182
- """
183
- results = evaluate_generations(generations, level=level, debug=debug)
184
- metrics = get_results(results, count_errors=count_errors, k_list=k_list)
185
- return metrics
186
-
187
- #import doctest
188
- #doctest.testmod()