Spaces:
Runtime error
Runtime error
import datetime | |
import gc | |
import multiprocessing as mp | |
import pathlib | |
import subprocess | |
from dataclasses import dataclass | |
from typing import Dict, List | |
from tqdm import tqdm | |
class CommandResult: | |
return_code: int | |
runtime: float | |
stdout: str | |
stderr: str | |
timed_out: bool | |
def safe_execute( | |
command_to_run: List[str], | |
working_dir: pathlib.Path, | |
timeout: int = 10, | |
) -> CommandResult: | |
"""Executes a list of commands safely. | |
Args: | |
command_to_run: The command to run. | |
working_dir: The working directory to run them in. | |
timeout Timeout. | |
Returns: | |
The result of executing the command. | |
""" | |
timed_out = False | |
return_code = -1 | |
runtime = timeout | |
stderr = None | |
stdout = None | |
start_time = datetime.datetime.now() | |
try: | |
execution_process = subprocess.Popen( | |
command_to_run, | |
cwd=str(working_dir), | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
) | |
try: | |
outputs = execution_process.communicate(timeout=timeout) | |
stdout, stderr = outputs | |
stdout = stdout.decode('utf-8') | |
stderr = stderr.decode('utf-8') | |
runtime = (datetime.datetime.now() - start_time).total_seconds() | |
return_code = execution_process.returncode | |
except subprocess.TimeoutExpired: | |
timed_out = True | |
runtime = timeout | |
finally: | |
execution_process.kill() | |
except Exception as e: | |
stderr = str(e) | |
stdout = "" | |
return_code = -1 | |
runtime = -1 | |
timed_out = False | |
return CommandResult( | |
return_code=return_code, | |
runtime=runtime, | |
stderr=stderr, | |
stdout=stdout, | |
timed_out=timed_out, | |
) | |
def execute_code(sample: Dict): | |
"""Execute a file of code. | |
Args: | |
sample: The sample to run. | |
Returns: | |
The execution result. | |
""" | |
file_path = sample["cwd"] | |
working_dir_for_execution = ( | |
file_path.parent if file_path.is_file() else file_path | |
) | |
working_dir_for_execution = working_dir_for_execution.resolve().absolute() | |
timed_out = False | |
failed = False | |
results = [] | |
for command in sample['commands']: | |
res = safe_execute(command['command'], working_dir=working_dir_for_execution, timeout=command['timeout']) | |
results.append(res) | |
if res.timed_out: | |
timed_out = True | |
break | |
if res.return_code != 0: | |
failed = True | |
break | |
return { | |
"qid":sample['qid'], | |
"idx": sample["idx"], | |
"file_path": str(file_path.absolute().resolve()), | |
"results": results, | |
"failed":failed, | |
"timed_out": timed_out, | |
} | |
def execute_predictions( | |
predictions: List[Dict], | |
num_workers: int = 1, | |
max_task_per_child: int = 1, | |
garbage_collection_freq: int = 500, | |
): | |
"""Execute a list of predictions in a specific language. | |
Args: | |
predictions: List of predictions. | |
num_workers: The number of workers to use. | |
max_task_per_child: The maximum tasks ran per child before it is killed. | |
garbage_collection_freq: How often to run garbage collection. | |
Returns: | |
The the array of raw execution results and the total runtime. | |
""" | |
# Make the arguments to submit to the ThreadPoolExecutor. Do it here so we | |
# can have a progress bar as well. | |
num_to_complete = len(predictions) | |
num_completed = 0 | |
results = [] | |
with mp.Pool(num_workers, maxtasksperchild=max_task_per_child) as pool: | |
for result in tqdm( | |
pool.imap_unordered(execute_code, predictions), | |
total=num_to_complete, | |
desc="Executing", | |
): | |
num_completed += 1 | |
results.append(result) | |
if num_completed % garbage_collection_freq == 0: | |
gc.collect() | |
# Cleanup pool | |
pool.close() | |
pool.terminate() | |
return results | |