Kewen Zhao commited on
Commit
306dfab
·
1 Parent(s): fcdd54c

bugfix: input and output should support List[str], not just a single str

Browse files
Files changed (2) hide show
  1. code_eval_stdio.py +4 -4
  2. execute.py +41 -42
code_eval_stdio.py CHANGED
@@ -26,7 +26,7 @@ import numpy as np
26
 
27
  import evaluate
28
 
29
- from .execute import check_correctness
30
 
31
 
32
  _CITATION = """\
@@ -145,8 +145,8 @@ class CodeEval(evaluate.Metric):
145
  "predictions": datasets.Sequence(datasets.Value("string")),
146
  "references": datasets.Features(
147
  {
148
- "input": datasets.Value("string"),
149
- "reference_output": datasets.Value("string"),
150
  }
151
  ),
152
  }
@@ -178,7 +178,7 @@ class CodeEval(evaluate.Metric):
178
 
179
  for task_id, (candidates, reference) in enumerate(zip(predictions, references)):
180
  for candidate in candidates:
181
- args = (candidate, reference['input'], reference['reference_output'], timeout, task_id, completion_id[task_id])
182
  future = executor.submit(check_correctness, *args)
183
  futures.append(future)
184
  completion_id[task_id] += 1
 
26
 
27
  import evaluate
28
 
29
+ .from execute import check_correctness
30
 
31
 
32
  _CITATION = """\
 
145
  "predictions": datasets.Sequence(datasets.Value("string")),
146
  "references": datasets.Features(
147
  {
148
+ "inputs": datasets.Sequence(datasets.Value("string")),
149
+ "reference_outputs": datasets.Sequence(datasets.Value("string")),
150
  }
151
  ),
152
  }
 
178
 
179
  for task_id, (candidates, reference) in enumerate(zip(predictions, references)):
180
  for candidate in candidates:
181
+ args = (candidate, reference['inputs'], reference['reference_outputs'], timeout, task_id, completion_id[task_id])
182
  future = executor.submit(check_correctness, *args)
183
  futures.append(future)
184
  completion_id[task_id] += 1
execute.py CHANGED
@@ -25,49 +25,47 @@ import signal
25
  import tempfile
26
 
27
 
28
- def check_correctness(program, test_input, test_output, timeout, task_id, completion_id):
29
  """
30
  Evaluates the functional correctness of a completion by running the test
31
  suite provided in the problem.
32
-
33
  :param program: The program string to evaluate.
34
- :param test_input: The input string to provide to the program via STDIN.
35
- :param test_output: The expected output string from the program via STDOUT.
36
  :param timeout: Maximum execution time in seconds.
37
  :param task_id: ID of the task being evaluated.
38
  :param completion_id: Completion ID to match results later.
39
  """
40
  manager = multiprocessing.Manager()
41
- result = manager.list()
42
 
43
- process = multiprocessing.Process(target=unsafe_execute, args=(program, test_input, test_output, result, timeout))
44
  process.start()
45
  process.join(timeout=timeout + 1)
46
 
47
  if process.is_alive():
48
  process.kill()
49
- result.append("timed out")
50
 
51
  return dict(
52
  task_id=task_id,
53
- passed=result[0] == "passed",
54
- result=result[0],
55
  completion_id=completion_id,
 
56
  )
57
 
58
 
59
- def unsafe_execute(program, test_input, test_output, result, timeout):
60
  """
61
  Executes the program with redirected STDIN and compares its STDOUT to the expected output.
62
-
63
  :param program: The program string to execute.
64
- :param test_input: Input to provide to the program via STDIN.
65
- :param test_output: Expected output to compare with STDOUT.
66
- :param result: A multiprocessing.Manager().list() to store the execution result.
67
  :param timeout: Maximum execution time in seconds.
68
  """
69
  with create_tempdir():
70
-
71
  # These system calls are needed when cleaning up tempdir.
72
  import os
73
  import shutil
@@ -79,36 +77,37 @@ def unsafe_execute(program, test_input, test_output, result, timeout):
79
  # Disable functionalities that can make destructive changes to the test.
80
  reliability_guard()
81
 
82
- # Run program.
83
  try:
84
  exec_globals = {}
