Spaces:
Paused
Paused
# SPDX-License-Identifier: MIT | |
import inspect | |
import platform | |
import sys | |
import threading | |
from collections.abc import Mapping, Sequence # noqa: F401 | |
from typing import _GenericAlias | |
PYPY = platform.python_implementation() == "PyPy" | |
PY_3_8_PLUS = sys.version_info[:2] >= (3, 8) | |
PY_3_9_PLUS = sys.version_info[:2] >= (3, 9) | |
PY310 = sys.version_info[:2] >= (3, 10) | |
PY_3_12_PLUS = sys.version_info[:2] >= (3, 12) | |
if sys.version_info < (3, 8): | |
try: | |
from typing_extensions import Protocol | |
except ImportError: # pragma: no cover | |
Protocol = object | |
else: | |
from typing import Protocol # noqa: F401 | |
class _AnnotationExtractor: | |
""" | |
Extract type annotations from a callable, returning None whenever there | |
is none. | |
""" | |
__slots__ = ["sig"] | |
def __init__(self, callable): | |
try: | |
self.sig = inspect.signature(callable) | |
except (ValueError, TypeError): # inspect failed | |
self.sig = None | |
def get_first_param_type(self): | |
""" | |
Return the type annotation of the first argument if it's not empty. | |
""" | |
if not self.sig: | |
return None | |
params = list(self.sig.parameters.values()) | |
if params and params[0].annotation is not inspect.Parameter.empty: | |
return params[0].annotation | |
return None | |
def get_return_type(self): | |
""" | |
Return the return type if it's not empty. | |
""" | |
if ( | |
self.sig | |
and self.sig.return_annotation is not inspect.Signature.empty | |
): | |
return self.sig.return_annotation | |
return None | |
# Thread-local global to track attrs instances which are already being repr'd. | |
# This is needed because there is no other (thread-safe) way to pass info | |
# about the instances that are already being repr'd through the call stack | |
# in order to ensure we don't perform infinite recursion. | |
# | |
# For instance, if an instance contains a dict which contains that instance, | |
# we need to know that we're already repr'ing the outside instance from within | |
# the dict's repr() call. | |
# | |
# This lives here rather than in _make.py so that the functions in _make.py | |
# don't have a direct reference to the thread-local in their globals dict. | |
# If they have such a reference, it breaks cloudpickle. | |
repr_context = threading.local() | |
def get_generic_base(cl): | |
"""If this is a generic class (A[str]), return the generic base for it.""" | |
if cl.__class__ is _GenericAlias: | |
return cl.__origin__ | |
return None | |