File size: 6,194 Bytes
4ae0b03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""RootModel class and type definitions."""
from __future__ import annotations as _annotations
import typing
from copy import copy, deepcopy
from pydantic_core import PydanticUndefined
from . import PydanticUserError
from ._internal import _model_construction, _repr
from .main import BaseModel, _object_setattr
if typing.TYPE_CHECKING:
from typing import Any
from typing_extensions import Literal, Self, dataclass_transform
from .fields import Field as PydanticModelField
from .fields import PrivateAttr as PydanticModelPrivateAttr
# dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
# takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
# on a new metaclass.
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr))
class _RootModelMetaclass(_model_construction.ModelMetaclass): ...
else:
_RootModelMetaclass = _model_construction.ModelMetaclass
__all__ = ('RootModel',)
RootModelRootType = typing.TypeVar('RootModelRootType')
class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
"""Usage docs: https://docs.pydantic.dev/2.8/concepts/models/#rootmodel-and-custom-root-types
A Pydantic `BaseModel` for the root object of the model.
Attributes:
root: The root object of the model.
__pydantic_root_model__: Whether the model is a RootModel.
__pydantic_private__: Private fields in the model.
__pydantic_extra__: Extra fields in the model.
"""
__pydantic_root_model__ = True
__pydantic_private__ = None
__pydantic_extra__ = None
root: RootModelRootType
def __init_subclass__(cls, **kwargs):
extra = cls.model_config.get('extra')
if extra is not None:
raise PydanticUserError(
"`RootModel` does not support setting `model_config['extra']`", code='root-model-extra'
)
super().__init_subclass__(**kwargs)
def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
__tracebackhide__ = True
if data:
if root is not PydanticUndefined:
raise ValueError(
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
)
root = data # type: ignore
self.__pydantic_validator__.validate_python(root, self_instance=self)
__init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
@classmethod
def model_construct(cls, root: RootModelRootType, _fields_set: set[str] | None = None) -> Self: # type: ignore
"""Create a new model using the provided root object and update fields set.
Args:
root: The root object of the model.
_fields_set: The set of fields to be updated.
Returns:
The new model.
Raises:
NotImplemented: If the model is not a subclass of `RootModel`.
"""
return super().model_construct(root=root, _fields_set=_fields_set)
def __getstate__(self) -> dict[Any, Any]:
return {
'__dict__': self.__dict__,
'__pydantic_fields_set__': self.__pydantic_fields_set__,
}
def __setstate__(self, state: dict[Any, Any]) -> None:
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
_object_setattr(self, '__dict__', state['__dict__'])
def __copy__(self) -> Self:
"""Returns a shallow copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', copy(self.__dict__))
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
"""Returns a deep copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
# This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
# and attempting a deepcopy would be marginally slower.
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m
if typing.TYPE_CHECKING:
def model_dump( # type: ignore
self,
*,
mode: Literal['json', 'python'] | str = 'python',
include: Any = None,
exclude: Any = None,
context: dict[str, Any] | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal['none', 'warn', 'error'] = True,
serialize_as_any: bool = False,
) -> Any:
"""This method is included just to get a more accurate return type for type checkers.
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
See the documentation of `BaseModel.model_dump` for more details about the arguments.
Generally, this method will have a return type of `RootModelRootType`, assuming that `RootModelRootType` is
not a `BaseModel` subclass. If `RootModelRootType` is a `BaseModel` subclass, then the return
type will likely be `dict[str, Any]`, as `model_dump` calls are recursive. The return type could
even be something different, in the case of a custom serializer.
Thus, `Any` is used here to catch all of these cases.
"""
...
def __eq__(self, other: Any) -> bool:
if not isinstance(other, RootModel):
return NotImplemented
return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other)
def __repr_args__(self) -> _repr.ReprArgs:
yield 'root', self.root
|