|
import keyword |
|
import warnings |
|
import weakref |
|
from collections import OrderedDict, defaultdict, deque |
|
from copy import deepcopy |
|
from itertools import islice, zip_longest |
|
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType |
|
from typing import ( |
|
TYPE_CHECKING, |
|
AbstractSet, |
|
Any, |
|
Callable, |
|
Collection, |
|
Dict, |
|
Generator, |
|
Iterable, |
|
Iterator, |
|
List, |
|
Mapping, |
|
NoReturn, |
|
Optional, |
|
Set, |
|
Tuple, |
|
Type, |
|
TypeVar, |
|
Union, |
|
) |
|
|
|
from typing_extensions import Annotated |
|
|
|
from pydantic.v1.errors import ConfigError |
|
from pydantic.v1.typing import ( |
|
NoneType, |
|
WithArgsTypes, |
|
all_literal_values, |
|
display_as_type, |
|
get_args, |
|
get_origin, |
|
is_literal_type, |
|
is_union, |
|
) |
|
from pydantic.v1.version import version_info |
|
|
|
if TYPE_CHECKING: |
|
from inspect import Signature |
|
from pathlib import Path |
|
|
|
from pydantic.v1.config import BaseConfig |
|
from pydantic.v1.dataclasses import Dataclass |
|
from pydantic.v1.fields import ModelField |
|
from pydantic.v1.main import BaseModel |
|
from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs |
|
|
|
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]] |
|
|
|
__all__ = ( |
|
'import_string', |
|
'sequence_like', |
|
'validate_field_name', |
|
'lenient_isinstance', |
|
'lenient_issubclass', |
|
'in_ipython', |
|
'is_valid_identifier', |
|
'deep_update', |
|
'update_not_none', |
|
'almost_equal_floats', |
|
'get_model', |
|
'to_camel', |
|
'is_valid_field', |
|
'smart_deepcopy', |
|
'PyObjectStr', |
|
'Representation', |
|
'GetterDict', |
|
'ValueItems', |
|
'version_info', |
|
'ClassAttribute', |
|
'path_type', |
|
'ROOT_KEY', |
|
'get_unique_discriminator_alias', |
|
'get_discriminator_alias_and_values', |
|
'DUNDER_ATTRIBUTES', |
|
) |
|
|
|
ROOT_KEY = '__root__' |
|
|
|
IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = { |
|
int, |
|
float, |
|
complex, |
|
str, |
|
bool, |
|
bytes, |
|
type, |
|
NoneType, |
|
FunctionType, |
|
BuiltinFunctionType, |
|
LambdaType, |
|
weakref.ref, |
|
CodeType, |
|
|
|
|
|
|
|
ModuleType, |
|
NotImplemented.__class__, |
|
Ellipsis.__class__, |
|
} |
|
|
|
|
|
BUILTIN_COLLECTIONS: Set[Type[Any]] = { |
|
list, |
|
set, |
|
tuple, |
|
frozenset, |
|
dict, |
|
OrderedDict, |
|
defaultdict, |
|
deque, |
|
} |
|
|
|
|
|
def import_string(dotted_path: str) -> Any: |
|
""" |
|
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the |
|
last name in the path. Raise ImportError if the import fails. |
|
""" |
|
from importlib import import_module |
|
|
|
try: |
|
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1) |
|
except ValueError as e: |
|
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e |
|
|
|
module = import_module(module_path) |
|
try: |
|
return getattr(module, class_name) |
|
except AttributeError as e: |
|
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e |
|
|
|
|
|
def truncate(v: Union[str], *, max_len: int = 80) -> str: |
|
""" |
|
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long |
|
""" |
|
warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning) |
|
if isinstance(v, str) and len(v) > (max_len - 2): |
|
|
|
return (v[: (max_len - 3)] + '…').__repr__() |
|
try: |
|
v = v.__repr__() |
|
except TypeError: |
|
v = v.__class__.__repr__(v) |
|
if len(v) > max_len: |
|
v = v[: max_len - 1] + '…' |
|
return v |
|
|
|
|
|
def sequence_like(v: Any) -> bool: |
|
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) |
|
|
|
|
|
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None: |
|
""" |
|
Ensure that the field's name does not shadow an existing attribute of the model. |
|
""" |
|
for base in bases: |
|
if getattr(base, field_name, None): |
|
raise NameError( |
|
f'Field name "{field_name}" shadows a BaseModel attribute; ' |
|
f'use a different field name with "alias=\'{field_name}\'".' |
|
) |
|
|
|
|
|
def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: |
|
try: |
|
return isinstance(o, class_or_tuple) |
|
except TypeError: |
|
return False |
|
|
|
|
|
def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: |
|
try: |
|
return isinstance(cls, type) and issubclass(cls, class_or_tuple) |
|
except TypeError: |
|
if isinstance(cls, WithArgsTypes): |
|
return False |
|
raise |
|
|
|
|
|
def in_ipython() -> bool: |
|
""" |
|
Check whether we're in an ipython environment, including jupyter notebooks. |
|
""" |
|
try: |
|
eval('__IPYTHON__') |
|
except NameError: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def is_valid_identifier(identifier: str) -> bool: |
|
""" |
|
Checks that a string is a valid identifier and not a Python keyword. |
|
:param identifier: The identifier to test. |
|
:return: True if the identifier is valid. |
|
""" |
|
return identifier.isidentifier() and not keyword.iskeyword(identifier) |
|
|
|
|
|
KeyType = TypeVar('KeyType') |
|
|
|
|
|
def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]: |
|
updated_mapping = mapping.copy() |
|
for updating_mapping in updating_mappings: |
|
for k, v in updating_mapping.items(): |
|
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict): |
|
updated_mapping[k] = deep_update(updated_mapping[k], v) |
|
else: |
|
updated_mapping[k] = v |
|
return updated_mapping |
|
|
|
|
|
def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None: |
|
mapping.update({k: v for k, v in update.items() if v is not None}) |
|
|
|
|
|
def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: |
|
""" |
|
Return True if two floats are almost equal |
|
""" |
|
return abs(value_1 - value_2) <= delta |
|
|
|
|
|
def generate_model_signature( |
|
init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig'] |
|
) -> 'Signature': |
|
""" |
|
Generate signature for model based on its fields |
|
""" |
|
from inspect import Parameter, Signature, signature |
|
|
|
from pydantic.v1.config import Extra |
|
|
|
present_params = signature(init).parameters.values() |
|
merged_params: Dict[str, Parameter] = {} |
|
var_kw = None |
|
use_var_kw = False |
|
|
|
for param in islice(present_params, 1, None): |
|
if param.kind is param.VAR_KEYWORD: |
|
var_kw = param |
|
continue |
|
merged_params[param.name] = param |
|
|
|
if var_kw: |
|
allow_names = config.allow_population_by_field_name |
|
for field_name, field in fields.items(): |
|
param_name = field.alias |
|
if field_name in merged_params or param_name in merged_params: |
|
continue |
|
elif not is_valid_identifier(param_name): |
|
if allow_names and is_valid_identifier(field_name): |
|
param_name = field_name |
|
else: |
|
use_var_kw = True |
|
continue |
|
|
|
|
|
kwargs = {'default': field.default} if not field.required else {} |
|
merged_params[param_name] = Parameter( |
|
param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs |
|
) |
|
|
|
if config.extra is Extra.allow: |
|
use_var_kw = True |
|
|
|
if var_kw and use_var_kw: |
|
|
|
|
|
default_model_signature = [ |
|
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), |
|
('data', Parameter.VAR_KEYWORD), |
|
] |
|
if [(p.name, p.kind) for p in present_params] == default_model_signature: |
|
|
|
var_kw_name = 'extra_data' |
|
else: |
|
|
|
var_kw_name = var_kw.name |
|
|
|
|
|
while var_kw_name in fields: |
|
var_kw_name += '_' |
|
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) |
|
|
|
return Signature(parameters=list(merged_params.values()), return_annotation=None) |
|
|
|
|
|
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']: |
|
from pydantic.v1.main import BaseModel |
|
|
|
try: |
|
model_cls = obj.__pydantic_model__ |
|
except AttributeError: |
|
model_cls = obj |
|
|
|
if not issubclass(model_cls, BaseModel): |
|
raise TypeError('Unsupported type, must be either BaseModel or dataclass') |
|
return model_cls |
|
|
|
|
|
def to_camel(string: str) -> str: |
|
return ''.join(word.capitalize() for word in string.split('_')) |
|
|
|
|
|
def to_lower_camel(string: str) -> str: |
|
if len(string) >= 1: |
|
pascal_string = to_camel(string) |
|
return pascal_string[0].lower() + pascal_string[1:] |
|
return string.lower() |
|
|
|
|
|
T = TypeVar('T') |
|
|
|
|
|
def unique_list( |
|
input_list: Union[List[T], Tuple[T, ...]], |
|
*, |
|
name_factory: Callable[[T], str] = str, |
|
) -> List[T]: |
|
""" |
|
Make a list unique while maintaining order. |
|
We update the list if another one with the same name is set |
|
(e.g. root validator overridden in subclass) |
|
""" |
|
result: List[T] = [] |
|
result_names: List[str] = [] |
|
for v in input_list: |
|
v_name = name_factory(v) |
|
if v_name not in result_names: |
|
result_names.append(v_name) |
|
result.append(v) |
|
else: |
|
result[result_names.index(v_name)] = v |
|
|
|
return result |
|
|
|
|
|
class PyObjectStr(str): |
|
""" |
|
String class where repr doesn't include quotes. Useful with Representation when you want to return a string |
|
representation of something that valid (or pseudo-valid) python. |
|
""" |
|
|
|
def __repr__(self) -> str: |
|
return str(self) |
|
|
|
|
|
class Representation: |
|
""" |
|
Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details. |
|
|
|
__pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations |
|
of objects. |
|
""" |
|
|
|
__slots__: Tuple[str, ...] = tuple() |
|
|
|
def __repr_args__(self) -> 'ReprArgs': |
|
""" |
|
Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. |
|
|
|
Can either return: |
|
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]` |
|
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]` |
|
""" |
|
attrs = ((s, getattr(self, s)) for s in self.__slots__) |
|
return [(a, v) for a, v in attrs if v is not None] |
|
|
|
def __repr_name__(self) -> str: |
|
""" |
|
Name of the instance's class, used in __repr__. |
|
""" |
|
return self.__class__.__name__ |
|
|
|
def __repr_str__(self, join_str: str) -> str: |
|
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) |
|
|
|
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: |
|
""" |
|
Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects |
|
""" |
|
yield self.__repr_name__() + '(' |
|
yield 1 |
|
for name, value in self.__repr_args__(): |
|
if name is not None: |
|
yield name + '=' |
|
yield fmt(value) |
|
yield ',' |
|
yield 0 |
|
yield -1 |
|
yield ')' |
|
|
|
def __str__(self) -> str: |
|
return self.__repr_str__(' ') |
|
|
|
def __repr__(self) -> str: |
|
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' |
|
|
|
def __rich_repr__(self) -> 'RichReprResult': |
|
"""Get fields for Rich library""" |
|
for name, field_repr in self.__repr_args__(): |
|
if name is None: |
|
yield field_repr |
|
else: |
|
yield name, field_repr |
|
|
|
|
|
class GetterDict(Representation): |
|
""" |
|
Hack to make object's smell just enough like dicts for validate_model. |
|
|
|
We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves. |
|
""" |
|
|
|
__slots__ = ('_obj',) |
|
|
|
def __init__(self, obj: Any): |
|
self._obj = obj |
|
|
|
def __getitem__(self, key: str) -> Any: |
|
try: |
|
return getattr(self._obj, key) |
|
except AttributeError as e: |
|
raise KeyError(key) from e |
|
|
|
def get(self, key: Any, default: Any = None) -> Any: |
|
return getattr(self._obj, key, default) |
|
|
|
def extra_keys(self) -> Set[Any]: |
|
""" |
|
We don't want to get any other attributes of obj if the model didn't explicitly ask for them |
|
""" |
|
return set() |
|
|
|
def keys(self) -> List[Any]: |
|
""" |
|
Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python |
|
dictionaries. |
|
""" |
|
return list(self) |
|
|
|
def values(self) -> List[Any]: |
|
return [self[k] for k in self] |
|
|
|
def items(self) -> Iterator[Tuple[str, Any]]: |
|
for k in self: |
|
yield k, self.get(k) |
|
|
|
def __iter__(self) -> Iterator[str]: |
|
for name in dir(self._obj): |
|
if not name.startswith('_'): |
|
yield name |
|
|
|
def __len__(self) -> int: |
|
return sum(1 for _ in self) |
|
|
|
def __contains__(self, item: Any) -> bool: |
|
return item in self.keys() |
|
|
|
def __eq__(self, other: Any) -> bool: |
|
return dict(self) == dict(other.items()) |
|
|
|
def __repr_args__(self) -> 'ReprArgs': |
|
return [(None, dict(self))] |
|
|
|
def __repr_name__(self) -> str: |
|
return f'GetterDict[{display_as_type(self._obj)}]' |
|
|
|
|
|
class ValueItems(Representation): |
|
""" |
|
Class for more convenient calculation of excluded or included fields on values. |
|
""" |
|
|
|
__slots__ = ('_items', '_type') |
|
|
|
def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None: |
|
items = self._coerce_items(items) |
|
|
|
if isinstance(value, (list, tuple)): |
|
items = self._normalize_indexes(items, len(value)) |
|
|
|
self._items: 'MappingIntStrAny' = items |
|
|
|
def is_excluded(self, item: Any) -> bool: |
|
""" |
|
Check if item is fully excluded. |
|
|
|
:param item: key or index of a value |
|
""" |
|
return self.is_true(self._items.get(item)) |
|
|
|
def is_included(self, item: Any) -> bool: |
|
""" |
|
Check if value is contained in self._items |
|
|
|
:param item: key or index of value |
|
""" |
|
return item in self._items |
|
|
|
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]: |
|
""" |
|
:param e: key or index of element on value |
|
:return: raw values for element if self._items is dict and contain needed element |
|
""" |
|
|
|
item = self._items.get(e) |
|
return item if not self.is_true(item) else None |
|
|
|
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny': |
|
""" |
|
:param items: dict or set of indexes which will be normalized |
|
:param v_length: length of sequence indexes of which will be |
|
|
|
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4) |
|
{0: True, 2: True, 3: True} |
|
>>> self._normalize_indexes({'__all__': True}, 4) |
|
{0: True, 1: True, 2: True, 3: True} |
|
""" |
|
|
|
normalized_items: 'DictIntStrAny' = {} |
|
all_items = None |
|
for i, v in items.items(): |
|
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)): |
|
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') |
|
if i == '__all__': |
|
all_items = self._coerce_value(v) |
|
continue |
|
if not isinstance(i, int): |
|
raise TypeError( |
|
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' |
|
'expected integer keys or keyword "__all__"' |
|
) |
|
normalized_i = v_length + i if i < 0 else i |
|
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i)) |
|
|
|
if not all_items: |
|
return normalized_items |
|
if self.is_true(all_items): |
|
for i in range(v_length): |
|
normalized_items.setdefault(i, ...) |
|
return normalized_items |
|
for i in range(v_length): |
|
normalized_item = normalized_items.setdefault(i, {}) |
|
if not self.is_true(normalized_item): |
|
normalized_items[i] = self.merge(all_items, normalized_item) |
|
return normalized_items |
|
|
|
@classmethod |
|
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any: |
|
""" |
|
Merge a ``base`` item with an ``override`` item. |
|
|
|
Both ``base`` and ``override`` are converted to dictionaries if possible. |
|
Sets are converted to dictionaries with the sets entries as keys and |
|
Ellipsis as values. |
|
|
|
Each key-value pair existing in ``base`` is merged with ``override``, |
|
while the rest of the key-value pairs are updated recursively with this function. |
|
|
|
Merging takes place based on the "union" of keys if ``intersect`` is |
|
set to ``False`` (default) and on the intersection of keys if |
|
``intersect`` is set to ``True``. |
|
""" |
|
override = cls._coerce_value(override) |
|
base = cls._coerce_value(base) |
|
if override is None: |
|
return base |
|
if cls.is_true(base) or base is None: |
|
return override |
|
if cls.is_true(override): |
|
return base if intersect else override |
|
|
|
|
|
if intersect: |
|
merge_keys = [k for k in base if k in override] + [k for k in override if k in base] |
|
else: |
|
merge_keys = list(base) + [k for k in override if k not in base] |
|
|
|
merged: 'DictIntStrAny' = {} |
|
for k in merge_keys: |
|
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect) |
|
if merged_item is not None: |
|
merged[k] = merged_item |
|
|
|
return merged |
|
|
|
@staticmethod |
|
def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny': |
|
if isinstance(items, Mapping): |
|
pass |
|
elif isinstance(items, AbstractSet): |
|
items = dict.fromkeys(items, ...) |
|
else: |
|
class_name = getattr(items, '__class__', '???') |
|
assert_never( |
|
items, |
|
f'Unexpected type of exclude value {class_name}', |
|
) |
|
return items |
|
|
|
@classmethod |
|
def _coerce_value(cls, value: Any) -> Any: |
|
if value is None or cls.is_true(value): |
|
return value |
|
return cls._coerce_items(value) |
|
|
|
@staticmethod |
|
def is_true(v: Any) -> bool: |
|
return v is True or v is ... |
|
|
|
def __repr_args__(self) -> 'ReprArgs': |
|
return [(None, self._items)] |
|
|
|
|
|
class ClassAttribute: |
|
""" |
|
Hide class attribute from its instances |
|
""" |
|
|
|
__slots__ = ( |
|
'name', |
|
'value', |
|
) |
|
|
|
def __init__(self, name: str, value: Any) -> None: |
|
self.name = name |
|
self.value = value |
|
|
|
def __get__(self, instance: Any, owner: Type[Any]) -> None: |
|
if instance is None: |
|
return self.value |
|
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only') |
|
|
|
|
|
path_types = { |
|
'is_dir': 'directory', |
|
'is_file': 'file', |
|
'is_mount': 'mount point', |
|
'is_symlink': 'symlink', |
|
'is_block_device': 'block device', |
|
'is_char_device': 'char device', |
|
'is_fifo': 'FIFO', |
|
'is_socket': 'socket', |
|
} |
|
|
|
|
|
def path_type(p: 'Path') -> str: |
|
""" |
|
Find out what sort of thing a path is. |
|
""" |
|
assert p.exists(), 'path does not exist' |
|
for method, name in path_types.items(): |
|
if getattr(p, method)(): |
|
return name |
|
|
|
return 'unknown' |
|
|
|
|
|
Obj = TypeVar('Obj') |
|
|
|
|
|
def smart_deepcopy(obj: Obj) -> Obj: |
|
""" |
|
Return type as is for immutable built-in types |
|
Use obj.copy() for built-in empty collections |
|
Use copy.deepcopy() for non-empty collections and unknown objects |
|
""" |
|
|
|
obj_type = obj.__class__ |
|
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES: |
|
return obj |
|
try: |
|
if not obj and obj_type in BUILTIN_COLLECTIONS: |
|
|
|
return obj if obj_type is tuple else obj.copy() |
|
except (TypeError, ValueError, RuntimeError): |
|
|
|
pass |
|
|
|
return deepcopy(obj) |
|
|
|
|
|
def is_valid_field(name: str) -> bool: |
|
if not name.startswith('_'): |
|
return True |
|
return ROOT_KEY == name |
|
|
|
|
|
DUNDER_ATTRIBUTES = { |
|
'__annotations__', |
|
'__classcell__', |
|
'__doc__', |
|
'__module__', |
|
'__orig_bases__', |
|
'__orig_class__', |
|
'__qualname__', |
|
} |
|
|
|
|
|
def is_valid_private_name(name: str) -> bool: |
|
return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES |
|
|
|
|
|
_EMPTY = object() |
|
|
|
|
|
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool: |
|
""" |
|
Check that the items of `left` are the same objects as those in `right`. |
|
|
|
>>> a, b = object(), object() |
|
>>> all_identical([a, b, a], [a, b, a]) |
|
True |
|
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical" |
|
False |
|
""" |
|
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY): |
|
if left_item is not right_item: |
|
return False |
|
return True |
|
|
|
|
|
def assert_never(obj: NoReturn, msg: str) -> NoReturn: |
|
""" |
|
Helper to make sure that we have covered all possible types. |
|
|
|
This is mostly useful for ``mypy``, docs: |
|
https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks |
|
""" |
|
raise TypeError(msg) |
|
|
|
|
|
def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str: |
|
"""Validate that all aliases are the same and if that's the case return the alias""" |
|
unique_aliases = set(all_aliases) |
|
if len(unique_aliases) > 1: |
|
raise ConfigError( |
|
f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})' |
|
) |
|
return unique_aliases.pop() |
|
|
|
|
|
def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]: |
|
""" |
|
Get alias and all valid values in the `Literal` type of the discriminator field |
|
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many. |
|
""" |
|
is_root_model = getattr(tp, '__custom_root_type__', False) |
|
|
|
if get_origin(tp) is Annotated: |
|
tp = get_args(tp)[0] |
|
|
|
if hasattr(tp, '__pydantic_model__'): |
|
tp = tp.__pydantic_model__ |
|
|
|
if is_union(get_origin(tp)): |
|
alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key) |
|
return alias, tuple(v for values in all_values for v in values) |
|
elif is_root_model: |
|
union_type = tp.__fields__[ROOT_KEY].type_ |
|
alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key) |
|
|
|
if len(set(all_values)) > 1: |
|
raise ConfigError( |
|
f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}' |
|
) |
|
|
|
return alias, all_values[0] |
|
|
|
else: |
|
try: |
|
t_discriminator_type = tp.__fields__[discriminator_key].type_ |
|
except AttributeError as e: |
|
raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e |
|
except KeyError as e: |
|
raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e |
|
|
|
if not is_literal_type(t_discriminator_type): |
|
raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`') |
|
|
|
return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type) |
|
|
|
|
|
def _get_union_alias_and_all_values( |
|
union_type: Type[Any], discriminator_key: str |
|
) -> Tuple[str, Tuple[Tuple[str, ...], ...]]: |
|
zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)] |
|
|
|
all_aliases, all_values = zip(*zipped_aliases_values) |
|
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values |
|
|