Kewen Zhao commited on
Commit
af3f724
·
1 Parent(s): 47feec4

stdio format

Browse files
Files changed (2) hide show
  1. code_eval_stdio.py +3 -4
  2. execute.py +45 -35
code_eval_stdio.py CHANGED
@@ -152,7 +152,7 @@ class CodeEval(evaluate.Metric):
152
  license=_LICENSE,
153
  )
154
 
155
- def _compute(self, predictions, references, k=[1, 10, 100], num_workers=4, timeout=3.0):
156
  """Returns the scores"""
157
 
158
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
@@ -167,10 +167,9 @@ class CodeEval(evaluate.Metric):
167
  n_samples = 0
168
  results = defaultdict(list)
169
 
170
- for task_id, (candidates, test_case) in enumerate(zip(predictions, references)):
171
  for candidate in candidates:
172
- test_program = candidate + "\n" + test_case
173
- args = (test_program, timeout, task_id, completion_id[task_id])
174
  future = executor.submit(check_correctness, *args)
175
  futures.append(future)
176
  completion_id[task_id] += 1
 
152
  license=_LICENSE,
153
  )
154
 
155
+ def _compute(self, program, test_input, test_output, k=[1, 10, 100], num_workers=4, timeout=3.0):
156
  """Returns the scores"""
157
 
158
  if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
 
167
  n_samples = 0
168
  results = defaultdict(list)
169
 
170
+ for task_id, (candidates, inputs, outputs) in enumerate(zip(program, test_input, test_output)):
171
  for candidate in candidates:
172
+ args = (candidate, inputs, outputs, timeout, task_id, completion_id[task_id])
 
173
  future = executor.submit(check_correctness, *args)
174
  futures.append(future)
175
  completion_id[task_id] += 1
execute.py CHANGED
@@ -25,24 +25,27 @@ import signal
25
  import tempfile
26
 
27
 
28
- def check_correctness(check_program, timeout, task_id, completion_id):
29
  """
30
  Evaluates the functional correctness of a completion by running the test
31
  suite provided in the problem.
32
 
33
- :param completion_id: an optional completion ID so we can match
34
- the results later even if execution finishes asynchronously.
 
 
 
 
35
  """
36
  manager = multiprocessing.Manager()
37
  result = manager.list()
38
 
39
- p = multiprocessing.Process(target=unsafe_execute, args=(check_program, result, timeout))
40
- p.start()
41
- p.join(timeout=timeout + 1)
42
- if p.is_alive():
43
- p.kill()
44
 
45
- if not result:
 
46
  result.append("timed out")
47
 
48
  return dict(
@@ -53,8 +56,16 @@ def check_correctness(check_program, timeout, task_id, completion_id):
53
  )
54
 
55
 
56
- def unsafe_execute(check_program, result, timeout):
 
 
57
 
 
 
 
 
 
 
58
  with create_tempdir():
59
 
60
  # These system calls are needed when cleaning up tempdir.
@@ -71,10 +82,24 @@ def unsafe_execute(check_program, result, timeout):
71
  # Run program.
72
  try:
73
  exec_globals = {}
74
- with swallow_io():
 
 
 
 
 
 
75
  with time_limit(timeout):
76
- exec(check_program, exec_globals)
77
- result.append("passed")
 
 
 
 
 
 
 
 
78
  except TimeoutException:
79
  result.append("timed out")
80
  except BaseException as e:
@@ -100,11 +125,13 @@ def time_limit(seconds):
100
 
101
 
102
  @contextlib.contextmanager
103
- def swallow_io():
104
- stream = WriteOnlyStringIO()
105
- with contextlib.redirect_stdout(stream):
106
- with contextlib.redirect_stderr(stream):
107
- with redirect_stdin(stream):
 
 
108
  yield
109
 
110
 
@@ -119,23 +146,6 @@ class TimeoutException(Exception):
119
  pass
120
 
121
 
122
- class WriteOnlyStringIO(io.StringIO):
123
- """StringIO that throws an exception when it's read from"""
124
-
125
- def read(self, *args, **kwargs):
126
- raise OSError
127
-
128
- def readline(self, *args, **kwargs):
129
- raise OSError
130
-
131
- def readlines(self, *args, **kwargs):
132
- raise OSError
133
-
134
- def readable(self, *args, **kwargs):
135
- """Returns True if the IO object can be read."""
136
- return False
137
-
138
-
139
  class redirect_stdin(contextlib._RedirectStream): # type: ignore
140
  _stream = "stdin"
141
 
 
25
  import tempfile
26
 
27
 
28
+ def check_correctness(program, test_input, test_output, timeout, task_id, completion_id):
29
  """
30
  Evaluates the functional correctness of a completion by running the test
31
  suite provided in the problem.
32
 
33
+ :param program: The program string to evaluate.
34
+ :param test_input: The input string to provide to the program via STDIN.
35
+ :param test_output: The expected output string from the program via STDOUT.
36
+ :param timeout: Maximum execution time in seconds.
37
+ :param task_id: ID of the task being evaluated.
38
+ :param completion_id: Completion ID to match results later.
39
  """
40
  manager = multiprocessing.Manager()
41
  result = manager.list()
42
 
43
+ process = multiprocessing.Process(target=unsafe_execute, args=(program, test_input, test_output, result, timeout))
44
+ process.start()
45
+ process.join(timeout=timeout + 1)
 
 
46
 
47
+ if process.is_alive():
48
+ process.kill()
49
  result.append("timed out")
50
 
51
  return dict(
 
56
  )
57
 
58
 
59
+ def unsafe_execute(program, test_input, test_output, result, timeout):
60
+ """
61
+ Executes the program with redirected STDIN and compares its STDOUT to the expected output.
62
 
63
+ :param program: The program string to execute.
64
+ :param test_input: Input to provide to the program via STDIN.
65
+ :param test_output: Expected output to compare with STDOUT.
66
+ :param result: A multiprocessing.Manager().list() to store the execution result.
67
+ :param timeout: Maximum execution time in seconds.
68
+ """
69
  with create_tempdir():
70
 
71
  # These system calls are needed when cleaning up tempdir.
 
82
  # Run program.
83
  try:
84
  exec_globals = {}
85
+ actual_output = None
86
+
87
+ # Redirect I/O and execute
88
+ input_stream = io.StringIO(test_input)
89
+ output_stream = io.StringIO()
90
+
91
+ with swallow_io(input_stream, output_stream):
92
  with time_limit(timeout):
93
+ exec(program, exec_globals)
94
+
95
+ actual_output = output_stream.getvalue().strip()
96
+ expected_output = test_output.strip()
97
+
98
+ if actual_output == expected_output:
99
+ result.append("passed")
100
+ else:
101
+ result.append(f"failed: got '{actual_output}', expected '{expected_output}'")
102
+
103
  except TimeoutException:
104
  result.append("timed out")
105
  except BaseException as e:
 
125
 
126
 
127
  @contextlib.contextmanager
128
+ def swallow_io(input_stream, output_stream):
129
+ """
130
+ Redirects STDIN, STDOUT, and STDERR for isolated execution.
131
+ """
132
+ with contextlib.redirect_stdout(output_stream):
133
+ with contextlib.redirect_stderr(output_stream):
134
+ with redirect_stdin(input_stream):
135
  yield
136
 
137
 
 
146
  pass
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  class redirect_stdin(contextlib._RedirectStream): # type: ignore
150
  _stream = "stdin"
151