PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclass utils."""
from __future__ import annotations
import dataclasses
import functools
import typing
from typing import Any, Callable, TypeVar
from etils import epy
from etils.edc import cast_utils
from etils.edc import context
from etils.edc import frozen_utils
from etils.edc import helpers
_Cls = Any
_ClsT = TypeVar('_ClsT')
_T = TypeVar('_T')
@typing.overload
def dataclass(
cls: None = ...,
*,
kw_only: bool = ...,
replace: bool = ..., # pylint: disable=redefined-outer-name
repr: bool = ..., # pylint: disable=redefined-builtin
auto_cast: bool = ...,
contextvars: bool = ...,
allow_unfrozen: bool = ...,
) -> Callable[[_ClsT], _ClsT]:
...
@typing.overload
def dataclass(
cls: _ClsT,
*,
kw_only: bool = ...,
replace: bool = ..., # pylint: disable=redefined-outer-name
repr: bool = ..., # pylint: disable=redefined-builtin
auto_cast: bool = ...,
contextvars: bool = ...,
allow_unfrozen: bool = ...,
) -> _ClsT:
...
def dataclass(
cls=None,
*,
kw_only=False,
replace=True, # pylint: disable=redefined-outer-name
repr=True, # pylint: disable=redefined-builtin
auto_cast=True,
contextvars=True,
allow_unfrozen=False,
):
"""Augment a dataclass with additional features.
`auto_cast`: Auto-convert init assignements to the annotated class.
```python
@edc.dataclass
class A:
path: edc.AutoCast[epath.Path]
some_enum: edc.AutoCast[MyEnum]
x: edc.AutoCast[str]
a = A(
path='/some/path',
some_enum='A',
x=123
)
# Fields annotated with `AutoCast` are automatically casted to their type
assert a.path == epath.Path('/some/path')
assert a.some_enum is MyEnum.A
assert a.x == '123'
```
`allow_unfrozen`: allow nested dataclass to be updated. This add two methods:
* `.unfrozen()`: Create a lazy deep-copy of the current dataclass. Updates
to nested attributes will be propagated to the top-level dataclass.
* `.frozen()`: Returns the frozen dataclass, after it was mutated.
Example:
```python
old_x = X(y=Y(z=123))
x = old_x.unfrozen()
x.y.z = 456
x = x.frozen()
assert x == X(y=Y(z=123)) # Only new x is mutated
assert old_x == X(y=Y(z=456)) # Old x is not mutated
```
Note:
* Only the last `.frozen()` call resolve the dataclass by calling `.replace`
recursivelly.
* Dataclass returned by `.unfrozen()` and nested attributes are not the
original dataclass but proxy objects which track the mutations. As such,
those object are not compatible with `isinstance()`, `jax.tree.map`,...
* Only the top-level dataclass need to be `allow_unfrozen=True`
* Avoid using `unfrozen` if 2 attributes of the dataclass point to the
same nested dataclass. Updates on one attribute might not be reflected on
the other.
```python
y = Y(y=123)
x = X(x0=y, x1=y) # Same instance assigned twice in `x0` and `x1`
x = x.unfrozen()
x.x0.y = 456 # Changes in `x0` not reflected in `x1`
x = x.frozen()
assert x == X(x0=Y(y=456), x1=Y(y=123))
```
This is because only attributes which are accessed are tracked, so `etils`
do not know the object exist somewhere else in the attribute tree.
* After `.frozen()` has been called, any of the temporary sub-attribute
become invalid:
```python
a = a.unfrozen()
y = a.y
a = a.frozen()
y.x # Raise error (created between the unfrozen/frozen call)
a.y.x # Work
```
`contextvars`: Fields annotated as `edc.ContextVar` are wrapped in
a `contextvars.ContextVar`. Afterward each thread / asyncio coroutine will
have its own version of the fields (similarly to `threading.local`).
The contextvars are lazily initialized at first usage.
Example:
```python
@edc.dataclass
@dataclasses.dataclass
class Context:
thread_id: edc.ContextVar[int] = dataclasses.field(
default_factory=threading.get_native_id
)
stack: edc.ContextVar[list[str]] = dataclasses.field(default_factory=list)
# Global context object
context = Context(thread_id=0)
def worker():
# Inside each thread, the worker use its own context
assert context.thread_id != 0
context.stack.append(1)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
for _ in range(10):
executor.submit(worker)
```
Args:
cls: The dataclass to decorate
kw_only: If True, make the dataclass `__init__` keyword-only.
replace: If `True`, add a `.replace(` alias of `dataclasses.replace`.
repr: If `True`, the class `__repr__` will return a pretty-printed `str`
(one attribute per line)
auto_cast: If `True`, fields annotated as `x: edc.AutoCast[Cls]` will be
converted to `x: Cls = edc.field(validator=Cls)`.
contextvars: It `True`, fields annotated as `x: edc.AutoCast[T]` are
converted to `contextvars`. This allow to have a `threading.local`-like
API for contextvars.
allow_unfrozen: If `True`, add `.frozen`, `.unfrozen` methods.
Returns:
Decorated class
"""
# Return decorator
if cls is None:
return functools.partial(
dataclass,
kw_only=kw_only,
replace=replace,
repr=repr,
auto_cast=auto_cast,
allow_unfrozen=allow_unfrozen,
)
if kw_only:
cls = _make_kw_only(cls)
if repr:
cls = add_repr(cls)
if replace:
cls = _add_replace(cls)
if allow_unfrozen:
cls = frozen_utils.add_unfrozen(cls)
descriptor_fns = []
if auto_cast:
descriptor_fns.append(
helpers.DescriptorInfo(
annotation=cast_utils.AutoCast,
descriptor_fn=cast_utils.make_auto_cast_descriptor,
)
)
if contextvars:
descriptor_fns.append(
helpers.DescriptorInfo(
annotation=context.ContextVar,
descriptor_fn=context.make_contextvar_descriptor,
)
)
cls = helpers.wrap_new(cls, descriptor_fns)
return cls
def _make_kw_only(cls: _ClsT) -> _ClsT:
"""Replace the `__init__` by a keyword-only version."""
# Use `cls.__dict__` and not `hasattr` to ignore parent classes
if '__init__' not in cls.__dict__:
return cls # Do not mutate the class if __init__ isn't present
old_init = cls.__init__
# Despite `@functools.wraps`, the function has to be called `__init__` (
# see: https://stackoverflow.com/q/29919804/4172685)
@functools.wraps(old_init)
def __init__(self, *args, **kwargs): # pylint: disable=invalid-name
if args:
raise TypeError(
f'{self.__class__.__name__} contructor is keyword-only. '
f'Got {len(args)} positional arguments.'
)
return old_init(self, **kwargs)
cls.__init__ = __init__
return cls
def _add_replace(cls: _ClsT) -> _ClsT:
"""Add a `.replace` method to the class, if not already present."""
# Only add replace if not present
if not hasattr(cls, 'replace'):
cls.replace = replace
return cls
def replace(self: _T, **kwargs: Any) -> _T:
"""Similar to `dataclasses.replace`."""
return dataclasses.replace(self, **kwargs)
def add_repr(cls: _ClsT) -> _ClsT:
"""Add a `.__repr__` method to the class, if not already present."""
# Use `cls.__dict__` and not `hasattr` to ignore parent classes
if '__repr__' not in cls.__dict__:
return cls
if epy.text_utils.has_default_repr(cls):
cls.__repr__ = __repr__
return cls
__repr__ = epy.pretty_repr