# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """JAX/dm-tree friendly dataclass implementation reusing Python dataclasses.""" import collections import dataclasses import functools import sys from absl import logging import jax from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet FrozenInstanceError = dataclasses.FrozenInstanceError _RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple")) def mappable_dataclass(cls): """Exposes dataclass as ``collections.abc.Mapping`` descendent. Allows to traverse dataclasses in methods from `dm-tree` library. NOTE: changes dataclasses constructor to dict-type (i.e. positional args aren't supported; however can use generators/iterables). Args: cls: A dataclass to mutate. Returns: Mutated dataclass implementing ``collections.abc.Mapping`` interface. """ if not dataclasses.is_dataclass(cls): raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).") # Define methods for compatibility with `collections.abc.Mapping`. setattr(cls, "__getitem__", lambda self, x: self.__dict__[x]) setattr(cls, "__len__", lambda self: len(self.__dict__)) setattr(cls, "__iter__", lambda self: iter(self.__dict__)) # Override the default `collections.abc.Mapping` method implementation for # cleaner visualization. Without this change x.keys() shows the full repr(x) # instead of only the dict_keys present. The same goes for values and items. setattr(cls, "keys", lambda self: self.__dict__.keys()) setattr(cls, "values", lambda self: self.__dict__.values()) setattr(cls, "items", lambda self: self.__dict__.items()) # Update constructor. orig_init = cls.__init__ all_fields = set(f.name for f in cls.__dataclass_fields__.values()) init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init] @functools.wraps(orig_init) def new_init(self, *orig_args, **orig_kwargs): if (orig_args and orig_kwargs) or len(orig_args) > 1: raise ValueError( "Mappable dataclass constructor doesn't support positional args." "(it has the same constructor as python dict)") all_kwargs = dict(*orig_args, **orig_kwargs) unknown_kwargs = set(all_kwargs.keys()) - all_fields if unknown_kwargs: raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.") # Pass only arguments corresponding to fields with `init=True`. valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields} orig_init(self, **valid_kwargs) cls.__init__ = new_init # Update base class to derive from Mapping dct = dict(cls.__dict__) if "__dict__" in dct: dct.pop("__dict__") # Avoid self-references. # Remove object from the sequence of base classes. Deriving from both Mapping # and object will cause a failure to create a MRO for the updated class bases = tuple(b for b in cls.__bases__ if b != object) cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct) return cls @dataclass_transform() def dataclass( cls=None, *, init=True, repr=True, # pylint: disable=redefined-builtin eq=True, order=False, unsafe_hash=False, frozen=False, kw_only: bool = False, mappable_dataclass=True, # pylint: disable=redefined-outer-name ): """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. This wrapper class registers new dataclasses with JAX so that tree utils operate correctly. Additionally a replace method is provided making it easy to operate on the class when made immutable (frozen=True). Args: cls: A class to decorate. init: See :py:func:`dataclasses.dataclass`. repr: See :py:func:`dataclasses.dataclass`. eq: See :py:func:`dataclasses.dataclass`. order: See :py:func:`dataclasses.dataclass`. unsafe_hash: See :py:func:`dataclasses.dataclass`. frozen: See :py:func:`dataclasses.dataclass`. kw_only: See :py:func:`dataclasses.dataclass`. mappable_dataclass: If True (the default), methods to make the class implement the :py:class:`collections.abc.Mapping` interface will be generated and the class will include :py:class:`collections.abc.Mapping` in its base classes. `True` is the default, because being an instance of `Mapping` makes `chex.dataclass` compatible with e.g. `jax.tree_util.tree_*` methods, the `tree` library, or methods related to tensorflow/python/utils/nest.py. As a side-effect, e.g. `np.testing.assert_array_equal` will only check the field names are equal and not the content. Use `chex.assert_tree_*` instead. Returns: A JAX-friendly dataclass. """ def dcls(cls): # Make sure to create a separate _Dataclass instance for each `cls`. return _Dataclass( init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass )(cls) if cls is None: return dcls return dcls(cls) class _Dataclass(): """JAX-friendly wrapper for `dataclasses.dataclass`.""" def __init__( self, init=True, repr=True, # pylint: disable=redefined-builtin eq=True, order=False, unsafe_hash=False, frozen=False, kw_only=False, mappable_dataclass=True, # pylint: disable=redefined-outer-name ): self.init = init self.repr = repr # pylint: disable=redefined-builtin self.eq = eq self.order = order self.unsafe_hash = unsafe_hash self.frozen = frozen self.kw_only = kw_only self.mappable_dataclass = mappable_dataclass def __call__(self, cls): """Forwards class to dataclasses's wrapper and registers it with JAX.""" # Remove once https://github.com/python/cpython/pull/24484 is merged. for base in cls.__bases__: if (dataclasses.is_dataclass(base) and getattr(base, "__dataclass_params__").frozen and not self.frozen): raise TypeError("cannot inherit non-frozen dataclass from a frozen one") # `kw_only` is only available starting from 3.10. version_dependent_args = {} version = sys.version_info if version.major == 3 and version.minor >= 10: version_dependent_args = {"kw_only": self.kw_only} # pytype: disable=wrong-keyword-args dcls = dataclasses.dataclass( cls, init=self.init, repr=self.repr, eq=self.eq, order=self.order, unsafe_hash=self.unsafe_hash, frozen=self.frozen, **version_dependent_args, ) # pytype: enable=wrong-keyword-args fields_names = set(f.name for f in dataclasses.fields(dcls)) invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES) if invalid_fields: raise ValueError(f"The following dataclass fields are disallowed: " f"{invalid_fields} ({dcls}).") if self.mappable_dataclass: dcls = mappable_dataclass(dcls) def _from_tuple(args): return dcls(zip(dcls.__dataclass_fields__.keys(), args)) def _to_tuple(self): return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys()) def _replace(self, **kwargs): return dataclasses.replace(self, **kwargs) def _getstate(self): return self.__dict__ # Register the dataclass at definition. As long as the dataclass is defined # outside __main__, this is sufficient to make JAX's PyTree registry # recognize the dataclass and the dataclass' custom PyTreeDef, especially # when unpickling either the dataclass object, its type, or its PyTreeDef, # in a different process, because the defining module will be imported. # # However, if the dataclass is defined in __main__, unpickling in a # subprocess does not trigger re-registration. Therefore we also need to # register when deserializing the object, or construction (e.g. when the # dataclass type is being unpickled). Unfortunately, there is not yet a way # to trigger re-registration when the treedef is unpickled as that's handled # by JAX. # # See internal dataclass_test for unit tests demonstrating the problems. register_dataclass_type_with_jax_tree_util(dcls) # Patch __setstate__ to register the dataclass on deserialization. def _setstate(self, state): register_dataclass_type_with_jax_tree_util(dcls) self.__dict__.update(state) orig_init = dcls.__init__ # Patch __init__ such that the dataclass is registered on creation if it is # not registered on deserialization. @functools.wraps(orig_init) def _init(self, *args, **kwargs): register_dataclass_type_with_jax_tree_util(dcls) return orig_init(self, *args, **kwargs) setattr(dcls, "from_tuple", _from_tuple) setattr(dcls, "to_tuple", _to_tuple) setattr(dcls, "replace", _replace) setattr(dcls, "__getstate__", _getstate) setattr(dcls, "__setstate__", _setstate) setattr(dcls, "__init__", _init) return dcls def _dataclass_unflatten(dcls, keys, values): """Creates a chex dataclass from a flatten jax.tree_util representation.""" dcls_object = dcls.__new__(dcls) attribute_dict = dict(zip(keys, values)) # Looping over fields instead of keys & values preserves the field order. # Using dataclasses.fields fails because dataclass uids change after # serialisation (eg, with cloudpickle). for field in dcls.__dataclass_fields__.values(): if field.name in attribute_dict: # Filter pseudo-fields. object.__setattr__(dcls_object, field.name, attribute_dict[field.name]) # Need to manual call post_init here as we have avoided calling __init__ if getattr(dcls_object, "__post_init__", None): dcls_object.__post_init__() return dcls_object def _flatten_with_path(dcls): path = [] keys = [] for k, v in sorted(dcls.__dict__.items()): k = jax.tree_util.GetAttrKey(k) path.append((k, v)) keys.append(k) return path, keys @functools.cache def register_dataclass_type_with_jax_tree_util(data_class): """Register an existing dataclass so JAX knows how to handle it. This means that functions in jax.tree_util operate over the fields of the dataclass. See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for further information. Args: data_class: A class created using dataclasses.dataclass. It must be constructable from keyword arguments corresponding to the members exposed in instance.__dict__. """ flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1] unflatten = functools.partial(_dataclass_unflatten, data_class) try: jax.tree_util.register_pytree_with_keys( nodetype=data_class, flatten_with_keys=_flatten_with_path, flatten_func=flatten, unflatten_func=unflatten) except ValueError: logging.info("%s is already registered as JAX PyTree node.", data_class)