85
- actual_output = None
86
-
87
- # Redirect I/O and execute
88
- input_stream = io.StringIO(test_input)
89
- output_stream = io.StringIO()
90
-
91
- with swallow_io(input_stream, output_stream):
92
- with time_limit(timeout):
93
- exec(program, exec_globals)
94
-
95
- actual_output = output_stream.getvalue().strip()
96
- expected_output = test_output.strip()
97
-
98
- if actual_output == expected_output:
99
- result.append("passed")
100
- else:
101
- result.append(f"failed: got '{actual_output}', expected '{expected_output}'")
102
-
103
- except TimeoutException:
104
- result.append("timed out")
105
- except BaseException as e:
106
- result.append(f"failed: {e}")
107
 
108
- # Needed for cleaning up.
109
- shutil.rmtree = rmtree
110
- os.rmdir = rmdir
111
- os.chdir = chdir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  @contextlib.contextmanager
 
25
  import tempfile
26
 
27
 
28
+ def check_correctness(program, test_inputs, test_outputs, timeout, task_id, completion_id):
29
  """
30
  Evaluates the functional correctness of a completion by running the test
31
  suite provided in the problem.
 
32
  :param program: The program string to evaluate.
33
+ :param test_inputs: A list of input strings to provide to the program via STDIN.
34
+ :param test_outputs: A list of expected output strings from the program via STDOUT.
35
  :param timeout: Maximum execution time in seconds.
36
  :param task_id: ID of the task being evaluated.
37
  :param completion_id: Completion ID to match results later.
38
  """
39
  manager = multiprocessing.Manager()
40
+ results = manager.list()
41
 
42
+ process = multiprocessing.Process(target=unsafe_execute, args=(program, test_inputs, test_outputs, results, timeout))
43
  process.start()
44
  process.join(timeout=timeout + 1)
45
 
46
  if process.is_alive():
47
  process.kill()
48
+ results.append("timed out")
49
 
50
  return dict(
51
  task_id=task_id,
52
+ passed=all(result.get("status") == "passed" for result in results),
53
+ results=list(results),
54
  completion_id=completion_id,
55
+ program=program,
56
  )
57
 
58
 
59
+ def unsafe_execute(program, test_inputs, test_outputs, results, timeout):
60
  """
61
  Executes the program with redirected STDIN and compares its STDOUT to the expected output.
 
62
  :param program: The program string to execute.
63
+ :param test_inputs: List of inputs to provide to the program via STDIN.
64
+ :param test_outputs: List of expected outputs to compare with STDOUT.
65
+ :param results: A multiprocessing.Manager().list() to store the execution results.
66
  :param timeout: Maximum execution time in seconds.
67
  """
68
  with create_tempdir():
 
69
  # These system calls are needed when cleaning up tempdir.
70
  import os
71
  import shutil
 
77
  # Disable functionalities that can make destructive changes to the test.
78
  reliability_guard()
79
 
80
+ # Run the program for each input-output pair.
81
  try:
82
  exec_globals = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ for test_input, test_output in zip(test_inputs, test_outputs):
85
+ input_stream = io.StringIO(test_input)
86
+ output_stream = io.StringIO()
87
+
88
+ try:
89
+ with swallow_io(input_stream, output_stream):
90
+ with time_limit(timeout):
91
+ exec(program, exec_globals)
92
+
93
+ actual_output = output_stream.getvalue().strip()
94
+ expected_output = test_output.strip()
95
+
96
+ if actual_output == expected_output:
97
+ results.append({"status": "passed", "input": test_input, "expected": test_output, "actual": actual_output})
98
+ else:
99
+ results.append({"status": "failed", "input": test_input, "expected": test_output, "actual": actual_output})
100
+
101
+ except TimeoutException:
102
+ results.append({"status": "timed out", "input": test_input})
103
+ except BaseException as e:
104
+ results.append({"status": "failed", "input": test_input, "error": str(e)})
105
+
106
+ finally:
107
+ # Restore system calls for cleaning up
108
+ shutil.rmtree = rmtree
109
+ os.rmdir = rmdir
110
+ os.chdir = chdir
111
 
112
 
113
  @contextlib.contextmanager