PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `dataclass.py`."""
# pytype: disable=wrong-keyword-args # dataclass_transform
import copy
import dataclasses
import pickle
import sys
from typing import Any, Generic, Mapping, TypeVar
from absl.testing import absltest
from absl.testing import parameterized
from chex._src import asserts
from chex._src import dataclass
from chex._src import pytypes
import cloudpickle
import jax
import numpy as np
import tree
chex_dataclass = dataclass.dataclass
mappable_dataclass = dataclass.mappable_dataclass
orig_dataclass = dataclasses.dataclass
@chex_dataclass
class NestedDataclass():
c: pytypes.ArrayDevice
d: pytypes.ArrayDevice
@chex_dataclass
class PostInitDataclass:
a: pytypes.ArrayDevice
def __post_init__(self):
if not self.a > 0:
raise ValueError('a should be > than 0')
@chex_dataclass
class ReverseOrderNestedDataclass():
# The order of c and d are switched comapred to NestedDataclass.
d: pytypes.ArrayDevice
c: pytypes.ArrayDevice
@chex_dataclass
class Dataclass():
a: NestedDataclass
b: pytypes.ArrayDevice
@chex_dataclass(frozen=True)
class FrozenDataclass():
a: NestedDataclass
b: pytypes.ArrayDevice
def dummy_dataclass(factor=1., frozen=False):
class_ctor = FrozenDataclass if frozen else Dataclass
return class_ctor(
a=NestedDataclass(
c=factor * np.ones((3,), dtype=np.float32),
d=factor * np.ones((4,), dtype=np.float32)),
b=factor * 2 * np.ones((5,), dtype=np.float32))
def _dataclass_instance_fields(dcls_instance):
"""Serialization-friendly version of dataclasses.fields for instances."""
attribute_dict = dcls_instance.__dict__
fields = []
for field in dcls_instance.__dataclass_fields__.values():
if field.name in attribute_dict: # Filter pseudo-fields.
fields.append(field)
return fields
@orig_dataclass
class ClassWithoutMap:
k: dict # pylint:disable=g-bare-generic
def some_method(self, *args):
raise RuntimeError('ClassWithoutMap.some_method() was called.')
def _get_mappable_dataclasses(test_type):
"""Generates shallow and nested mappable dataclasses."""
class Class:
"""Shallow class."""
k_tuple: tuple # pylint:disable=g-bare-generic
k_dict: dict # pylint:disable=g-bare-generic
def some_method(self, *args):
raise RuntimeError('Class.some_method() was called.')
class NestedClass:
"""Nested class."""
k_any: Any
k_int: int
k_str: str
k_arr: np.ndarray
k_dclass_with_map: Class
k_dclass_no_map: ClassWithoutMap
k_dict_factory: dict = dataclasses.field( # pylint:disable=g-bare-generic,invalid-field-call
default_factory=lambda: dict(x='x', y='y'))
k_default: str = 'default_str'
k_non_init: int = dataclasses.field(default=1, init=False) # pylint:disable=g-bare-generic,invalid-field-call
k_init_only: dataclasses.InitVar[int] = 10
def some_method(self, *args):
raise RuntimeError('NestedClassWithMap.some_method() was called.')
def __post_init__(self, k_init_only):
self.k_non_init = self.k_int * k_init_only
if test_type == 'chex':
cls = chex_dataclass(Class, mappable_dataclass=True)
nested_cls = chex_dataclass(NestedClass, mappable_dataclass=True)
elif test_type == 'original':
cls = mappable_dataclass(orig_dataclass(Class))
nested_cls = mappable_dataclass(orig_dataclass(NestedClass))
else:
raise ValueError(f'Unknown test type: {test_type}')
return cls, nested_cls
@parameterized.named_parameters(('_original', 'original'), ('_chex', 'chex'))
class MappableDataclassTest(parameterized.TestCase):
def _init_testdata(self, test_type):
"""Initializes test data."""
map_cls, nested_map_cls = _get_mappable_dataclasses(test_type)
self.dcls_with_map_inner = map_cls(
k_tuple=(1, 2), k_dict=dict(k1=32, k2=33))
self.dcls_with_map_inner_inc = map_cls(
k_tuple=(2, 3), k_dict=dict(k1=33, k2=34))
self.dcls_no_map = ClassWithoutMap(k=dict(t='t', t2='t2'))
self.dcls_with_map = nested_map_cls(
k_any=None,
k_int=1,
k_str='test_str',
k_arr=np.array(16),
k_dclass_with_map=self.dcls_with_map_inner,
k_dclass_no_map=self.dcls_no_map)
self.dcls_with_map_inc_ints = nested_map_cls(
k_any=None,
k_int=2,
k_str='test_str',
k_arr=np.array(16),
k_dclass_with_map=self.dcls_with_map_inner_inc,
k_dclass_no_map=self.dcls_no_map,
k_default='default_str')
self.dcls_flattened_with_path = [
(('k_any',), None),
(('k_arr',), np.array(16)),
(('k_dclass_no_map',), self.dcls_no_map),
(('k_dclass_with_map', 'k_dict', 'k1'), 32),
(('k_dclass_with_map', 'k_dict', 'k2'), 33),
(('k_dclass_with_map', 'k_tuple', 0), 1),
(('k_dclass_with_map', 'k_tuple', 1), 2),
(('k_default',), 'default_str'),
(('k_dict_factory', 'x'), 'x'),
(('k_dict_factory', 'y'), 'y'),
(('k_int',), 1),
(('k_non_init',), 10),
(('k_str',), 'test_str'),
]
self.dcls_flattened_with_path_up_to = [
(('k_any',), None),
(('k_arr',), np.array(16)),
(('k_dclass_no_map',), self.dcls_no_map),
(('k_dclass_with_map',), self.dcls_with_map_inner),
(('k_default',), 'default_str'),
(('k_dict_factory', 'x'), 'x'),
(('k_dict_factory', 'y'), 'y'),
(('k_int',), 1),
(('k_non_init',), 10),
(('k_str',), 'test_str'),
]
self.dcls_flattened = [v for (_, v) in self.dcls_flattened_with_path]
self.dcls_flattened_up_to = [
v for (_, v) in self.dcls_flattened_with_path_up_to
]
self.dcls_tree_size = 18
self.dcls_tree_size_no_dicts = 14
def testFlattenAndUnflatten(self, test_type):
self._init_testdata(test_type)
self.assertEqual(self.dcls_flattened, tree.flatten(self.dcls_with_map))
self.assertEqual(
self.dcls_with_map,
tree.unflatten_as(self.dcls_with_map_inc_ints, self.dcls_flattened))
dataclass_in_seq = [34, self.dcls_with_map, [1, 2]]
dataclass_in_seq_flat = [34] + self.dcls_flattened + [1, 2]
self.assertEqual(dataclass_in_seq_flat, tree.flatten(dataclass_in_seq))
self.assertEqual(dataclass_in_seq,
tree.unflatten_as(dataclass_in_seq, dataclass_in_seq_flat))
def testFlattenUpTo(self, test_type):
self._init_testdata(test_type)
structure = copy.copy(self.dcls_with_map)
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
self.assertEqual(self.dcls_flattened_up_to,
tree.flatten_up_to(structure, self.dcls_with_map))
def testFlattenWithPath(self, test_type):
self._init_testdata(test_type)
self.assertEqual(
tree.flatten_with_path(self.dcls_with_map),
self.dcls_flattened_with_path)
def testFlattenWithPathUpTo(self, test_type):
self._init_testdata(test_type)
structure = copy.copy(self.dcls_with_map)
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
self.assertEqual(
tree.flatten_with_path_up_to(structure, self.dcls_with_map),
self.dcls_flattened_with_path_up_to)
def testMapStructure(self, test_type):
self._init_testdata(test_type)
add_one_to_ints_fn = lambda x: x + 1 if isinstance(x, int) else x
mapped_inc_ints = tree.map_structure(add_one_to_ints_fn, self.dcls_with_map)
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints)
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init,
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
def testMapStructureUpTo(self, test_type):
self._init_testdata(test_type)
structure = copy.copy(self.dcls_with_map)
structure.k_dclass_with_map = None # Do not map over 'k_dclass_with_map'
add_one_to_ints_fn = lambda x: x + 1 if isinstance(x, int) else x
mapped_inc_ints = tree.map_structure_up_to(structure, add_one_to_ints_fn,
self.dcls_with_map)
# k_dclass_with_map should be passed through unchanged
class_with_map = self.dcls_with_map.k_dclass_with_map
self.dcls_with_map_inc_ints.k_dclass_with_map = class_with_map
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints)
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init,
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
def testMapStructureWithPath(self, test_type):
self._init_testdata(test_type)
add_one_to_ints_fn = lambda path, x: x + 1 if isinstance(x, int) else x
mapped_inc_ints = tree.map_structure_with_path(add_one_to_ints_fn,
self.dcls_with_map)
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints)
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init,
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
def testMapStructureWithPathUpTo(self, test_type):
self._init_testdata(test_type)
structure = copy.copy(self.dcls_with_map)
structure.k_dclass_with_map = None # Do not map over 'k_dclass_with_map'
add_one_to_ints_fn = lambda path, x: x + 1 if isinstance(x, int) else x
mapped_inc_ints = tree.map_structure_with_path_up_to(
structure, add_one_to_ints_fn, self.dcls_with_map)
# k_dclass_with_map should be passed through unchanged
class_with_map = self.dcls_with_map.k_dclass_with_map
self.dcls_with_map_inc_ints.k_dclass_with_map = class_with_map
self.assertEqual(self.dcls_with_map_inc_ints, mapped_inc_ints)
self.assertEqual(self.dcls_with_map_inc_ints.k_non_init,
self.dcls_with_map_inc_ints.k_int * 10)
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
def testTraverse(self, test_type):
self._init_testdata(test_type)
visited = []
tree.traverse(visited.append, self.dcls_with_map, top_down=False)
self.assertLen(visited, self.dcls_tree_size)
visited_without_dicts = []
def visit_without_dicts(x):
visited_without_dicts.append(x)
return 'X' if isinstance(x, dict) else None
tree.traverse(visit_without_dicts, self.dcls_with_map, top_down=True)
self.assertLen(visited_without_dicts, self.dcls_tree_size_no_dicts)
def testIsDataclass(self, test_type):
self._init_testdata(test_type)
self.assertTrue(dataclasses.is_dataclass(self.dcls_no_map))
self.assertTrue(dataclasses.is_dataclass(self.dcls_with_map))
self.assertTrue(
dataclasses.is_dataclass(self.dcls_with_map.k_dclass_with_map))
self.assertTrue(
dataclasses.is_dataclass(self.dcls_with_map.k_dclass_no_map))
class DataclassesTest(parameterized.TestCase):
@parameterized.parameters([True, False])
def test_dataclass_tree_leaves(self, frozen):
obj = dummy_dataclass(frozen=frozen)
self.assertLen(jax.tree_util.tree_leaves(obj), 3)
@parameterized.parameters([True, False])
def test_dataclass_tree_map(self, frozen):
factor = 5.
obj = dummy_dataclass(frozen=frozen)
target_obj = dummy_dataclass(factor=factor, frozen=frozen)
asserts.assert_trees_all_close(
jax.tree_util.tree_map(lambda t: factor * t, obj), target_obj)
def test_tree_flatten_with_keys(self):
obj = dummy_dataclass()
keys_and_leaves, treedef = jax.tree_util.tree_flatten_with_path(obj)
self.assertEqual(
[k for k, _ in keys_and_leaves],
[
(jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c')),
(jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d')),
(jax.tree_util.GetAttrKey('b'),),
],
)
leaves = [l for _, l in keys_and_leaves]
new_obj = treedef.unflatten(leaves)
self.assertEqual(new_obj, obj)
def test_tree_map_with_keys(self):
obj = dummy_dataclass()
key_value_list, unused_treedef = jax.tree_util.tree_flatten_with_path(obj)
# Convert a list of key-value tuples to a dict.
flat_obj = dict(key_value_list)
def f(path, x):
value = flat_obj[path]
np.testing.assert_allclose(value, x)
return path
out = jax.tree_util.tree_map_with_path(f, obj)
self.assertEqual(
out.a.c, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('c'))
)
self.assertEqual(
out.a.d, (jax.tree_util.GetAttrKey('a'), jax.tree_util.GetAttrKey('d'))
)
self.assertEqual(out.b, (jax.tree_util.GetAttrKey('b'),))
def test_tree_map_with_keys_traversal_order(self):
# pytype: disable=wrong-arg-types
obj = ReverseOrderNestedDataclass(d=1, c=2)
# pytype: enable=wrong-arg-types
leaves = []
def f(_, x):
leaves.append(x)
jax.tree_util.tree_map_with_path(f, obj)
self.assertEqual(leaves, jax.tree_util.tree_leaves(obj))
@parameterized.parameters([True, False])
def test_dataclass_replace(self, frozen):
factor = 5.
obj = dummy_dataclass(frozen=frozen)
# pytype: disable=attribute-error # dataclass_transform
obj = obj.replace(a=obj.a.replace(c=factor * obj.a.c))
obj = obj.replace(a=obj.a.replace(d=factor * obj.a.d))
obj = obj.replace(b=factor * obj.b)
target_obj = dummy_dataclass(factor=factor, frozen=frozen)
asserts.assert_trees_all_close(obj, target_obj)
# pytype: enable=attribute-error
def test_dataclass_requires_kwargs_by_default(self):
factor = 1.0
with self.assertRaisesRegex(
ValueError,
"Mappable dataclass constructor doesn't support positional args.",
):
Dataclass(
NestedDataclass(
c=factor * np.ones((3,), dtype=np.float32),
d=factor * np.ones((4,), dtype=np.float32),
),
factor * 2 * np.ones((5,), dtype=np.float32),
)
def test_dataclass_mappable_dataclass_false(self):
factor = 1.0
@chex_dataclass(mappable_dataclass=False)
class NonMappableDataclass:
a: NestedDataclass
b: pytypes.ArrayDevice
NonMappableDataclass(
NestedDataclass(
c=factor * np.ones((3,), dtype=np.float32),
d=factor * np.ones((4,), dtype=np.float32),
),
factor * 2 * np.ones((5,), dtype=np.float32),
)
def test_inheritance_is_possible_thanks_to_kw_only(self):
if sys.version_info.minor < 10: # Feature only available for Python >= 3.10
return
@chex_dataclass(kw_only=True)
class Base:
default: int = 1
@chex_dataclass(kw_only=True)
class Child(Base):
non_default: int
Child(non_default=2)
def test_unfrozen_dataclass_is_mutable(self):
factor = 5.
obj = dummy_dataclass(frozen=False)
obj.a.c = factor * obj.a.c
obj.a.d = factor * obj.a.d
obj.b = factor * obj.b
target_obj = dummy_dataclass(factor=factor, frozen=False)
asserts.assert_trees_all_close(obj, target_obj)
def test_frozen_dataclass_raise_error(self):
factor = 5.
obj = dummy_dataclass(frozen=True)
obj.a.c = factor * obj.a.c # mutable since obj.a is not frozen.
with self.assertRaisesRegex(dataclass.FrozenInstanceError,
'cannot assign to field'):
obj.b = factor * obj.b # raises error because obj is frozen.
@parameterized.named_parameters(
('frozen', True),
('mutable', False),
)
def test_get_and_set_state(self, frozen):
@chex_dataclass(frozen=frozen)
class SimpleClass():
data: int = 1
obj_a = SimpleClass(data=1)
state = getattr(obj_a, '__getstate__')()
obj_b = SimpleClass(data=2)
getattr(obj_b, '__setstate__')(state)
self.assertEqual(obj_a, obj_b)
def test_unexpected_kwargs(self):
@chex_dataclass()
class SimpleDataclass:
a: int
b: int = 2
SimpleDataclass(a=1, b=3)
with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'):
SimpleDataclass(a=1, b=3, c=4) # pytype: disable=wrong-keyword-args
def test_tuple_conversion(self):
@chex_dataclass()
class SimpleDataclass:
b: int
a: int
obj = SimpleDataclass(a=2, b=1)
self.assertSequenceEqual(getattr(obj, 'to_tuple')(), (1, 2))
obj2 = getattr(SimpleDataclass, 'from_tuple')((1, 2))
self.assertEqual(obj.a, obj2.a)
self.assertEqual(obj.b, obj2.b)
@parameterized.named_parameters(
('frozen', True),
('mutable', False),
)
def test_tuple_rev_conversion(self, frozen):
obj = dummy_dataclass(frozen=frozen)
asserts.assert_trees_all_close(
type(obj).from_tuple(obj.to_tuple()), # pytype: disable=attribute-error
obj,
)
@parameterized.named_parameters(
('frozen', True),
('mutable', False),
)
def test_inheritance(self, frozen):
@chex_dataclass(frozen=frozen)
class Base:
x: int
@chex_dataclass(frozen=frozen)
class Derived(Base):
y: int
base_obj = Base(x=1)
self.assertNotIsInstance(base_obj, Derived)
self.assertIsInstance(base_obj, Base)
derived_obj = Derived(x=1, y=2)
self.assertIsInstance(derived_obj, Derived)
self.assertIsInstance(derived_obj, Base)
def test_inheritance_from_empty_frozen_base(self):
@chex_dataclass(frozen=True)
class FrozenBase:
pass
@chex_dataclass(frozen=True)
class DerivedFrozen(FrozenBase):
j: int
df = DerivedFrozen(j=2)
self.assertIsInstance(df, FrozenBase)
with self.assertRaisesRegex(
TypeError, 'cannot inherit non-frozen dataclass from a frozen one'):
# pylint:disable=unused-variable
@chex_dataclass
class DerivedMutable(FrozenBase):
j: int
# pylint:enable=unused-variable
def test_disallowed_fields(self):
# pylint:disable=unused-variable
with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):
@chex_dataclass(mappable_dataclass=False)
class InvalidNonMappable:
from_tuple: int
@chex_dataclass(mappable_dataclass=False)
class ValidMappable:
get: int
with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):
@chex_dataclass(mappable_dataclass=True)
class InvalidMappable:
get: int
from_tuple: int
# pylint:enable=unused-variable
@parameterized.parameters(True, False)
def test_flatten_is_leaf(self, is_mappable):
@chex_dataclass(mappable_dataclass=is_mappable)
class _InnerDcls:
v_1: int
v_2: int
@chex_dataclass(mappable_dataclass=is_mappable)
class _Dcls:
str_val: str
# pytype: disable=invalid-annotation # enable-bare-annotations
inner_dcls: _InnerDcls
dct: Mapping[str, _InnerDcls]
# pytype: enable=invalid-annotation # enable-bare-annotations
dcls = _Dcls(
str_val='test',
inner_dcls=_InnerDcls(v_1=1, v_2=11),
dct={
'md1': _InnerDcls(v_1=2, v_2=22),
'md2': _InnerDcls(v_1=3, v_2=33)
})
def _is_leaf(value) -> bool:
# Must not traverse over integers.
self.assertNotIsInstance(value, int)
return isinstance(value, (_InnerDcls, str))
leaves = jax.tree_util.tree_flatten(dcls, is_leaf=_is_leaf)[0]
self.assertCountEqual(
(dcls.str_val, dcls.inner_dcls, dcls.dct['md1'], dcls.dct['md2']),
leaves)
asserts.assert_trees_all_equal_structs(
jax.tree_util.tree_map(lambda x: x, dcls, is_leaf=_is_leaf), dcls)
def test_decorator_alias(self):
# Make sure, that creating a decorator alias works correctly.
configclass = chex_dataclass(frozen=True)
@configclass
class Foo:
bar: int = 1
toto: int = 2
@configclass
class Bar:
bar: int = 1
toto: int = 2
# Verify that both Foo and Bar are correctly registered with jax.tree_util.
self.assertLen(jax.tree_util.tree_flatten(Foo())[0], 2)
self.assertLen(jax.tree_util.tree_flatten(Bar())[0], 2)
@parameterized.named_parameters(
('mappable', True),
('not_mappable', False),
)
def test_generic_dataclass(self, mappable):
T = TypeVar('T')
@chex_dataclass(mappable_dataclass=mappable)
class GenericDataclass(Generic[T]):
a: T # pytype: disable=invalid-annotation # enable-bare-annotations
obj = GenericDataclass(a=np.array([1.0, 1.0]))
asserts.assert_trees_all_close(obj.a, 1.0)
def test_mappable_eq_override(self):
@chex_dataclass(mappable_dataclass=True)
class EqDataclass:
a: pytypes.ArrayDevice
def __eq__(self, other):
if isinstance(other, EqDataclass):
return other.a[0] == self.a[0]
return False
obj1 = EqDataclass(a=np.array([1.0, 1.0]))
obj2 = EqDataclass(a=np.array([1.0, 0.0]))
obj3 = EqDataclass(a=np.array([0.0, 1.0]))
self.assertEqual(obj1, obj2)
self.assertNotEqual(obj1, obj3)
@parameterized.parameters([NestedDataclass, ReverseOrderNestedDataclass])
def test_dataclass_instance_fields(self, dcls):
obj = dcls(c=1, d=2)
self.assertSequenceEqual(
dataclasses.fields(obj), _dataclass_instance_fields(obj))
@parameterized.parameters((pickle, NestedDataclass),
(cloudpickle, ReverseOrderNestedDataclass))
def test_roundtrip_serialization(self, serialization_lib, dcls):
obj = dcls(c=1, d=2)
obj_fields = [
(f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj)
]
self.assertLen(obj_fields, 2)
obj2 = serialization_lib.loads(serialization_lib.dumps(obj))
obj2_fields = [(f.name, getattr(obj2, f.name))
for f in _dataclass_instance_fields(obj2)]
self.assertSequenceEqual(obj_fields, obj2_fields)
self.assertSequenceEqual(jax.tree_util.tree_leaves(obj2), [1, 2])
obj3 = jax.tree_util.tree_map(lambda x: x, obj2)
obj3_fields = [(f.name, getattr(obj3, f.name))
for f in _dataclass_instance_fields(obj3)]
self.assertSequenceEqual(obj_fields, obj3_fields)
self.assertSequenceEqual(jax.tree_util.tree_leaves(obj3), [1, 2])
@parameterized.parameters([NestedDataclass, ReverseOrderNestedDataclass])
def test_flatten_roundtrip_ordering(self, dcls):
obj = dcls(c=1, d=2)
leaves, treedef = jax.tree_util.tree_flatten(obj)
self.assertSequenceEqual(leaves, [1, 2])
obj2 = jax.tree_util.tree_unflatten(treedef, leaves)
self.assertSequenceEqual(dataclasses.fields(obj2), dataclasses.fields(obj))
def test_flatten_respects_post_init(self):
obj = PostInitDataclass(a=1) # pytype: disable=wrong-arg-types
with self.assertRaises(ValueError):
_ = jax.tree_util.tree_map(lambda x: 0, obj)
@parameterized.parameters([False, True])
def test_keys_and_values_type(self, frozen):
obj = dummy_dataclass(frozen=frozen)
self.assertEqual(
type(obj.keys()), # pytype: disable=attribute-error
type({}.keys()),
)
self.assertEqual(
type(obj.values()), # pytype: disable=attribute-error
type({}.values()),
)
@parameterized.parameters([False, True])
def test_keys_and_values_override(self, frozen):
@chex_dataclass(frozen=frozen)
class _Dataclass:
x: int
values: int
obj = _Dataclass(x=1, values=2)
self.assertEqual(
list(obj.keys()), # pytype: disable=attribute-error
['x', 'values'],
)
self.assertEqual(obj.values, 2)
if __name__ == '__main__':
absltest.main()