# Copyright 2024 The etils Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper utils.""" from __future__ import annotations import dataclasses import functools from typing import Any, Callable, Dict, Type, TypeVar from etils import epy import typing_extensions from typing_extensions import Annotated _Cls = Type[Any] _ClsT = TypeVar('_ClsT') Hint = Any _Hints = Dict[str, Hint] Descriptor = Any @dataclasses.dataclass class DescriptorInfo: annotation: Hint descriptor_fn: Callable[[dataclasses.Field[Any], Hint], Descriptor] @functools.cached_property def annotated_token(self) -> object: """Returns the Annotated sentinel.""" assert typing_extensions.get_origin(self.annotation) is Annotated (annotated_token,) = self.annotation.__metadata__ return annotated_token def wrap_new(cls: _ClsT, descriptor_infos: list[DescriptorInfo]) -> _ClsT: """`__new__` decorator to replace the fields by descriptors on first usage.""" if not descriptor_infos: return cls cls._edc_processed = False # pylint: disable=protected-access old_new_fn = cls.__new__ @functools.wraps(old_new_fn) def new_new_fn(cls, *args, **kwargs): if old_new_fn is object.__new__: self = old_new_fn(cls) else: self = old_new_fn(cls, *args, **kwargs) # Already called, skipping initialization if cls.__dict__.get('_edc_processed'): return self # First time, apply to all parent classes . for curr_cls in cls.mro(): # Apply to all parent classes if cls.__dict__.get('_edc_processed', True): # Either: # This class is not a `@edc.dataclass` (but parent might) # This class is already processed continue _replace_field_by_descriptor(curr_cls, descriptor_infos=descriptor_infos) cls._edc_processed = True # pylint: disable=protected-access return self cls.__new__ = new_new_fn return cls def _replace_field_by_descriptor( cls: _Cls, *, descriptor_infos: list[DescriptorInfo], ): """Iterate over the dataclass fields and replace the fields by descriptors.""" if not dataclasses.is_dataclass(cls): # e.g. object return fields = {f.name: f for f in dataclasses.fields(cls)} hints = _get_type_hints(cls, include_extras=True) for name, hint in hints.items(): if name not in cls.__annotations__: continue # Only add typing from the current class # TODO(epot): Should create a typing parsing util. if typing_extensions.get_origin(hint) is not Annotated: continue hint_cls = hint.__origin__ # Unwrap the original type field = fields[name] # Make the descriptor for descriptor_info in descriptor_infos: if not any( a is descriptor_info.annotated_token for a in hint.__metadata__ ): continue descriptor = descriptor_info.descriptor_fn(field, hint_cls) setattr(cls, name, descriptor) # cls.__dict__[name] = cast_field descriptor.__set_name__(cls, name) # Notify the descriptor # Could merge this function with the one in `dataclass_array` in a util. def _get_type_hints(cls, *, include_extras: bool = False) -> _Hints: """`get_type_hints` with better error reporting.""" # At this point, `ForwardRef` should have been resolved. try: return _get_type_hints_fix(cls, include_extras=include_extras) except Exception as e: # pylint: disable=broad-except msg = ( f'Could not infer typing annotation of {cls.__qualname__} ' f'defined in {cls.__module__}:\n' ) lines = [f' * {k}: {v!r}' for k, v in cls.__annotations__.items()] lines = '\n'.join(lines) epy.reraise(e, prefix=msg + lines + '\n') # pytype: disable=bad-return-type def _get_type_hints_fix(cls, *, include_extras: bool = False) -> _Hints: """`get_type_hints` with bug fixes.""" # TODO(py311): `get_type_hints` fail for `_: dataclasses.KW_ONLY` old_annotations = [_fix_annotations(subcls) for subcls in cls.mro()] try: return typing_extensions.get_type_hints(cls, include_extras=include_extras) finally: # Restore the annotations for subcls, annotations in zip(cls.mro(), old_annotations): if annotations: subcls.__annotations__ = annotations def _fix_annotations(cls): """Remove the `_: dataclasses.KW_ONLY` annotation.""" if cls is object or '_' not in getattr(cls, '__annotations__', {}): return old_annotations = dict(cls.__annotations__) cls.__annotations__.pop('_') return old_annotations