PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2024 The etils Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unfrozen dataclasses."""
from __future__ import annotations
import dataclasses
import functools
import sys
from typing import Any, Generic, NoReturn, TypeVar, Union
_Cls = TypeVar('_Cls')
_T = TypeVar('_T')
_is_tree_registered = False
def add_unfrozen(cls: _Cls) -> _Cls:
"""Add the `frozen`, `unfrozen` methods."""
cls_frozen = getattr(cls, 'frozen', None)
cls_unfrozen = getattr(cls, 'has_unfrozen', None)
# Already inherit a unfrozeen class
if cls_frozen is frozen and cls_unfrozen is unfrozen:
return cls
# Partial implementation, or collision detected
if cls_frozen is not None or cls_unfrozen is not None:
raise ValueError(f'{cls} already define `frozen` or `unfrozen`')
cls.frozen = frozen
cls.unfrozen = unfrozen
return cls
def frozen(self: _T) -> _T:
"""Freeze the dataclass."""
del self
raise ValueError('`.frozen()` can only be called after `.unfrozen()`.')
def unfrozen(self: _T) -> _T:
"""Returns a lazy deep-copy of the dataclass."""
global _is_tree_registered
if not _is_tree_registered:
jax = sys.modules.get('jax', None)
if jax is not None:
jax.tree_util.register_pytree_node_class(_MutableProxy)
_is_tree_registered = True
impl = _MutableProxyImpl(obj=self, common=_Common(), is_root=True)
return impl.public_api
# Limitations: Compatibility of _MutableProxy with jax.tree_utils, chex,...
class _MutableProxy(Generic[_T]):
"""Proxy which mutate the dataclass.
Note: To avoid attribute collisions with the wrapped class, the actual
implementation is moved inside a `_edc_impl` attribute. We cannot use mangled
names due to compatibility issues with autograph.
This module only expose the public API.
"""
def __init__(self, impl: _MutableProxyImpl):
super().__setattr__('_edc_impl', impl)
def unfrozen(self) -> NoReturn:
raise ValueError('Object is already unfrozen. Cannot call `.unfrozen()`.')
def frozen(self) -> _T:
return self._edc_impl.frozen()
def __getattr__(self, name: str) -> Any:
return self._edc_impl.getattr(name)
def __setattr__(self, name: str, value: Any) -> None:
return self._edc_impl.setattr(name, value)
def __repr__(self) -> str:
return f'{type(self).__name__}({self._edc_impl.resolve()!r})'
def tree_flatten(self) -> tuple[list[Any], Any]:
"""`jax.tree_utils` support."""
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error
obj = self._edc_impl.resolve()
# Not if the wrapped object do not support tree_map, then it will be
# returned as expected
return jax.tree_util.tree_flatten(obj)
@classmethod
def tree_unflatten(
cls,
metadata: Any,
flattened: list[Any],
) -> Any:
"""`jax.tree_utils` support."""
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error
return jax.tree_util.tree_unflatten(metadata, flattened)
@dataclasses.dataclass
class _Common:
"""Shared variable across all nested childs of an `unfrozen()` object.
Attributes:
cache: Global mapping `id(_MutableProxyImpl) -> _MutableProxyImpl` to avoid
duplicating the same object proxy
```python
a = a.unfrozen()
a.x = x
a.y = x # `a.x` and `a.y` point to the same object
a = a.frozen()
assert a.x is a.y
```
resolved: Cache of the objects after they have been frozen (to avoid
evaluating the object twice.
is_frozen: Become `True` after `.frozen()` is called. After which all
Mutable are invalid
"""
cache: dict[int, _MutableProxyImpl] = dataclasses.field(default_factory=dict)
resolved: dict[int, Any] = dataclasses.field(default_factory=dict)
is_frozen: bool = False
def get_proxy(self, value: Any) -> _MutableProxyImpl:
"""Returns the proxy associated with the given value, or create it."""
id_ = id(value)
if id_ not in self.cache:
self.cache[id_] = _MutableProxyImpl(obj=value, common=self)
return self.cache[id_]
@dataclasses.dataclass
class _MutableProxyImpl(Generic[_T]):
"""Proxy implementation is a separate class to avoid collisions."""
obj: _T
common: _Common
is_root: bool = False
# Child info
attrs: dict[str, Union[_MutableProxyImpl, Any]] = dataclasses.field(
default_factory=dict
)
@functools.cached_property
def public_api(self) -> _MutableProxy:
return _MutableProxy(self)
@functools.cached_property
def _fields(self) -> dict[str, dataclasses.Field[Any]]:
# Could also filter only init=True fields
return {f.name: f for f in dataclasses.fields(self.obj)}
def _is_dataclass_field(self, name: str) -> bool:
"""Returns True if the field is a dataclass attribute."""
return name in self._fields # pylint: disable=unsupported-membership-test
def getattr(self, name: str) -> Any:
"""Returns `obj.name`."""
if self.common.is_frozen:
raise AttributeError('Cannot access value after the mutable was frozen.')
# Reuse cache if it exists
if name in self.attrs:
value = self.attrs[name]
else:
value = getattr(self.obj, name)
# Eventually wrap the value in a proxy
if self._is_dataclass_field(name) and dataclasses.is_dataclass(value):
value = self.common.get_proxy(value)
self.attrs[name] = value
if isinstance(value, _MutableProxyImpl):
value = value.public_api
return value
def setattr(self, name: str, value: Any) -> None:
"""Set `obj.name`."""
if self.common.is_frozen:
raise AttributeError(
'Cannot set attributes after the mutable was frozen.'
)
# TODO(epot): Check that the field we're trying to overwrite is a
# dataclass field
if not self._is_dataclass_field(name):
raise AttributeError(f'Cannot set {name!r}: Not a dataclass attribute.')
# Trying to set another mutable mapping
if isinstance(value, _MutableProxy):
value = value._edc_impl # pylint: disable=protected-access
if value.common is not self.common:
raise ValueError(
f'Trying to mix `unfrozen` attributes. For: {name}={value}'
)
# Wrapping dataclasses in proxy objects
elif dataclasses.is_dataclass(value):
value = self.common.get_proxy(value)
# Storing the new value
self.attrs[name] = value
def frozen(self) -> _T:
if not self.is_root:
raise ValueError('Only the top-level dataclass can be `.frozen`')
self.common.is_frozen = True
return self.resolve()
def resolve(self) -> _T:
"""Recursivelly call `.replace` on instances which were mutated."""
self.common.resolved.clear() # Clear the resolved cache
return self._resolve_inner()
def _resolve_inner(self) -> _T:
"""Recursivelly call `.replace` on instances which were mutated."""
id_ = id(self)
# Cache resolved object, to handle cycle & cie
if id_ not in self.common.resolved:
# Cached property, so that the same object is only resolved once
new_vals = {}
for k, v in self.attrs.items():
if isinstance(v, _MutableProxyImpl):
v = v._resolve_inner() # pylint: disable=protected-access
# For optimization, we could skip attributes which were read-only
# and not mutated (but careful as `dc.attr = A()` will wrap A() in
# a mapping proxy
new_vals[k] = v
if not new_vals: # Object wasn't mutated
resolved = self.obj
else:
resolved = dataclasses.replace(self.obj, **new_vals)
self.common.resolved[id_] = resolved
return self.common.resolved[id_]