Spaces:
Building
Building
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for `dataclass.py`.""" | |
# pytype: disable=wrong-keyword-args # dataclass_transform | |
import copy | |
import dataclasses | |
import pickle | |
import sys | |
from typing import Any, Generic, Mapping, TypeVar | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
from chex._src import asserts | |
from chex._src import dataclass | |
from chex._src import pytypes | |
import cloudpickle | |
import jax | |
import numpy as np | |
import tree | |
chex_dataclass = dataclass.dataclass | |
mappable_dataclass = dataclass.mappable_dataclass | |
orig_dataclass = dataclasses.dataclass | |
class NestedDataclass(): | |
c: pytypes.ArrayDevice | |
d: pytypes.ArrayDevice | |
class PostInitDataclass: | |
a: pytypes.ArrayDevice | |
def __post_init__(self): | |
if not self.a > 0: | |
raise ValueError('a should be > than 0') | |
class ReverseOrderNestedDataclass(): | |
# The order of c and d are switched comapred to NestedDataclass. | |
d: pytypes.ArrayDevice | |
c: pytypes.ArrayDevice | |
class Dataclass(): | |
a: NestedDataclass | |
b: pytypes.ArrayDevice | |
class FrozenDataclass(): | |
a: NestedDataclass | |
b: pytypes.ArrayDevice | |
def dummy_dataclass(factor=1., frozen=False): | |
class_ctor = FrozenDataclass if frozen else Dataclass | |
return class_ctor( | |
a=NestedDataclass( | |
c=factor * np.ones((3,), dtype=np.float32), | |
d=factor * np.ones((4,), dtype=np.float32)), | |
b=factor * 2 * np.ones((5,), dtype=np.float32)) | |
def _dataclass_instance_fields(dcls_instance): | |
"""Serialization-friendly version of dataclasses.fields for instances.""" | |
attribute_dict = dcls_instance.__dict__ | |
fields = [] | |
for field in dcls_instance.__dataclass_fields__.values(): | |
if field.name in attribute_dict: # Filter pseudo-fields. | |
fields.append(field) | |
return fields | |
class ClassWithoutMap: | |
k: dict # pylint:disable=g-bare-generic | |
def some_method(self, *args): | |
raise RuntimeError('ClassWithoutMap.some_method() was called.') | |
def _get_mappable_dataclasses(test_type): | |
"""Generates shallow and nested mappable dataclasses.""" | |
class Class: | |
"""Shallow class.""" | |
k_tuple: tuple # pylint:disable=g-bare-generic | |
k_dict: dict # pylint:disable=g-bare-generic | |
def some_method(self, *args): | |
raise RuntimeError('Class.some_method() was called.') | |
class NestedClass: | |
"""Nested class.""" | |
k_any: Any | |
k_int: int | |
k_str: str | |
k_arr: np.ndarray | |
k_dclass_with_map: Class | |
k_dclass_no_map: ClassWithoutMap | |
k_dict_factory: dict = dataclasses.field( # pylint:disable=g-bare-generic,invalid-field-call | |
default_factory=lambda: dict(x='x', y='y')) | |
k_default: str = 'default_str' | |
k_non_init: int = dataclasses.field(default=1, init=False) # pylint:disable=g-bare-generic,invalid-field-call | |
k_init_only: dataclasses.InitVar[int] = 10 | |
def some_method(self, *args): | |
raise RuntimeError('NestedClassWithMap.some_method() was called.') | |
def __post_init__(self, k_init_only): | |
self.k_non_init = self.k_int * k_init_only | |
if test_type == 'chex': | |
cls = chex_dataclass(Class, mappable_dataclass=True) | |
nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True) | |
elif test_type == 'original': | |
cls = mappable_dataclass(orig_dataclass(Class)) | |
nested_cls = mappable_dataclass(orig_dataclass(NestedClass)) | |
else: | |
raise ValueError(f'Unknown test type: {test_type}') | |
return cls, nested_cls | |
class MappableDataclassTest(parameterized.TestCase): | |
def _init_testdata(self, test_type): | |
"""Initializes test data.""" | |
map_cls, nested_map_cls = _get_mappable_dataclasses(test_type) | |
self.dcls_with_map_inner = map_cls( | |
k_tuple=(1, 2), k_dict=dict(k1=32, k2=33)) | |
self.dcls_with_map_inner_inc = map_cls( | |
k_tuple=(2, 3), k_dict=dict(k1=33, k2=34)) | |
self.dcls_no_map = ClassWithoutMap(k=dict(t='t', t2='t2')) | |
self.dcls_with_map = nested_map_cls( | |
k_any=None, | |
k_int=1, | |
k_str='test_str', | |
k_arr=np.array(16), | |
k_dclass_with_map=self.dcls_with_map_inner, | |
k_dclass_no_map=self.dcls_no_map) | |
self.dcls_with_map_inc_ints = nested_map_cls( | |
k_any=None, | |
k_int=2, | |
k_str='test_str', | |
k_arr=np.array(16), | |
k_dclass_with_map=self.dcls_with_map_inner_inc, | |
k_dclass_no_map=self.dcls_no_map, | |
k_default='default_str') | |
self.dcls_flattened_with_path = [ | |
(('k_any',), None), | |
(('k_arr',), np.array(16)), | |
(('k_dclass_no_map',), self.dcls_no_map), | |
(('k_dclass_with_map', 'k_dict', 'k1'), 32), | |
(('k_dclass_with_map', 'k_dict', 'k2'), 33), | |
(('k_dclass_with_map', 'k_tuple', 0), 1), | |
(('k_dclass_with_map', 'k_tuple', 1), 2), | |
(('k_default',), 'default_str'), | |
(('k_dict_factory', 'x'), 'x'), | |
(('k_dict_factory', 'y'), 'y'), | |
(('k_int',), 1), | |
(('k_non_init',), 10), | |
(('k_str',), 'test_str'), | |
] | |
self.dcls_flattened_with_path_up_to = [ | |
(('k_any',), None), | |
(('k_arr',), np.array(16)), | |
(('k_dclass_no_map',), self.dcls_no_map), | |
(('k_dclass_with_map',), self.dcls_with_map_inner), | |
(('k_default',), 'default_str'), | |
(('k_dict_factory', 'x'), 'x'), | |
(('k_dict_factory', 'y'), 'y'), | |
(('k_int',), 1), | |
(('k_non_init',), 10), | |
(('k_str',), 'test_str'), | |
] | |
self.dcls_flattened = [v for (_, v) in self.dcls_flattened_with_path] | |
self.dcls_flattened_up_to = [ | |
v for (_, v) in self.dcls_flattened_with_path_up_to | |
] | |
self.dcls_tree_size = 18 | |
self.dcls_tree_size_no_dicts = 14 | |
def testFlattenAndUnflatten(self, test_type): | |
self._init_testdata(test_type) | |
self.assertEqual(self.dcls_flattened, tree.flatten(self.dcls_with_map)) | |
self.assertEqual( | |
self.dcls_with_map, | |
tree.unflatten_as(self.dcls_with_map_inc_ints, self.dcls_flattened)) | |
dataclass_in_seq = [34, self.dcls_with_map, [1, 2]] | |
dataclass_in_seq_flat = [34] + self.dcls_flattened + [1, 2] | |
self.assertEqual(dataclass_in_seq_flat, tree.flatten(dataclass_in_seq)) | |
self.assertEqual(dataclass_in_seq, | |
tree.unflatten_as(dataclass_in_seq, dataclass_in_seq_flat)) | |
def testFlattenUpTo(self, test_type): | |
self._init_testdata(test_type) | |
structure = copy.copy(self.dcls_with_map) | |
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map' | |
self.assertEqual(self.dcls_flattened_up_to, | |
tree.flatten_up_to(structure, self.dcls_with_map)) | |
def testFlattenWithPath(self, test_type): | |
self._init_testdata(test_type) | |
self.assertEqual( | |
tree.flatten_with_path(self.dcls_with_map), | |
self.dcls_flattened_with_path) | |
def testFlattenWithPathUpTo(self, test_type): | |
self._init_testdata(test_type) | |
structure = copy.copy(self.dcls_with_map) | |
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map' | |
self.assertEqual( | |
tree.flatten_with_path_up_to(structure, self.dcls_with_map), | |
self.dcls_flattened_with_path_up_to) | |
def testMapStructure(self, test_type): | |
self._init_testdata(test_type) | |
add_one_to_ints_fn = lambda x: x + 1 if isinstance(x, int) else x | |
mapped_inc_ints = tree.map_structure(add_one_to_ints_fn, self.dcls_with_map) | |
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints) | |
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init, | |
self.dcls_with_map_inc_ints.k_int * 10) | |
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) | |
def testMapStructureUpTo(self, test_type): | |
self._init_testdata(test_type) | |
structure = copy.copy(self.dcls_with_map) | |
structure.k_dclass_with_map = None # Do not map over 'k_dclass_with_map' | |
add_one_to_ints_fn = lambda x: x + 1 if isinstance(x, int) else x | |
mapped_inc_ints = tree.map_structure_up_to(structure, add_one_to_ints_fn, | |
self.dcls_with_map) | |
# k_dclass_with_map should be passed through unchanged | |
class_with_map = self.dcls_with_map.k_dclass_with_map | |
self.dcls_with_map_inc_ints.k_dclass_with_map = class_with_map | |
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints) | |
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init, | |
self.dcls_with_map_inc_ints.k_int * 10) | |
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) | |
def testMapStructureWithPath(self, test_type): | |
self._init_testdata(test_type) | |
add_one_to_ints_fn = lambda path, x: x + 1 if isinstance(x, int) else x | |
mapped_inc_ints = tree.map_structure_with_path(add_one_to_ints_fn, | |
self.dcls_with_map) | |
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints) | |
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init, | |
self.dcls_with_map_inc_ints.k_int * 10) | |
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) | |
def testMapStructureWithPathUpTo(self, test_type): | |
self._init_testdata(test_type) | |
structure = copy.copy(self.dcls_with_map) | |
structure.k_dclass_with_map = None # Do not map over 'k_dclass_with_map' | |
add_one_to_ints_fn = lambda path, x: x + 1 if isinstance(x, int) else x | |
mapped_inc_ints = tree.map_structure_with_path_up_to( | |
structure, add_one_to_ints_fn, self.dcls_with_map) | |
# k_dclass_with_map should be passed through unchanged | |
class_with_map = self.dcls_with_map.k_dclass_with_map | |
self.dcls_with_map_inc_ints.k_dclass_with_map = class_with_map | |
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints) | |
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init, | |
self.dcls_with_map_inc_ints.k_int * 10) | |
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10) | |
def testTraverse(self, test_type): | |
self._init_testdata(test_type) | |
visited = [] | |
tree.traverse(visited.append, self.dcls_with_map, top_down=False) | |
self.assertLen(visited, self.dcls_tree_size) | |
visited_without_dicts = [] | |
def visit_without_dicts(x): | |
visited_without_dicts.append(x) | |
return 'X' if isinstance(x, dict) else None | |
tree.traverse(visit_without_dicts, self.dcls_with_map, top_down=True) | |
self.assertLen(visited_without_dicts, self.dcls_tree_size_no_dicts) | |
def testIsDataclass(self, test_type): | |
self._init_testdata(test_type) | |
self.assertTrue(dataclasses.is_dataclass(self.dcls_no_map)) | |
self.assertTrue(dataclasses.is_dataclass(self.dcls_with_map)) | |
self.assertTrue( | |
dataclasses.is_dataclass(self.dcls_with_map.k_dclass_with_map)) | |
self.assertTrue( | |
dataclasses.is_dataclass(self.dcls_with_map.k_dclass_no_map)) | |
class DataclassesTest(parameterized.TestCase): | |
def test_dataclass_tree_leaves(self, frozen): | |
obj = dummy_dataclass(frozen=frozen) | |
self.assertLen(jax.tree_util.tree_leaves(obj), 3) | |
def test_dataclass_tree_map(self, frozen): | |
factor = 5. | |
obj = dummy_dataclass(frozen=frozen) | |
target_obj = dummy_dataclass(factor=factor, frozen=frozen) | |
asserts.assert_trees_all_close( | |
jax.tree_util.tree_map(lambda t: factor * t, obj), target_obj) | |
def test_tree_flatten_with_keys(self): | |
obj = dummy_dataclass() | |
keys_and_leaves, treedef = jax.tree_util.tree_flatten_with_path(obj) | |
self.assertEqual( | |
[k for k, _ in keys_and_leaves], | |
[ | |
(jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c')), | |
(jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d')), | |
(jax.tree_util.GetAttrKey('b'),), | |
], | |
) | |
leaves = [l for _, l in keys_and_leaves] | |
new_obj = treedef.unflatten(leaves) | |
self.assertEqual(new_obj, obj) | |
def test_tree_map_with_keys(self): | |
obj = dummy_dataclass() | |
key_value_list, unused_treedef = jax.tree_util.tree_flatten_with_path(obj) | |
# Convert a list of key-value tuples to a dict. | |
flat_obj = dict(key_value_list) | |
def f(path, x): | |
value = flat_obj[path] | |
np.testing.assert_allclose(value, x) | |
return path | |
out = jax.tree_util.tree_map_with_path(f, obj) | |
self.assertEqual( | |
out.a.c, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c')) | |
) | |
self.assertEqual( | |
out.a.d, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d')) | |
) | |
self.assertEqual(out.b, (jax.tree_util.GetAttrKey('b'),)) | |
def test_tree_map_with_keys_traversal_order(self): | |
# pytype: disable=wrong-arg-types | |
obj = ReverseOrderNestedDataclass(d=1, c=2) | |
# pytype: enable=wrong-arg-types | |
leaves = [] | |
def f(_, x): | |
leaves.append(x) | |
jax.tree_util.tree_map_with_path(f, obj) | |
self.assertEqual(leaves, jax.tree_util.tree_leaves(obj)) | |
def test_dataclass_replace(self, frozen): | |
factor = 5. | |
obj = dummy_dataclass(frozen=frozen) | |
# pytype: disable=attribute-error # dataclass_transform | |
obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c)) | |
obj = obj.replace(a=obj.a.replace(d=factor * obj.a.d)) | |
obj = obj.replace(b=factor * obj.b) | |
target_obj = dummy_dataclass(factor=factor, frozen=frozen) | |
asserts.assert_trees_all_close(obj, target_obj) | |
# pytype: enable=attribute-error | |
def test_dataclass_requires_kwargs_by_default(self): | |
factor = 1.0 | |
with self.assertRaisesRegex( | |
ValueError, | |
"Mappable dataclass constructor doesn't support positional args.", | |
): | |
Dataclass( | |
NestedDataclass( | |
c=factor * np.ones((3,), dtype=np.float32), | |
d=factor * np.ones((4,), dtype=np.float32), | |
), | |
factor * 2 * np.ones((5,), dtype=np.float32), | |
) | |
def test_dataclass_mappable_dataclass_false(self): | |
factor = 1.0 | |
class NonMappableDataclass: | |
a: NestedDataclass | |
b: pytypes.ArrayDevice | |
NonMappableDataclass( | |
NestedDataclass( | |
c=factor * np.ones((3,), dtype=np.float32), | |
d=factor * np.ones((4,), dtype=np.float32), | |
), | |
factor * 2 * np.ones((5,), dtype=np.float32), | |
) | |
def test_inheritance_is_possible_thanks_to_kw_only(self): | |
if sys.version_info.minor < 10: # Feature only available for Python >= 3.10 | |
return | |
class Base: | |
default: int = 1 | |
class Child(Base): | |
non_default: int | |
Child(non_default=2) | |
def test_unfrozen_dataclass_is_mutable(self): | |
factor = 5. | |
obj = dummy_dataclass(frozen=False) | |
obj.a.c = factor * obj.a.c | |
obj.a.d = factor * obj.a.d | |
obj.b = factor * obj.b | |
target_obj = dummy_dataclass(factor=factor, frozen=False) | |
asserts.assert_trees_all_close(obj, target_obj) | |
def test_frozen_dataclass_raise_error(self): | |
factor = 5. | |
obj = dummy_dataclass(frozen=True) | |
obj.a.c = factor * obj.a.c # mutable since obj.a is not frozen. | |
with self.assertRaisesRegex(dataclass.FrozenInstanceError, | |
'cannot assign to field'): | |
obj.b = factor * obj.b # raises error because obj is frozen. | |
def test_get_and_set_state(self, frozen): | |
class SimpleClass(): | |
data: int = 1 | |
obj_a = SimpleClass(data=1) | |
state = getattr(obj_a, '__getstate__')() | |
obj_b = SimpleClass(data=2) | |
getattr(obj_b, '__setstate__')(state) | |
self.assertEqual(obj_a, obj_b) | |
def test_unexpected_kwargs(self): | |
class SimpleDataclass: | |
a: int | |
b: int = 2 | |
SimpleDataclass(a=1, b=3) | |
with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'): | |
SimpleDataclass(a=1, b=3, c=4) # pytype: disable=wrong-keyword-args | |
def test_tuple_conversion(self): | |
class SimpleDataclass: | |
b: int | |
a: int | |
obj = SimpleDataclass(a=2, b=1) | |
self.assertSequenceEqual(getattr(obj, 'to_tuple')(), (1, 2)) | |
obj2 = getattr(SimpleDataclass, 'from_tuple')((1, 2)) | |
self.assertEqual(obj.a, obj2.a) | |
self.assertEqual(obj.b, obj2.b) | |
def test_tuple_rev_conversion(self, frozen): | |
obj = dummy_dataclass(frozen=frozen) | |
asserts.assert_trees_all_close( | |
type(obj).from_tuple(obj.to_tuple()), # pytype: disable=attribute-error | |
obj, | |
) | |
def test_inheritance(self, frozen): | |
class Base: | |
x: int | |
class Derived(Base): | |
y: int | |
base_obj = Base(x=1) | |
self.assertNotIsInstance(base_obj, Derived) | |
self.assertIsInstance(base_obj, Base) | |
derived_obj = Derived(x=1, y=2) | |
self.assertIsInstance(derived_obj, Derived) | |
self.assertIsInstance(derived_obj, Base) | |
def test_inheritance_from_empty_frozen_base(self): | |
class FrozenBase: | |
pass | |
class DerivedFrozen(FrozenBase): | |
j: int | |
df = DerivedFrozen(j=2) | |
self.assertIsInstance(df, FrozenBase) | |
with self.assertRaisesRegex( | |
TypeError, 'cannot inherit non-frozen dataclass from a frozen one'): | |
# pylint:disable=unused-variable | |
class DerivedMutable(FrozenBase): | |
j: int | |
# pylint:enable=unused-variable | |
def test_disallowed_fields(self): | |
# pylint:disable=unused-variable | |
with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'): | |
class InvalidNonMappable: | |
from_tuple: int | |
class ValidMappable: | |
get: int | |
with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'): | |
class InvalidMappable: | |
get: int | |
from_tuple: int | |
# pylint:enable=unused-variable | |
def test_flatten_is_leaf(self, is_mappable): | |
class _InnerDcls: | |
v_1: int | |
v_2: int | |
class _Dcls: | |
str_val: str | |
# pytype: disable=invalid-annotation # enable-bare-annotations | |
inner_dcls: _InnerDcls | |
dct: Mapping[str, _InnerDcls] | |
# pytype: enable=invalid-annotation # enable-bare-annotations | |
dcls = _Dcls( | |
str_val='test', | |
inner_dcls=_InnerDcls(v_1=1, v_2=11), | |
dct={ | |
'md1': _InnerDcls(v_1=2, v_2=22), | |
'md2': _InnerDcls(v_1=3, v_2=33) | |
}) | |
def _is_leaf(value) -> bool: | |
# Must not traverse over integers. | |
self.assertNotIsInstance(value, int) | |
return isinstance(value, (_InnerDcls, str)) | |
leaves = jax.tree_util.tree_flatten(dcls, is_leaf=_is_leaf)[0] | |
self.assertCountEqual( | |
(dcls.str_val, dcls.inner_dcls, dcls.dct['md1'], dcls.dct['md2']), | |
leaves) | |
asserts.assert_trees_all_equal_structs( | |
jax.tree_util.tree_map(lambda x: x, dcls, is_leaf=_is_leaf), dcls) | |
def test_decorator_alias(self): | |
# Make sure, that creating a decorator alias works correctly. | |
configclass = chex_dataclass(frozen=True) | |
class Foo: | |
bar: int = 1 | |
toto: int = 2 | |
class Bar: | |
bar: int = 1 | |
toto: int = 2 | |
# Verify that both Foo and Bar are correctly registered with jax.tree_util. | |
self.assertLen(jax.tree_util.tree_flatten(Foo())[0], 2) | |
self.assertLen(jax.tree_util.tree_flatten(Bar())[0], 2) | |
def test_generic_dataclass(self, mappable): | |
T = TypeVar('T') | |
class GenericDataclass(Generic[T]): | |
a: T # pytype: disable=invalid-annotation # enable-bare-annotations | |
obj = GenericDataclass(a=np.array([1.0, 1.0])) | |
asserts.assert_trees_all_close(obj.a, 1.0) | |
def test_mappable_eq_override(self): | |
class EqDataclass: | |
a: pytypes.ArrayDevice | |
def __eq__(self, other): | |
if isinstance(other, EqDataclass): | |
return other.a[0] == self.a[0] | |
return False | |
obj1 = EqDataclass(a=np.array([1.0, 1.0])) | |
obj2 = EqDataclass(a=np.array([1.0, 0.0])) | |
obj3 = EqDataclass(a=np.array([0.0, 1.0])) | |
self.assertEqual(obj1, obj2) | |
self.assertNotEqual(obj1, obj3) | |
def test_dataclass_instance_fields(self, dcls): | |
obj = dcls(c=1, d=2) | |
self.assertSequenceEqual( | |
dataclasses.fields(obj), _dataclass_instance_fields(obj)) | |
def test_roundtrip_serialization(self, serialization_lib, dcls): | |
obj = dcls(c=1, d=2) | |
obj_fields = [ | |
(f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj) | |
] | |
self.assertLen(obj_fields, 2) | |
obj2 = serialization_lib.loads(serialization_lib.dumps(obj)) | |
obj2_fields = [(f.name, getattr(obj2, f.name)) | |
for f in _dataclass_instance_fields(obj2)] | |
self.assertSequenceEqual(obj_fields, obj2_fields) | |
self.assertSequenceEqual(jax.tree_util.tree_leaves(obj2), [1, 2]) | |
obj3 = jax.tree_util.tree_map(lambda x: x, obj2) | |
obj3_fields = [(f.name, getattr(obj3, f.name)) | |
for f in _dataclass_instance_fields(obj3)] | |
self.assertSequenceEqual(obj_fields, obj3_fields) | |
self.assertSequenceEqual(jax.tree_util.tree_leaves(obj3), [1, 2]) | |
def test_flatten_roundtrip_ordering(self, dcls): | |
obj = dcls(c=1, d=2) | |
leaves, treedef = jax.tree_util.tree_flatten(obj) | |
self.assertSequenceEqual(leaves, [1, 2]) | |
obj2 = jax.tree_util.tree_unflatten(treedef, leaves) | |
self.assertSequenceEqual(dataclasses.fields(obj2), dataclasses.fields(obj)) | |
def test_flatten_respects_post_init(self): | |
obj = PostInitDataclass(a=1) # pytype: disable=wrong-arg-types | |
with self.assertRaises(ValueError): | |
_ = jax.tree_util.tree_map(lambda x: 0, obj) | |
def test_keys_and_values_type(self, frozen): | |
obj = dummy_dataclass(frozen=frozen) | |
self.assertEqual( | |
type(obj.keys()), # pytype: disable=attribute-error | |
type({}.keys()), | |
) | |
self.assertEqual( | |
type(obj.values()), # pytype: disable=attribute-error | |
type({}.values()), | |
) | |
def test_keys_and_values_override(self, frozen): | |
class _Dataclass: | |
x: int | |
values: int | |
obj = _Dataclass(x=1, values=2) | |
self.assertEqual( | |
list(obj.keys()), # pytype: disable=attribute-error | |
['x', 'values'], | |
) | |
self.assertEqual(obj.values, 2) | |
if __name__ == '__main__': | |
absltest.main() | |