File size: 4,612 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from collections import OrderedDict, deque
from datetime import date, time, datetime
from decimal import Decimal
from fractions import Fraction
import ast
import enum
import typing
class CannotEval(Exception):
def __repr__(self):
return self.__class__.__name__
__str__ = __repr__
def is_any(x, *args):
return any(
x is arg
for arg in args
)
def of_type(x, *types):
if is_any(type(x), *types):
return x
else:
raise CannotEval
def of_standard_types(x, *, check_dict_values: bool, deep: bool):
if is_standard_types(x, check_dict_values=check_dict_values, deep=deep):
return x
else:
raise CannotEval
def is_standard_types(x, *, check_dict_values: bool, deep: bool):
try:
return _is_standard_types_deep(x, check_dict_values, deep)[0]
except RecursionError:
return False
def _is_standard_types_deep(x, check_dict_values: bool, deep: bool):
typ = type(x)
if is_any(
typ,
str,
int,
bool,
float,
bytes,
complex,
date,
time,
datetime,
Fraction,
Decimal,
type(None),
object,
):
return True, 0
if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice):
if typ in [slice]:
length = 0
else:
length = len(x)
assert isinstance(deep, bool)
if not deep:
return True, length
if check_dict_values and typ in (dict, OrderedDict):
items = (v for pair in x.items() for v in pair)
elif typ is slice:
items = [x.start, x.stop, x.step]
else:
items = x
for item in items:
if length > 100000:
return False, length
is_standard, item_length = _is_standard_types_deep(
item, check_dict_values, deep
)
if not is_standard:
return False, length
length += item_length
return True, length
return False, 0
class _E(enum.Enum):
pass
class _C:
def foo(self): pass # pragma: nocover
def bar(self): pass # pragma: nocover
@classmethod
def cm(cls): pass # pragma: nocover
@staticmethod
def sm(): pass # pragma: nocover
safe_name_samples = {
"len": len,
"append": list.append,
"__add__": list.__add__,
"insert": [].insert,
"__mul__": [].__mul__,
"fromkeys": dict.__dict__['fromkeys'],
"is_any": is_any,
"__repr__": CannotEval.__repr__,
"foo": _C().foo,
"bar": _C.bar,
"cm": _C.cm,
"sm": _C.sm,
"ast": ast,
"CannotEval": CannotEval,
"_E": _E,
}
typing_annotation_samples = {
name: getattr(typing, name)
for name in "List Dict Tuple Set Callable Mapping".split()
}
safe_name_types = tuple({
type(f)
for f in safe_name_samples.values()
})
typing_annotation_types = tuple({
type(f)
for f in typing_annotation_samples.values()
})
def eq_checking_types(a, b):
return type(a) is type(b) and a == b
def ast_name(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return node.attr
else:
return None
def safe_name(value):
typ = type(value)
if is_any(typ, *safe_name_types):
return value.__name__
elif value is typing.Optional:
return "Optional"
elif value is typing.Union:
return "Union"
elif is_any(typ, *typing_annotation_types):
return getattr(value, "__name__", None) or getattr(value, "_name", None)
else:
return None
def has_ast_name(value, node):
value_name = safe_name(value)
if type(value_name) is not str:
return False
return eq_checking_types(ast_name(node), value_name)
def copy_ast_without_context(x):
if isinstance(x, ast.AST):
kwargs = {
field: copy_ast_without_context(getattr(x, field))
for field in x._fields
if field != 'ctx'
if hasattr(x, field)
}
a = type(x)(**kwargs)
if hasattr(a, 'ctx'):
# Python 3.13.0b2+ defaults to Load when we don't pass ctx
# https://github.com/python/cpython/pull/118871
del a.ctx
return a
elif isinstance(x, list):
return list(map(copy_ast_without_context, x))
else:
return x
def ensure_dict(x):
"""
Handles invalid non-dict inputs
"""
try:
return dict(x)
except Exception:
return {}
|