Spaces:
Building
Building
File size: 8,036 Bytes
f5f3483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclass utils."""
from __future__ import annotations
import dataclasses
import functools
import typing
from typing import Any, Callable, TypeVar
from etils import epy
from etils.edc import cast_utils
from etils.edc import context
from etils.edc import frozen_utils
from etils.edc import helpers
_Cls = Any
_ClsT = TypeVar('_ClsT')
_T = TypeVar('_T')
@typing.overload
def dataclass(
cls: None = ...,
*,
kw_only: bool = ...,
replace: bool = ..., # pylint: disable=redefined-outer-name
repr: bool = ..., # pylint: disable=redefined-builtin
auto_cast: bool = ...,
contextvars: bool = ...,
allow_unfrozen: bool = ...,
) -> Callable[[_ClsT], _ClsT]:
...
@typing.overload
def dataclass(
cls: _ClsT,
*,
kw_only: bool = ...,
replace: bool = ..., # pylint: disable=redefined-outer-name
repr: bool = ..., # pylint: disable=redefined-builtin
auto_cast: bool = ...,
contextvars: bool = ...,
allow_unfrozen: bool = ...,
) -> _ClsT:
...
def dataclass(
cls=None,
*,
kw_only=False,
replace=True, # pylint: disable=redefined-outer-name
repr=True, # pylint: disable=redefined-builtin
auto_cast=True,
contextvars=True,
allow_unfrozen=False,
):
"""Augment a dataclass with additional features.
`auto_cast`: Auto-convert init assignements to the annotated class.
```python
@edc.dataclass
class A:
path: edc.AutoCast[epath.Path]
some_enum: edc.AutoCast[MyEnum]
x: edc.AutoCast[str]
a = A(
path='/some/path',
some_enum='A',
x=123
)
# Fields annotated with `AutoCast` are automatically casted to their type
assert a.path == epath.Path('/some/path')
assert a.some_enum is MyEnum.A
assert a.x == '123'
```
`allow_unfrozen`: allow nested dataclass to be updated. This add two methods:
* `.unfrozen()`: Create a lazy deep-copy of the current dataclass. Updates
to nested attributes will be propagated to the top-level dataclass.
* `.frozen()`: Returns the frozen dataclass, after it was mutated.
Example:
```python
old_x = X(y=Y(z=123))
x = old_x.unfrozen()
x.y.z = 456
x = x.frozen()
assert x == X(y=Y(z=123)) # Only new x is mutated
assert old_x == X(y=Y(z=456)) # Old x is not mutated
```
Note:
* Only the last `.frozen()` call resolve the dataclass by calling `.replace`
recursivelly.
* Dataclass returned by `.unfrozen()` and nested attributes are not the
original dataclass but proxy objects which track the mutations. As such,
those object are not compatible with `isinstance()`, `jax.tree.map`,...
* Only the top-level dataclass need to be `allow_unfrozen=True`
* Avoid using `unfrozen` if 2 attributes of the dataclass point to the
same nested dataclass. Updates on one attribute might not be reflected on
the other.
```python
y = Y(y=123)
x = X(x0=y, x1=y) # Same instance assigned twice in `x0` and `x1`
x = x.unfrozen()
x.x0.y = 456 # Changes in `x0` not reflected in `x1`
x = x.frozen()
assert x == X(x0=Y(y=456), x1=Y(y=123))
```
This is because only attributes which are accessed are tracked, so `etils`
do not know the object exist somewhere else in the attribute tree.
* After `.frozen()` has been called, any of the temporary sub-attribute
become invalid:
```python
a = a.unfrozen()
y = a.y
a = a.frozen()
y.x # Raise error (created between the unfrozen/frozen call)
a.y.x # Work
```
`contextvars`: Fields annotated as `edc.ContextVar` are wrapped in
a `contextvars.ContextVar`. Afterward each thread / asyncio coroutine will
have its own version of the fields (similarly to `threading.local`).
The contextvars are lazily initialized at first usage.
Example:
```python
@edc.dataclass
@dataclasses.dataclass
class Context:
thread_id: edc.ContextVar[int] = dataclasses.field(
default_factory=threading.get_native_id
)
stack: edc.ContextVar[list[str]] = dataclasses.field(default_factory=list)
# Global context object
context = Context(thread_id=0)
def worker():
# Inside each thread, the worker use its own context
assert context.thread_id != 0
context.stack.append(1)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
for _ in range(10):
executor.submit(worker)
```
Args:
cls: The dataclass to decorate
kw_only: If True, make the dataclass `__init__` keyword-only.
replace: If `True`, add a `.replace(` alias of `dataclasses.replace`.
repr: If `True`, the class `__repr__` will return a pretty-printed `str`
(one attribute per line)
auto_cast: If `True`, fields annotated as `x: edc.AutoCast[Cls]` will be
converted to `x: Cls = edc.field(validator=Cls)`.
contextvars: It `True`, fields annotated as `x: edc.AutoCast[T]` are
converted to `contextvars`. This allow to have a `threading.local`-like
API for contextvars.
allow_unfrozen: If `True`, add `.frozen`, `.unfrozen` methods.
Returns:
Decorated class
"""
# Return decorator
if cls is None:
return functools.partial(
dataclass,
kw_only=kw_only,
replace=replace,
repr=repr,
auto_cast=auto_cast,
allow_unfrozen=allow_unfrozen,
)
if kw_only:
cls = _make_kw_only(cls)
if repr:
cls = add_repr(cls)
if replace:
cls = _add_replace(cls)
if allow_unfrozen:
cls = frozen_utils.add_unfrozen(cls)
descriptor_fns = []
if auto_cast:
descriptor_fns.append(
helpers.DescriptorInfo(
annotation=cast_utils.AutoCast,
descriptor_fn=cast_utils.make_auto_cast_descriptor,
)
)
if contextvars:
descriptor_fns.append(
helpers.DescriptorInfo(
annotation=context.ContextVar,
descriptor_fn=context.make_contextvar_descriptor,
)
)
cls = helpers.wrap_new(cls, descriptor_fns)
return cls
def _make_kw_only(cls: _ClsT) -> _ClsT:
"""Replace the `__init__` by a keyword-only version."""
# Use `cls.__dict__` and not `hasattr` to ignore parent classes
if '__init__' not in cls.__dict__:
return cls # Do not mutate the class if __init__ isn't present
old_init = cls.__init__
# Despite `@functools.wraps`, the function has to be called `__init__` (
# see: https://stackoverflow.com/q/29919804/4172685)
@functools.wraps(old_init)
def __init__(self, *args, **kwargs): # pylint: disable=invalid-name
if args:
raise TypeError(
f'{self.__class__.__name__} contructor is keyword-only. '
f'Got {len(args)} positional arguments.'
)
return old_init(self, **kwargs)
cls.__init__ = __init__
return cls
def _add_replace(cls: _ClsT) -> _ClsT:
"""Add a `.replace` method to the class, if not already present."""
# Only add replace if not present
if not hasattr(cls, 'replace'):
cls.replace = replace
return cls
def replace(self: _T, **kwargs: Any) -> _T:
"""Similar to `dataclasses.replace`."""
return dataclasses.replace(self, **kwargs)
def add_repr(cls: _ClsT) -> _ClsT:
"""Add a `.__repr__` method to the class, if not already present."""
# Use `cls.__dict__` and not `hasattr` to ignore parent classes
if '__repr__' not in cls.__dict__:
return cls
if epy.text_utils.has_default_repr(cls):
cls.__repr__ = __repr__
return cls
__repr__ = epy.pretty_repr
|