# Copyright 2024 The etils Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Field utils.""" from __future__ import annotations import dataclasses import typing from typing import Any, Callable, Generic, Optional, Type, TypeVar from etils import epy _Dataclass = Any _In = Any _Out = Any _InT = TypeVar('_InT') _OutT = TypeVar('_OutT') def field( *, validate: Optional[Callable[[_In], _OutT]] = None, **kwargs: Any, ) -> dataclasses.Field[_OutT]: """Like `dataclasses.field`, but allow `validator`. Args: validate: A callable `(x) -> x` called each time the variable is assigned. **kwargs: Kwargs forwarded to `dataclasses.field` Returns: The field. """ if validate is None: return dataclasses.field(**kwargs) else: field_ = _Field(validate=validate, field_kwargs=kwargs) return typing.cast(dataclasses.Field, field_) # pylint: disable=g-bare-generic class _Field(Generic[_InT, _OutT]): """Field descriptor.""" def __init__( self, validate: Callable[[_InT], _OutT], field_kwargs: dict[str, Any], ) -> None: """Constructor. Args: validate: A callable called each time the variable is assigned. field_kwargs: Kwargs forwarded to `dataclasses.field` """ # Attribute name and objtype refer to the object in which the descriptor # is applied. E.g. if `A.x = edc.field()`: # * _attribute_name = 'x' # * _objtype = A self._attribute_name: Optional[str] = None self._objtype: Optional[Type[_Dataclass]] = None self._validate_fn = validate self._field_kwargs = field_kwargs # Whether `__get__` has not been called yet. See `__get__` for details. self._first_getattr_call: bool = True def __set_name__(self, objtype: Type[_Dataclass], name: str) -> None: """Bind the descriptor to the class (PEP 487).""" self._objtype = objtype self._attribute_name = name def __get__( self, obj: Optional[_Dataclass], objtype: Optional[Type[_Dataclass]] = None, ) -> _OutT: """Called when `MyDataclass.x` or `my_dataclass.x`.""" # Called as `MyDataclass.my_attribute` if obj is None: if self._first_getattr_call: # Count the number of times `dataclasses.dataclass(cls)` calls # `getattr(cls, f.name)`. # The first time, we return a `dataclasses.Field` to let dataclass # do the magic. # The second time, `dataclasses.dataclass` delete the descriptor if # `isinstance(getattr(cls, f.name, None), Field)`. So it is very # important to return anything except a `dataclasses.Field`. # This rely on implementation detail, but seems to hold for python # 3.6-3.10. self._first_getattr_call = False return dataclasses.field(**self._field_kwargs) else: # TODO(epot): Could better handle default value: Either by returning # the default value, or raising an AttributeError. Currently, we just # return the descriptor: # assert isinstance(MyDataclass.my_attribute, _Field) return self else: # Called as `my_dataclass.my_path` return _getattr(obj, self._attribute_name) def __set__(self, obj: _Dataclass, value: _InT) -> None: """Called as `my_dataclass.x = value`.""" # Validate the value during assignement _setattr(obj, self._attribute_name, self._validate(value)) def _validate(self, value: _InT) -> _OutT: try: return self._validate_fn(value) except Exception as e: # pylint: disable=broad-exception-caught epy.reraise(e, prefix=f'Error assigning {self._attribute_name!r}: ') # Because there is one instance of the `_Field` per class, shared across all # class instances, we need to store the per-object state somewhere. # The simplest is to attach the state in an extra `dict[str, value]`: # `_dataclass_field_values`. def _getattr( obj: _Dataclass, attribute_name: str, ) -> _Out: """Returns the `obj.attribute_name`.""" _init_dataclass_state(obj) # Accessing the attribute before it was set (e.g. before super().__init__) if attribute_name not in obj._dataclass_field_values: # pylint: disable=protected-access raise AttributeError( f"type object '{type(obj).__qualname__}' has no attribute " f"'{attribute_name}'" ) else: return obj._dataclass_field_values[attribute_name] # pylint: disable=protected-access def _setattr( obj: _Dataclass, attribute_name: str, value: _In, ) -> None: """Set the `obj.attribute_name = value`.""" # Note: In `dataclasses.dataclass(frozen=True)`, obj.__setattr__ will # correctly raise a `FrozenInstanceError` before `DataclassField.__set__` is # called. _init_dataclass_state(obj) # fmt: off obj._dataclass_field_values[attribute_name] = value # pylint: disable=protected-access # fmt: on def _init_dataclass_state(obj: _Dataclass) -> None: """Initialize the object state containing all DataclassField values.""" if not hasattr(obj, '_dataclass_field_values'): # Use object.__setattr__ for frozen dataclasses object.__setattr__(obj, '_dataclass_field_values', {})