File size: 11,785 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import inspect
from inspect import Parameter
from types import FunctionType
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar, Union, get_type_hints
from .typing_utils import get_args, issubtype
_WrappedMethod = TypeVar("_WrappedMethod", bound=Union[FunctionType, Callable])
_WrappedMethod2 = TypeVar("_WrappedMethod2", bound=Union[FunctionType, Callable])
def _contains_unbound_typevar(t: Type) -> bool:
"""Recursively check if `t` or any types contained by `t` is a `TypeVar`.
Examples where we return `True`: `T`, `Optional[T]`, `Tuple[Optional[T], ...]`, ...
Examples where we return `False`: `int`, `Optional[str]`, ...
:param t: Type to evaluate.
:return: `True` if the input type contains an unbound `TypeVar`, `False` otherwise.
"""
# Check self
if isinstance(t, TypeVar):
return True
# Check children
for arg in get_args(t):
if _contains_unbound_typevar(arg):
return True
return False
def _issubtype(left, right):
if _contains_unbound_typevar(left):
return True
if right is None:
return True
if _contains_unbound_typevar(right):
return True
try:
return issubtype(left, right)
except TypeError:
# Ignore all broken cases
return True
def _get_type_hints(callable) -> Optional[Dict]:
try:
return get_type_hints(callable)
except (NameError, TypeError):
return None
def _is_same_module(callable1: _WrappedMethod, callable2: _WrappedMethod2) -> bool:
mod1 = callable1.__module__.split(".")[0]
# "__module__" attribute may be missing in CPython or it can be None
# in PyPy: https://github.com/mkorpela/overrides/issues/118
mod2 = getattr(callable2, "__module__", None)
if mod2 is None:
return False
mod2 = mod2.split(".")[0]
return mod1 == mod2
def ensure_signature_is_compatible(
super_callable: _WrappedMethod,
sub_callable: _WrappedMethod2,
is_static: bool = False,
) -> None:
"""Ensure that the signature of `sub_callable` is compatible with the signature of `super_callable`.
Guarantees that any call to `super_callable` will work on `sub_callable` by checking the following criteria:
1. The return type of `sub_callable` is a subtype of the return type of `super_callable`.
2. All parameters of `super_callable` are present in `sub_callable`, unless `sub_callable`
declares `*args` or `**kwargs`.
3. All positional parameters of `super_callable` appear in the same order in `sub_callable`.
4. All parameters of `super_callable` are a subtype of the corresponding parameters of `sub_callable`.
5. All required parameters of `sub_callable` are present in `super_callable`, unless `super_callable`
declares `*args` or `**kwargs`.
:param super_callable: Function to check compatibility with.
:param sub_callable: Function to check compatibility of.
:param is_static: True if staticmethod and should check first argument.
"""
super_callable = _unbound_func(super_callable)
sub_callable = _unbound_func(sub_callable)
try:
super_sig = inspect.signature(super_callable)
except ValueError:
return
super_type_hints = _get_type_hints(super_callable)
sub_sig = inspect.signature(sub_callable)
sub_type_hints = _get_type_hints(sub_callable)
method_name = sub_callable.__qualname__
same_main_module = _is_same_module(sub_callable, super_callable)
if super_type_hints is not None and sub_type_hints is not None:
ensure_return_type_compatibility(super_type_hints, sub_type_hints, method_name)
ensure_all_kwargs_defined_in_sub(
super_sig, sub_sig, super_type_hints, sub_type_hints, is_static, method_name
)
ensure_all_positional_args_defined_in_sub(
super_sig,
sub_sig,
super_type_hints,
sub_type_hints,
is_static,
same_main_module,
method_name,
)
ensure_no_extra_args_in_sub(super_sig, sub_sig, is_static, method_name)
def _unbound_func(callable: _WrappedMethod) -> _WrappedMethod:
if hasattr(callable, "__self__") and hasattr(callable, "__func__"):
return callable.__func__ # type: ignore
return callable
def ensure_all_kwargs_defined_in_sub(
super_sig: inspect.Signature,
sub_sig: inspect.Signature,
super_type_hints: Dict,
sub_type_hints: Dict,
check_first_parameter: bool,
method_name: str,
):
sub_has_var_kwargs = any(
p.kind == Parameter.VAR_KEYWORD for p in sub_sig.parameters.values()
)
for super_index, (name, super_param) in enumerate(super_sig.parameters.items()):
if super_index == 0 and not check_first_parameter:
continue
if super_param.kind == Parameter.VAR_POSITIONAL:
continue
if super_param.kind == Parameter.POSITIONAL_ONLY:
continue
if not is_param_defined_in_sub(
name, True, sub_has_var_kwargs, sub_sig, super_param
):
raise TypeError(f"{method_name}: `{name}` is not present.")
elif name in sub_sig.parameters and super_param.kind != Parameter.VAR_KEYWORD:
sub_index = list(sub_sig.parameters.keys()).index(name)
sub_param = sub_sig.parameters[name]
if super_param.kind != sub_param.kind and not (
super_param.kind == Parameter.KEYWORD_ONLY
and sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD
):
raise TypeError(f"{method_name}: `{name}` is not `{super_param.kind}`")
elif super_index > sub_index and super_param.kind != Parameter.KEYWORD_ONLY:
raise TypeError(
f"{method_name}: `{name}` is not parameter at index `{super_index}`"
)
elif (
name in super_type_hints
and name in sub_type_hints
and not _issubtype(super_type_hints[name], sub_type_hints[name])
):
raise TypeError(
f"`{method_name}: {name} must be a supertype of `{super_param.annotation}` but is `{sub_param.annotation}`"
)
def ensure_all_positional_args_defined_in_sub(
super_sig: inspect.Signature,
sub_sig: inspect.Signature,
super_type_hints: Dict,
sub_type_hints: Dict,
check_first_parameter: bool,
is_same_main_module: bool,
method_name: str,
):
sub_parameter_values = [
v
for v in sub_sig.parameters.values()
if v.kind not in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD)
]
super_parameter_values = [
v
for v in super_sig.parameters.values()
if v.kind not in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD)
]
sub_has_var_args = any(
p.kind == Parameter.VAR_POSITIONAL for p in sub_parameter_values
)
super_has_var_args = any(
p.kind == Parameter.VAR_POSITIONAL for p in super_parameter_values
)
if not sub_has_var_args and len(sub_parameter_values) < len(super_parameter_values):
raise TypeError(f"{method_name}: parameter list too short")
super_shift = 0
for index, sub_param in enumerate(sub_parameter_values):
if index == 0 and not check_first_parameter:
continue
if index + super_shift >= len(super_parameter_values):
if sub_param.kind == Parameter.VAR_POSITIONAL:
continue
if (
sub_param.kind == Parameter.POSITIONAL_ONLY
and sub_param.default != Parameter.empty
):
continue
if sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD:
continue # Assume use as keyword
raise TypeError(
f"{method_name}: `{sub_param.name}` positionally required in subclass but not in supertype"
)
if sub_param.kind == Parameter.VAR_POSITIONAL:
return
super_param = super_parameter_values[index + super_shift]
if super_param.kind == Parameter.VAR_POSITIONAL:
super_shift -= 1
if super_param.kind == Parameter.VAR_POSITIONAL:
if not sub_has_var_args:
raise TypeError(f"{method_name}: `{super_param.name}` must be present")
continue
if (
super_param.kind != sub_param.kind
and not (
super_param.kind == Parameter.POSITIONAL_ONLY
and sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD
)
and not (sub_param.kind == Parameter.POSITIONAL_ONLY and super_has_var_args)
):
raise TypeError(
f"{method_name}: `{sub_param.name}` is not `{super_param.kind}` and is `{sub_param.kind}`"
)
elif (
super_param.name in super_type_hints or is_same_main_module
) and not _issubtype(
super_type_hints.get(super_param.name, None),
sub_type_hints.get(sub_param.name, None),
):
raise TypeError(
f"`{method_name}: {sub_param.name} overriding must be a supertype of `{super_param.annotation}` but is `{sub_param.annotation}`"
)
def is_param_defined_in_sub(
name: str,
sub_has_var_args: bool,
sub_has_var_kwargs: bool,
sub_sig: inspect.Signature,
super_param: inspect.Parameter,
) -> bool:
return (
name in sub_sig.parameters
or (super_param.kind == Parameter.VAR_POSITIONAL and sub_has_var_args)
or (super_param.kind == Parameter.VAR_KEYWORD and sub_has_var_kwargs)
or (super_param.kind == Parameter.POSITIONAL_ONLY and sub_has_var_args)
or (
super_param.kind == Parameter.POSITIONAL_OR_KEYWORD
and sub_has_var_args
and sub_has_var_kwargs
)
or (super_param.kind == Parameter.KEYWORD_ONLY and sub_has_var_kwargs)
)
def ensure_no_extra_args_in_sub(
super_sig: inspect.Signature,
sub_sig: inspect.Signature,
check_first_parameter: bool,
method_name: str,
) -> None:
super_params = super_sig.parameters.values()
super_var_args = any(p.kind == Parameter.VAR_POSITIONAL for p in super_params)
super_var_kwargs = any(p.kind == Parameter.VAR_KEYWORD for p in super_params)
for sub_index, (name, sub_param) in enumerate(sub_sig.parameters.items()):
if (
sub_param.kind == Parameter.POSITIONAL_ONLY
and len(super_params) > sub_index
and list(super_params)[sub_index].kind == Parameter.POSITIONAL_ONLY
):
continue
if (
name not in super_sig.parameters
and sub_param.default == Parameter.empty
and sub_param.kind != Parameter.VAR_POSITIONAL
and sub_param.kind != Parameter.VAR_KEYWORD
and not (sub_param.kind == Parameter.KEYWORD_ONLY and super_var_kwargs)
and not (sub_param.kind == Parameter.POSITIONAL_ONLY and super_var_args)
and not (
sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD and super_var_args
)
and (sub_index > 0 or check_first_parameter)
):
raise TypeError(f"{method_name}: `{name}` is not a valid parameter.")
def ensure_return_type_compatibility(
super_type_hints: Dict, sub_type_hints: Dict, method_name: str
):
super_return = super_type_hints.get("return", None)
sub_return = sub_type_hints.get("return", None)
if not _issubtype(sub_return, super_return) and super_return is not None:
raise TypeError(
f"{method_name}: return type `{sub_return}` is not a `{super_return}`."
)
|