|
from collections import OrderedDict, deque |
|
from datetime import date, time, datetime |
|
from decimal import Decimal |
|
from fractions import Fraction |
|
import ast |
|
import enum |
|
import typing |
|
|
|
|
|
class CannotEval(Exception): |
|
def __repr__(self): |
|
return self.__class__.__name__ |
|
|
|
__str__ = __repr__ |
|
|
|
|
|
def is_any(x, *args): |
|
return any( |
|
x is arg |
|
for arg in args |
|
) |
|
|
|
|
|
def of_type(x, *types): |
|
if is_any(type(x), *types): |
|
return x |
|
else: |
|
raise CannotEval |
|
|
|
|
|
def of_standard_types(x, *, check_dict_values: bool, deep: bool): |
|
if is_standard_types(x, check_dict_values=check_dict_values, deep=deep): |
|
return x |
|
else: |
|
raise CannotEval |
|
|
|
|
|
def is_standard_types(x, *, check_dict_values: bool, deep: bool): |
|
try: |
|
return _is_standard_types_deep(x, check_dict_values, deep)[0] |
|
except RecursionError: |
|
return False |
|
|
|
|
|
def _is_standard_types_deep(x, check_dict_values: bool, deep: bool): |
|
typ = type(x) |
|
if is_any( |
|
typ, |
|
str, |
|
int, |
|
bool, |
|
float, |
|
bytes, |
|
complex, |
|
date, |
|
time, |
|
datetime, |
|
Fraction, |
|
Decimal, |
|
type(None), |
|
object, |
|
): |
|
return True, 0 |
|
|
|
if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice): |
|
if typ in [slice]: |
|
length = 0 |
|
else: |
|
length = len(x) |
|
assert isinstance(deep, bool) |
|
if not deep: |
|
return True, length |
|
|
|
if check_dict_values and typ in (dict, OrderedDict): |
|
items = (v for pair in x.items() for v in pair) |
|
elif typ is slice: |
|
items = [x.start, x.stop, x.step] |
|
else: |
|
items = x |
|
for item in items: |
|
if length > 100000: |
|
return False, length |
|
is_standard, item_length = _is_standard_types_deep( |
|
item, check_dict_values, deep |
|
) |
|
if not is_standard: |
|
return False, length |
|
length += item_length |
|
return True, length |
|
|
|
return False, 0 |
|
|
|
|
|
class _E(enum.Enum): |
|
pass |
|
|
|
|
|
class _C: |
|
def foo(self): pass |
|
|
|
def bar(self): pass |
|
|
|
@classmethod |
|
def cm(cls): pass |
|
|
|
@staticmethod |
|
def sm(): pass |
|
|
|
|
|
safe_name_samples = { |
|
"len": len, |
|
"append": list.append, |
|
"__add__": list.__add__, |
|
"insert": [].insert, |
|
"__mul__": [].__mul__, |
|
"fromkeys": dict.__dict__['fromkeys'], |
|
"is_any": is_any, |
|
"__repr__": CannotEval.__repr__, |
|
"foo": _C().foo, |
|
"bar": _C.bar, |
|
"cm": _C.cm, |
|
"sm": _C.sm, |
|
"ast": ast, |
|
"CannotEval": CannotEval, |
|
"_E": _E, |
|
} |
|
|
|
typing_annotation_samples = { |
|
name: getattr(typing, name) |
|
for name in "List Dict Tuple Set Callable Mapping".split() |
|
} |
|
|
|
safe_name_types = tuple({ |
|
type(f) |
|
for f in safe_name_samples.values() |
|
}) |
|
|
|
|
|
typing_annotation_types = tuple({ |
|
type(f) |
|
for f in typing_annotation_samples.values() |
|
}) |
|
|
|
|
|
def eq_checking_types(a, b): |
|
return type(a) is type(b) and a == b |
|
|
|
|
|
def ast_name(node): |
|
if isinstance(node, ast.Name): |
|
return node.id |
|
elif isinstance(node, ast.Attribute): |
|
return node.attr |
|
else: |
|
return None |
|
|
|
|
|
def safe_name(value): |
|
typ = type(value) |
|
if is_any(typ, *safe_name_types): |
|
return value.__name__ |
|
elif value is typing.Optional: |
|
return "Optional" |
|
elif value is typing.Union: |
|
return "Union" |
|
elif is_any(typ, *typing_annotation_types): |
|
return getattr(value, "__name__", None) or getattr(value, "_name", None) |
|
else: |
|
return None |
|
|
|
|
|
def has_ast_name(value, node): |
|
value_name = safe_name(value) |
|
if type(value_name) is not str: |
|
return False |
|
return eq_checking_types(ast_name(node), value_name) |
|
|
|
|
|
def copy_ast_without_context(x): |
|
if isinstance(x, ast.AST): |
|
kwargs = { |
|
field: copy_ast_without_context(getattr(x, field)) |
|
for field in x._fields |
|
if field != 'ctx' |
|
if hasattr(x, field) |
|
} |
|
a = type(x)(**kwargs) |
|
if hasattr(a, 'ctx'): |
|
|
|
|
|
del a.ctx |
|
return a |
|
elif isinstance(x, list): |
|
return list(map(copy_ast_without_context, x)) |
|
else: |
|
return x |
|
|
|
|
|
def ensure_dict(x): |
|
""" |
|
Handles invalid non-dict inputs |
|
""" |
|
try: |
|
return dict(x) |
|
except Exception: |
|
return {} |
|
|