Spaces:
Sleeping
Sleeping
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""This is an implementation of the `CodeEval` metric that uses `RestrictedPython` | |
to exectue the untrusted code returned by the model. | |
Lightly adapted and mostly copied verbatim from the implementation in `evaluate`. | |
""" | |
import ast | |
import contextlib | |
import copy | |
import faulthandler | |
import itertools | |
import importlib | |
import importlib.util | |
import io | |
import multiprocessing | |
import os | |
import platform | |
import signal | |
import sys | |
import tempfile | |
import types | |
from typing import Optional, Dict, List, Any | |
from collections import Counter, defaultdict | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import evaluate | |
# from evaluate.metrics import code_eval | |
import datasets | |
import numpy as np | |
from RestrictedPython import compile_restricted, safe_builtins, limited_builtins, utility_builtins, RestrictingNodeTransformer | |
from RestrictedPython.transformer import copy_locations, IOPERATOR_TO_STR, FORBIDDEN_FUNC_NAMES | |
from RestrictedPython.Eval import default_guarded_getiter, default_guarded_getitem | |
from RestrictedPython.Guards import guarded_iter_unpack_sequence, safer_getattr, guarded_unpack_sequence | |
SAFE_ATTRIBUTES = ['__add__', '__ge__', '__gt__', '__le__', '__lt__', '__mul__', '__ne__', '__rmul__', '__str__',] | |
# patch their list implementation to allow empty lists and tuples | |
def limited_list(seq=None): | |
if isinstance(seq, str): | |
# raise TypeError('cannot convert string to list') | |
return [c for c in seq] | |
return list(seq) if seq is not None else list() | |
for attr in SAFE_ATTRIBUTES: | |
if hasattr(list, attr): | |
setattr(limited_list, attr, getattr(list, attr)) | |
limited_builtins['list'] = limited_list | |
def limited_tuple(seq=None): | |
if isinstance(seq, str): | |
# raise TypeError('cannot convert string to tuple') | |
return tuple([c for c in seq]) | |
return tuple(seq) if seq is not None else tuple() | |
for attr in SAFE_ATTRIBUTES: | |
if hasattr(tuple, attr): | |
setattr(limited_tuple, attr, getattr(tuple, attr)) | |
limited_builtins['tuple'] = limited_tuple | |
def limited_range(iFirst, *args): | |
# limited range function from Martijn Pieters | |
RANGELIMIT = 10000 | |
if not len(args): | |
iStart, iEnd, iStep = 0, iFirst, 1 | |
elif len(args) == 1: | |
iStart, iEnd, iStep = iFirst, args[0], 1 | |
elif len(args) == 2: | |
iStart, iEnd, iStep = iFirst, args[0], args[1] | |
else: | |
raise AttributeError('range() requires 1-3 int arguments') | |
if iStep == 0: | |
raise ValueError('zero step for range()') | |
iLen = int((iEnd - iStart) / iStep) | |
if iLen < 0: | |
iLen = 0 | |
if iLen >= RANGELIMIT: | |
raise ValueError( | |
'To be created range() object would be too large, ' | |
'in RestrictedPython we only allow {limit} ' | |
'elements in a range.'.format(limit=str(RANGELIMIT)), | |
) | |
return range(iStart, iEnd, iStep) | |
limited_builtins['range'] = limited_range | |
ALLOWED_UNDERSCORE_NAMES = ['__add__'] | |
def safer_getattr_allowing_string_format(object, name, default=None, getattr=getattr): | |
"""Getattr implementation allowing str.format(), but preventing access to | |
private attributes. | |
format() is considered harmful, so use at own risk: | |
http://lucumr.pocoo.org/2016/12/29/careful-with-str-format/ | |
""" | |
print('safer_getattr_allowing_string_format', object, name, default, ALLOWED_UNDERSCORE_NAMES) | |
if name.startswith('_') and name not in ALLOWED_UNDERSCORE_NAMES: | |
raise AttributeError( | |
'"{name}" is an invalid attribute name because it ' | |
'starts with "_"'.format(name=name) | |
) | |
return getattr(object, name, default) | |
class AllowAugmentedAssignRestrictingTransformer(RestrictingNodeTransformer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def visit_AugAssign(self, node): | |
# allow += and similar operations for list indices | |
if isinstance(node.target, ast.Subscript): | |
# new_name = copy.deepcopy(node.target.value) | |
# new_name.ctx = ast.Load() # type: ignore | |
new_subscript = copy.deepcopy(node.target) | |
new_subscript.ctx = ast.Load() | |
if hasattr(new_subscript.value, 'ctx'): | |
new_subscript.value.ctx = ast.Load() # type: ignore | |
new_node = ast.Assign( | |
targets=[node.target], | |
value=ast.Call( | |
func=ast.Name('_inplacevar_', ast.Load()), | |
args=[ | |
ast.Str(IOPERATOR_TO_STR[type(node.op)]), | |
new_subscript, | |
# ast.Subscript( | |
# value=new_name, | |
# slice=node.target.slice, | |
# ctx=ast.Load(), | |
# ), | |
node.value | |
], | |
keywords=[])) | |
copy_locations(new_node, node) | |
return new_node | |
return super().visit_AugAssign(node) | |
class AllowAugmentedAssignAndUnderscoreVariableNamesRestrictingTransformer(AllowAugmentedAssignRestrictingTransformer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def visit_Attribute(self, node): | |
"""Checks and mutates attribute access/assignment. | |
'a.b' becomes '_getattr_(a, "b")' | |
'a.b = c' becomes '_write_(a).b = c' | |
'del a.b' becomes 'del _write_(a).b' | |
The _write_ function should return a security proxy. | |
""" | |
# Overriding here to allow select underscore names | |
if node.attr.startswith('_') and node.attr != '_' and node.attr not in ALLOWED_UNDERSCORE_NAMES: | |
self.error( | |
node, | |
'"{name}" is an invalid attribute name because it starts ' | |
'with "_".'.format(name=node.attr)) | |
if node.attr.endswith('__roles__'): | |
self.error( | |
node, | |
'"{name}" is an invalid attribute name because it ends ' | |
'with "__roles__".'.format(name=node.attr)) | |
if isinstance(node.ctx, ast.Load): | |
node = self.node_contents_visit(node) | |
new_node = ast.Call( | |
func=ast.Name('_getattr_', ast.Load()), | |
args=[node.value, ast.Str(node.attr)], | |
keywords=[]) | |
copy_locations(new_node, node) | |
return new_node | |
elif isinstance(node.ctx, (ast.Store, ast.Del)): | |
node = self.node_contents_visit(node) | |
new_value = ast.Call( | |
func=ast.Name('_write_', ast.Load()), | |
args=[node.value], | |
keywords=[]) | |
copy_locations(new_value, node.value) | |
node.value = new_value | |
return node | |
else: # pragma: no cover | |
# Impossible Case only ctx Load, Store and Del are defined in ast. | |
raise NotImplementedError( | |
f"Unknown ctx type: {type(node.ctx)}") | |
def check_name(self, node, name, allow_magic_methods=False): | |
if name is None: | |
return | |
if name.startswith('_'): | |
# Verify it doesn't do anything else that's not allowed | |
if not name.endswith('__roles__') and not name in FORBIDDEN_FUNC_NAMES: | |
return | |
# Otherwise, flow to parent logic | |
return super().check_name(node, name, allow_magic_methods) | |
# TODO: Add BibTeX citation | |
_CITATION = """\ | |
@InProceedings{huggingface:module, | |
title = {A great new module}, | |
authors={huggingface, Inc.}, | |
year={2020} | |
} | |
""" | |
# TODO: Add description of the module here | |
_DESCRIPTION = """\ | |
This module implements the same logic as the baseline `code_eval` module but using RestrictedPython. | |
""" | |
# TODO: Add description of the arguments of the module here | |
_KWARGS_DESCRIPTION = """ | |
Calculates how good are predictions given some references, using certain scores | |
Args: | |
predictions: list of candidates to evaluate. Each candidates should be a list | |
of strings with several code candidates to solve the problem. | |
references: a list with a test for each prediction. Each test should evaluate the | |
correctness of a code candidate. | |
k: number of code candidates to consider in the evaluation (Default: [1, 10, 100]) | |
num_workers: number of workers used to evaluate the canidate programs (Default: 4). | |
timeout: | |
use_safe_builtins: a bool indicating whether to use the `RestrictedPython.safe_builtins` | |
use_limited_builtins: a bool indicating whether to use the `RestrictedPython.limited_builtins` | |
use_utility_builtins: a bool indicating whether to use the `RestrictedPython.utility_builtins` | |
additional_globals: a optional dict of additional globals to pass to the RestrictedPython interpreter | |
additional_locals: a optional dict of additional locals to pass to the RestrictedPython interpreter | |
allowed_imports: an optional list of string, modules the tested code is allowed to import | |
allow_str_format: a bool indicating whether to allow the use of str.format() in the tested code | |
allow_underscore_variable_names: a bool indicating whether to allow the use of underscore variable names in the tested code | |
return_output: a bool indicating whether to return the output of the tested code | |
output_variable: a string indicating the name of the variable to return if return_output is True | |
Returns: | |
pass_at_k: dict with pass rates for each k | |
results: dict with granular results of each unittest | |
Examples: | |
>>> code_eval = evaluate.load("RestrictedPython_code_eval") | |
>>> test_cases = ["assert add(2,3)==5"] | |
>>> candidates = [["def add(a,b): return a*b", "def add(a, b): return a+b"]] | |
>>> pass_at_k, results = code_eval.compute(references=test_cases, predictions=candidates, k=[1, 2]) | |
>>> print(pass_at_k) | |
{'pass@1': 0.5, 'pass@2': 1.0} | |
""" | |
_WARNING = """ | |
################################################################################ | |
!!!WARNING!!! | |
################################################################################ | |
The "code_eval" metric executes untrusted model-generated code in Python. | |
Although it is highly unlikely that model-generated code will do something | |
overtly malicious in response to this test suite, model-generated code may act | |
destructively due to a lack of model capability or alignment. | |
Users are strongly encouraged to sandbox this evaluation suite so that it | |
does not perform destructive actions on their host or network. For more | |
information on how OpenAI sandboxes its code, see the paper "Evaluating Large | |
Language Models Trained on Code" (https://arxiv.org/abs/2107.03374). | |
Once you have read this disclaimer and taken appropriate precautions, | |
set the environment variable HF_ALLOW_CODE_EVAL="1". Within Python you can to this | |
with: | |
import os | |
os.environ["HF_ALLOW_CODE_EVAL"] = "1" | |
################################################################################\ | |
""" | |
# TODO: who has the copyright? | |
_LICENSE = """The MIT License | |
Copyright (c) OpenAI (https://openai.com) | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in | |
all copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |
THE SOFTWARE.""" | |
class RestrictedPythonCodeEval(evaluate.Metric): | |
"""Exactly the same as the built in `code_eval` module, but using restricted python""" | |
def _info(self): | |
# TODO: Specifies the evaluate.EvaluationModuleInfo object | |
return evaluate.MetricInfo( | |
# This is the description that will appear on the modules page. | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
# This defines the format of each prediction and reference | |
features=datasets.Features({ | |
'predictions': datasets.Sequence(datasets.Value("string")), | |
'references': datasets.Value('string'), | |
}), | |
# Homepage of the module for documentation | |
homepage="http://module.homepage", | |
# Additional links to the codebase or references | |
codebase_urls=["http://github.com/path/to/codebase/of/new_module"], | |
reference_urls=["http://path.to.reference.url/new_module"] | |
) | |
def _compute(self, predictions, references, k=[1, 10, 100], num_workers=4, timeout=3.0, | |
use_safe_builtins: bool = True, use_limited_builtins: bool = True, use_utility_builtins: bool = True, | |
additional_globals: Optional[Dict[str, Any]] = None, additional_locals: Optional[Dict[str, Any]] = None, | |
allowed_imports: Optional[List[str]] = None, allow_str_format: bool = False, | |
allow_underscore_variable_names: bool = False, return_output: bool = False, output_variable: str = "output"): | |
"""Returns the scores""" | |
if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1": | |
raise ValueError(_WARNING) | |
if os.name == "nt": | |
raise NotImplementedError("This metric is currently not supported on Windows.") | |
with ThreadPoolExecutor(max_workers=num_workers) as executor: | |
futures = [] | |
completion_id = Counter() | |
n_samples = 0 | |
results = defaultdict(list) | |
for task_id, (candidates, test_case) in enumerate(zip(predictions, references)): | |
for candidate in candidates: | |
test_program = candidate + "\n" + test_case | |
args = ( | |
test_program, timeout, task_id, completion_id[task_id], | |
use_safe_builtins, use_limited_builtins, use_utility_builtins, | |
additional_globals, additional_locals, | |
allowed_imports, allow_str_format, allow_underscore_variable_names, | |
return_output, output_variable, | |
) | |
future = executor.submit(_check_correctness, *args) | |
futures.append(future) | |
completion_id[task_id] += 1 | |
n_samples += 1 | |
for future in as_completed(futures): | |
result = future.result() | |
results[result["task_id"]].append((result["completion_id"], result)) | |
total, correct = [], [] | |
for result in results.values(): | |
result.sort() | |
passed = [r[1]["passed"] for r in result] | |
total.append(len(passed)) | |
correct.append(sum(passed)) | |
total = np.array(total) | |
correct = np.array(correct) | |
ks = k | |
pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()} | |
return pass_at_k, results | |
def estimate_pass_at_k(num_samples, num_correct, k): | |
"""Estimates pass@k of each problem and returns them in an array.""" | |
def estimator(n: int, c: int, k: int) -> float: | |
"""Calculates 1 - comb(n - c, k) / comb(n, k).""" | |
if n - c < k: | |
return 1.0 | |
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) # type: ignore | |
if isinstance(num_samples, int): | |
num_samples_it = itertools.repeat(num_samples, len(num_correct)) | |
else: | |
assert len(num_samples) == len(num_correct) | |
num_samples_it = iter(num_samples) | |
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) | |
def _check_correctness(check_program, timeout, task_id, completion_id, | |
use_safe_builtins: bool = True, use_limited_builtins: bool = True, use_utility_builtins: bool = True, | |
additional_globals: Optional[Dict[str, Any]] = None, additional_locals: Optional[Dict[str, Any]] = None, | |
allowed_imports: Optional[List[str]] = None, allow_str_format: bool = False, | |
allow_underscore_variable_names: bool = False, return_output: bool = False, output_variable: str = "output"): | |
""" | |
Evaluates the functional correctness of a completion by running the test | |
suite provided in the problem. | |
:param completion_id: an optional completion ID so we can match | |
the results later even if execution finishes asynchronously. | |
""" | |
manager = multiprocessing.Manager() | |
result = manager.list() | |
args = ( | |
check_program, result, timeout, | |
use_safe_builtins, use_limited_builtins, use_utility_builtins, | |
additional_globals, additional_locals, | |
allowed_imports, allow_str_format, allow_underscore_variable_names, | |
return_output, output_variable | |
) | |
p = multiprocessing.Process(target=_unsafe_execute, args=args) | |
p.start() | |
p.join(timeout=timeout + 1) | |
if p.is_alive(): | |
p.kill() | |
if not result: | |
result.append("Result evaluates to False (probably timed out)") | |
out_dict = dict( | |
task_id=task_id, | |
passed=result[-1] == "passed", | |
result=result[0], | |
completion_id=completion_id, | |
) | |
if 'failed' in result[0] and len(result) > 1: | |
exc = result[1] | |
out_dict["exception_type"] = type(exc).__name__ | |
out_dict["exception_description"] = str(exc) | |
return out_dict | |
# ALLOWED_SYS_NAMES = ['maxsize'] | |
class AllowListImporter: | |
def __init__(self, allowed_imports: List[str]): | |
self.allowed_imports = allowed_imports | |
def __call__(self, name, globals=None, locals=None, fromlist=(), level=0): | |
if name.startswith('.'): | |
raise ImportError("Relative imports are not allowed.") | |
if '.' in name: | |
package_name, _ = name.split('.', 1) | |
else: | |
package_name = name | |
if package_name in self.allowed_imports: | |
return importlib.__import__(name, globals, locals, fromlist, level) | |
def _default_write_(obj): | |
if isinstance(obj, types.ModuleType): | |
raise ValueError("Modules are not allowed in to be written to.") | |
return obj | |
class DefaultPrinter: | |
def __init__(self, _getattr_=None, *args, **kwargs): | |
self._getattr_ = _getattr_ | |
self.txt = [] | |
self.args = args | |
self.kwargs = kwargs | |
def write(self, text): | |
self.txt.append(text) | |
print(text) | |
def __call__(self): | |
return ''.join(self.txt) | |
def _call_print(self, *objects, **kwargs): | |
if kwargs.get('file', None) is None: | |
kwargs['file'] = self | |
else: | |
self._getattr_(kwargs['file'], 'write') # type: ignore | |
print(*objects, **kwargs) | |
def _unsafe_execute(check_program, result, timeout, | |
use_safe_builtins: bool = True, use_limited_builtins: bool = True, use_utility_builtins: bool = True, | |
additional_globals: Optional[Dict[str, Any]] = None, additional_locals: Optional[Dict[str, Any]] = None, | |
allowed_imports: Optional[List[str]] = None, allow_str_format: bool = False, | |
allow_underscore_variable_names: bool = False, return_output: bool = False, output_variable: str = "output"): | |
with create_tempdir(): | |
# These system calls are needed when cleaning up tempdir. | |
import os | |
import shutil | |
rmtree = shutil.rmtree | |
rmdir = os.rmdir | |
chdir = os.chdir | |
# Disable functionalities that can make destructive changes to the test. | |
reliability_guard() | |
if return_output and additional_locals is None: | |
additional_locals = {} | |
# Run program. | |
try: | |
builtins = {} | |
if use_safe_builtins: | |
builtins.update(safe_builtins) | |
if use_limited_builtins: | |
builtins.update(limited_builtins) | |
if use_utility_builtins: | |
builtins.update(utility_builtins) | |
exec_globals = {'__builtins__': builtins} | |
if additional_globals is None: | |
additional_globals = {} | |
for key, glob in additional_globals.items(): | |
if key not in exec_globals: | |
exec_globals[key] = glob | |
else: | |
exec_globals[key].update(glob) | |
if allowed_imports is not None: | |
if '__import__' in exec_globals['__builtins__']: | |
raise ValueError("Cannot specify allowed_imports when __import__ is in additional_globals.") | |
exec_globals['__builtins__']['__import__'] = AllowListImporter(allowed_imports) | |
if allow_str_format: | |
exec_globals['getattr'] = safer_getattr_allowing_string_format # type: ignore | |
exec_globals['__builtins__']['_getattr_'] = safer_getattr_allowing_string_format | |
if '__metaclass__' not in exec_globals: | |
exec_globals['__metaclass__'] = type # type: ignore | |
if '__name__' not in exec_globals: | |
exec_globals['__name__'] = '__main__' # type: ignore | |
if '_getiter_' not in exec_globals: | |
exec_globals['_getiter_'] = default_guarded_getiter # type: ignore | |
if '_iter_unpack_sequence_' not in exec_globals: | |
exec_globals['_iter_unpack_sequence_'] = guarded_iter_unpack_sequence # type: ignore | |
if '_unpack_sequence_' not in exec_globals: | |
exec_globals['_unpack_sequence_'] = guarded_unpack_sequence # type: ignore | |
if '_getitem_' not in exec_globals: | |
exec_globals['_getitem_'] = default_guarded_getitem # type: ignore | |
if 'getattr' not in exec_globals: | |
exec_globals['getattr'] = safer_getattr # type: ignore | |
exec_globals['__builtins__']['_getattr_'] = safer_getattr | |
if '_write_' not in exec_globals: | |
exec_globals['_write_'] = _default_write_ # type: ignore | |
if '_inplacevar_' not in exec_globals: | |
exec_globals['_inplacevar_'] = protected_inplacevar # type: ignore | |
if '_print_' not in exec_globals: | |
exec_globals['_print_'] = DefaultPrinter # type: ignore | |
if '_apply_' not in exec_globals: | |
exec_globals['_apply_'] = _apply # type: ignore | |
with swallow_io(): | |
policy_class = AllowAugmentedAssignAndUnderscoreVariableNamesRestrictingTransformer if allow_underscore_variable_names else AllowAugmentedAssignRestrictingTransformer | |
with time_limit(timeout): | |
byte_code = compile_restricted(check_program, filename="<model output>", mode="exec", policy=policy_class) | |
exec(byte_code, exec_globals, additional_locals) | |
if return_output: | |
result.append(additional_locals[output_variable]) | |
result.append("passed") | |
except EOFError: | |
result.append("EOF error") | |
except TimeoutException: | |
result.append("timed out") | |
except BaseException as e: | |
result.append(f"failed ({type(e)}): {str(e)}") | |
result.append(e) | |
# Needed for cleaning up. | |
shutil.rmtree = rmtree | |
os.rmdir = rmdir | |
os.chdir = chdir | |
def time_limit(seconds): | |
def signal_handler(signum, frame): | |
raise TimeoutException("Timed out!") | |
signal.setitimer(signal.ITIMER_REAL, seconds) | |
signal.signal(signal.SIGALRM, signal_handler) | |
try: | |
yield | |
finally: | |
signal.setitimer(signal.ITIMER_REAL, 0) | |
def swallow_io(): | |
stream = WriteOnlyStringIO() | |
with contextlib.redirect_stdout(stream): | |
with contextlib.redirect_stderr(stream): | |
with redirect_stdin(stream): | |
yield | |
def create_tempdir(): | |
with tempfile.TemporaryDirectory() as dirname: | |
with chdir(dirname): | |
yield dirname | |
class TimeoutException(Exception): | |
pass | |
class WriteOnlyStringIO(io.StringIO): | |
"""StringIO that throws an exception when it's read from""" | |
def read(self, *args, **kwargs): | |
raise OSError | |
def readline(self, *args, **kwargs): | |
raise OSError | |
def readlines(self, *args, **kwargs): | |
raise OSError | |
def readable(self, *args, **kwargs): | |
"""Returns True if the IO object can be read.""" | |
return False | |
class redirect_stdin(contextlib._RedirectStream): # type: ignore | |
_stream = "stdin" | |
def chdir(root): | |
if root == ".": | |
yield | |
return | |
cwd = os.getcwd() | |
os.chdir(root) | |
try: | |
yield | |
except BaseException as exc: | |
raise exc | |
finally: | |
os.chdir(cwd) | |
def reliability_guard(maximum_memory_bytes=None): | |
""" | |
This disables various destructive functions and prevents the generated code | |
from interfering with the test (e.g. fork bomb, killing other processes, | |
removing filesystem files, etc.) | |
WARNING | |
This function is NOT a security sandbox. Untrusted code, including, model- | |
generated code, should not be blindly executed outside of one. See the | |
Codex paper for more information about OpenAI's code sandbox, and proceed | |
with caution. | |
""" | |
if maximum_memory_bytes is not None: | |
import resource | |
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) | |
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) | |
if not platform.uname().system == "Darwin": | |
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) | |
faulthandler.disable() | |
import builtins | |
builtins.exit = None | |
builtins.quit = None | |
import os | |
os.environ["OMP_NUM_THREADS"] = "1" | |
os.kill = None | |
os.system = None | |
os.putenv = None | |
os.remove = None | |
os.removedirs = None | |
os.rmdir = None | |
os.fchdir = None | |
os.setuid = None | |
os.fork = None | |
os.forkpty = None | |
os.killpg = None | |
os.rename = None | |
os.renames = None | |
os.truncate = None | |
os.replace = None | |
os.unlink = None | |
os.fchmod = None | |
os.fchown = None | |
os.chmod = None | |
os.chown = None | |
os.chroot = None | |
os.fchdir = None | |
os.lchflags = None | |
os.lchmod = None | |
os.lchown = None | |
os.getcwd = None | |
os.chdir = None | |
import shutil | |
shutil.rmtree = None | |
shutil.move = None | |
shutil.chown = None | |
import subprocess | |
subprocess.Popen = None # type: ignore | |
__builtins__["help"] = None | |
import sys | |
sys.modules["ipdb"] = None # type: ignore | |
sys.modules["joblib"] = None # type: ignore | |
sys.modules["resource"] = None # type: ignore | |
sys.modules["psutil"] = None # type: ignore | |
sys.modules["tkinter"] = None # type: ignore | |
""" | |
Borrowed implementation of _inplacevar_ from the Zope Foundations's AccessControl module | |
https://github.com/zopefoundation/AccessControl/blob/f9ae58816f0712eb6ea97459b4ccafbf4662d9db/src/AccessControl/ZopeGuards.py#L530 | |
""" | |
valid_inplace_types = (list, set) | |
inplace_slots = { | |
'+=': '__iadd__', | |
'-=': '__isub__', | |
'*=': '__imul__', | |
'/=': (1 / 2 == 0) and '__idiv__' or '__itruediv__', | |
'//=': '__ifloordiv__', | |
'%=': '__imod__', | |
'**=': '__ipow__', | |
'<<=': '__ilshift__', | |
'>>=': '__irshift__', | |
'&=': '__iand__', | |
'^=': '__ixor__', | |
'|=': '__ior__', | |
} | |
def __iadd__(x, y): | |
x += y | |
return x | |
def __isub__(x, y): | |
x -= y | |
return x | |
def __imul__(x, y): | |
x *= y | |
return x | |
def __idiv__(x, y): | |
x /= y | |
return x | |
def __ifloordiv__(x, y): | |
x //= y | |
return x | |
def __imod__(x, y): | |
x %= y | |
return x | |
def __ipow__(x, y): | |
x **= y | |
return x | |
def __ilshift__(x, y): | |
x <<= y | |
return x | |
def __irshift__(x, y): | |
x >>= y | |
return x | |
def __iand__(x, y): | |
x &= y | |
return x | |
def __ixor__(x, y): | |
x ^= y | |
return x | |
def __ior__(x, y): | |
x |= y | |
return x | |
inplace_ops = { | |
'+=': __iadd__, | |
'-=': __isub__, | |
'*=': __imul__, | |
'/=': __idiv__, | |
'//=': __ifloordiv__, | |
'%=': __imod__, | |
'**=': __ipow__, | |
'<<=': __ilshift__, | |
'>>=': __irshift__, | |
'&=': __iand__, | |
'^=': __ixor__, | |
'|=': __ior__, | |
} | |
def protected_inplacevar(op, var, expr): | |
"""Do an inplace operation | |
If the var has an inplace slot, then disallow the operation | |
unless the var an instance of ``valid_inplace_types``. | |
""" | |
if hasattr(var, inplace_slots[op]) and \ | |
not isinstance(var, valid_inplace_types): | |
try: | |
cls = var.__class__ | |
except AttributeError: | |
cls = type(var) | |
raise TypeError( | |
"Augmented assignment to %s objects is not allowed" | |
" in untrusted code" % cls.__name__) | |
return inplace_ops[op](var, expr) | |
def _apply(f, *a, **kw): | |
return f(*a, **kw) |