PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `asserts.py`."""
import functools
import re
from absl.testing import absltest
from absl.testing import parameterized
from chex._src import asserts
from chex._src import asserts_internal
from chex._src import variants
import jax
import jax.numpy as jnp
import numpy as np
_get_err_regex = asserts_internal.get_err_regex
_num_devices_available = asserts_internal.num_devices_available
def as_arrays(arrays):
return [np.asarray(a) for a in arrays]
def array_from_shape(*shape):
return np.ones(shape=shape)
def emplace(arrays, dtype):
return jnp.array(arrays, dtype=dtype)
class AssertsSwitchTest(parameterized.TestCase):
"""Tests for enable/disable_asserts."""
def test_enable_disable_asserts(self):
with self.assertRaisesRegex(AssertionError, _get_err_regex('scalar')):
asserts.assert_scalar('test')
asserts.disable_asserts()
asserts.assert_scalar('test')
asserts.enable_asserts()
with self.assertRaisesRegex(AssertionError, _get_err_regex('scalar')):
asserts.assert_scalar('test')
asserts.disable_asserts()
asserts.assert_is_divisible(13, 5)
# To avoid side effects.
asserts.enable_asserts()
class AssertMaxTracesTest(variants.TestCase):
def setUp(self):
super().setUp()
asserts.clear_trace_counter()
def _init(self, fn_, init_type, max_traces, kwargs, static_arg):
"""Initializes common test cases."""
variant_kwargs = {}
if static_arg:
variant_kwargs['static_argnums'] = 1
if kwargs:
args, kwargs = [], {'n': max_traces}
else:
args, kwargs = [max_traces], {}
if init_type == 't1':
@asserts.assert_max_traces(*args, **kwargs)
def fn(x, y):
if static_arg:
self.assertNotIsInstance(y, jax.core.Tracer)
return fn_(x, y)
fn_jitted = self.variant(fn, **variant_kwargs)
elif init_type == 't2':
def fn(x, y):
if static_arg:
self.assertNotIsInstance(y, jax.core.Tracer)
return fn_(x, y)
fn = asserts.assert_max_traces(fn, *args, **kwargs)
fn_jitted = self.variant(fn, **variant_kwargs)
elif init_type == 't3':
def fn(x, y):
if static_arg:
self.assertNotIsInstance(y, jax.core.Tracer)
return fn_(x, y)
@self.variant(**variant_kwargs)
@asserts.assert_max_traces(*args, **kwargs)
def fn_jitted(x, y):
self.assertIsInstance(x, jax.core.Tracer)
return fn_(x, y)
else:
raise ValueError(f'Unknown type {init_type}.')
return fn, fn_jitted
@variants.variants(with_jit=True, with_pmap=True)
@parameterized.named_parameters(
variants.params_product((
('type1', 't1'),
('type2', 't2'),
('type3', 't3'),
), (
('args', False),
('kwargs', True),
), (
('no_static_arg', False),
('with_static_arg', True),
), (
('max_traces_0', 0),
('max_traces_1', 1),
('max_traces_2', 2),
('max_traces_10', 10),
),
named=True))
def test_assert(self, init_type, kwargs, static_arg, max_traces):
fn_ = lambda x, y: x + y
fn, fn_jitted = self._init(fn_, init_type, max_traces, kwargs, static_arg)
# Original function.
for _ in range(max_traces + 3):
self.assertEqual(fn(1, 2), 3)
# Every call results in re-tracing because arguments' shapes are different.
for i in range(max_traces):
for k in range(5):
arg = jnp.zeros(i + 1) + k
np.testing.assert_array_equal(fn_jitted(arg, 2), arg + 2)
# Original function.
for _ in range(max_traces + 3):
self.assertEqual(fn(1, 2), 3)
self.assertEqual(fn([1], [2]), [1, 2])
self.assertEqual(fn('a', 'b'), 'ab')
# (max_traces + 1)-th re-tracing.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('fn.* is traced > .* times!')):
arg = jnp.zeros(max_traces + 1)
fn_jitted(arg, 2)
def test_incorrect_ordering(self):
# pylint:disable=g-error-prone-assert-raises,unused-variable
with self.assertRaisesRegex(ValueError, 'change wrappers ordering'):
@asserts.assert_max_traces(1)
@jax.jit
def fn(_):
pass
def dummy_wrapper(fn):
@functools.wraps(fn)
def fn_wrapped():
return fn()
return fn_wrapped
with self.assertRaisesRegex(ValueError, 'change wrappers ordering'):
@asserts.assert_max_traces(1)
@dummy_wrapper
@jax.jit
def fn_2():
pass
# pylint:enable=g-error-prone-assert-raises,unused-variable
def test_redefined_traced_function(self):
def outer_fn(x):
@jax.jit
@asserts.assert_max_traces(3)
def inner_fn(y):
return y.sum()
return inner_fn(2 * x)
self.assertEqual(outer_fn(1), 2)
self.assertEqual(outer_fn(2), 4)
self.assertEqual(outer_fn(3), 6)
# Fails since the traced inner function is redefined at each call.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('fn.* is traced > .* times!')):
outer_fn(4)
asserts.clear_trace_counter()
for i in range(10):
if i > 2:
with self.assertRaisesRegex(
AssertionError, _get_err_regex('fn.* is traced > .* times!')):
outer_fn(1)
else:
outer_fn(1)
def test_nested_functions(self):
@jax.jit
def jitted_outer_fn(x):
@jax.jit
@asserts.assert_max_traces(1)
def inner_fn(y):
return y.sum()
return inner_fn(2 * x)
# Inner assert_max_traces have no effect since the outer_fn is traced once.
for i in range(10):
self.assertEqual(jitted_outer_fn(i), 2 * i)
class ScalarAssertTest(parameterized.TestCase):
def test_scalar(self):
asserts.assert_scalar(1)
asserts.assert_scalar(1.)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('must be a scalar')):
asserts.assert_scalar(np.array(1.)) # pytype: disable=wrong-arg-types
def test_scalar_positive(self):
asserts.assert_scalar_positive(0.5)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('must be positive')):
asserts.assert_scalar_positive(-0.5)
def test_scalar_non_negative(self):
asserts.assert_scalar_non_negative(0.5)
asserts.assert_scalar_non_negative(0.)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('must be non-negative')):
asserts.assert_scalar_non_negative(-0.5)
def test_scalar_negative(self):
asserts.assert_scalar_negative(-0.5)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('argument must be negative')):
asserts.assert_scalar_negative(0.5)
def test_scalar_in(self):
asserts.assert_scalar_in(0.5, 0, 1)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('argument must be in')):
asserts.assert_scalar_in(-0.5, 0, 1)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('argument must be in')):
asserts.assert_scalar_in(1.5, 0, 1)
def test_scalar_in_excluded(self):
asserts.assert_scalar_in(0.5, 0, 1, included=False)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('argument must be in')):
asserts.assert_scalar_in(0, 0, 1, included=False)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('argument must be in')):
asserts.assert_scalar_in(1, 0, 1, included=False)
class EqualSizeAssertTest(parameterized.TestCase):
@parameterized.named_parameters(
('scalar_vector_matrix', [1, 2, [3], [[4, 5]]]),
('vector_matrix', [[1], [2], [[3, 5]]]),
('matrix', [[[1, 2]], [[3], [4], [5]]]),
)
def test_equal_size_should_fail(self, arrays):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Arrays have different sizes')
):
asserts.assert_equal_size(arrays)
@parameterized.named_parameters(
('scalar_vector_matrix', [1, 2, [3], [[4]]]),
('vector_matrix', [[1], [2], [[3]]]),
('matrix', [[[1, 2]], [[3], [4]]]),
)
def test_equal_size_should_pass(self, arrays):
arrays = as_arrays(arrays)
asserts.assert_equal_size(arrays)
class SizeAssertTest(parameterized.TestCase):
@parameterized.named_parameters(
('wrong_size', [1, 2], 2),
('some_wrong_size', [[1, 2], [3, 4]], (2, 3)),
('wrong_common_shape', [[1, 2], [3, 4, 3]], 3),
('wrong_common_shape_2', [[1, 2, 3], [1, 2]], 2),
('some_wrong_size_set', [[1, 2], [3, 4]], (2, {3, 4})),
)
def test_size_should_fail(self, arrays, sizes):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has size .+ but expected .+')):
asserts.assert_size(arrays, sizes)
@parameterized.named_parameters(
('too_many_sizes', [[1]], (1, 1)),
('not_enough_sizes', [[1, 2], [3, 4], [5, 6]], (2, 2)),
)
def test_size_should_fail_wrong_length(self, arrays, sizes):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of `inputs` and `expected_sizes` must match')):
asserts.assert_size(arrays, sizes)
@parameterized.named_parameters(
('scalars', [1, 2], 1),
('vectors', [[1, 2], [3, 4, 5]], [2, 3]),
('matrices', [[[1, 2], [3, 4]]], 4),
('common_size_set', [[[1, 2], [3, 4]], [[1], [3]]], (4, {1, 2})),
)
def test_size_should_pass(self, arrays, sizes):
arrays = as_arrays(arrays)
asserts.assert_size(arrays, sizes)
def test_pytypes_pass(self):
arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]])
asserts.assert_size(arrays, (4, None))
asserts.assert_size(arrays, (4, {1, 2}))
asserts.assert_size(arrays, (4, ...))
@parameterized.named_parameters(
('single_ellipsis', [[1, 2, 3, 4], [1, 2]], (..., 2)),
('multiple_ellipsis', [[1, 2, 3], [1, 2, 3]], (..., ...)),
)
def test_ellipsis_should_pass(self, arrays, expected_size):
arrays = as_arrays(arrays)
asserts.assert_size(arrays, expected_size)
class EqualShapeAssertTest(parameterized.TestCase):
@parameterized.named_parameters(
('not_scalar', [1, 2, [3]]),
('wrong_rank', [[1], [2], 3]),
('wrong_length', [[1], [2], [3, 4]]),
)
def test_equal_shape_should_fail(self, arrays):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Arrays have different shapes')):
asserts.assert_equal_shape(arrays)
@parameterized.named_parameters(
('scalars', [1, 2, 3]),
('vectors', [[1], [2], [3]]),
('matrices', [[[1], [2]], [[3], [4]]]),
)
def test_equal_shape_should_pass(self, arrays):
arrays = as_arrays(arrays)
asserts.assert_equal_shape(arrays)
@parameterized.named_parameters(
('scalars', [1, 2, 3]),
('vectors', [[1], [2], [[3, 4]]]),
)
def test_equal_shape_prefix_should_pass(self, arrays):
arrays = as_arrays(arrays)
asserts.assert_equal_shape_prefix(arrays, prefix_len=1)
@parameterized.named_parameters(
('scalars', [1, 2, [3]]),
('vectors', [[1], [2], [[3], [4]]]),
)
def test_equal_shape_prefix_should_fail(self, arrays):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('different shape prefixes')):
asserts.assert_equal_shape_prefix(arrays, prefix_len=1)
@parameterized.named_parameters(
('first_dim', [[2, 3], [2, 4], [2, 5]], 0),
('last_dim', [[3, 5, 7], [2, 7], [4, 7]], -1), # Note different ranks.
('first_few_dims', [[1, 2, 3], [1, 2, 4], [1, 2, 5]], [0, 1]),
('first_and_last', [[1, 2, 1], [1, 3, 1], [1, 4, 1]], [0, 2]),
('first_and_last_neg', [[1, 2, 3, 4], [1, 5, 4], [1, 4]], [0, -1]),
)
def test_equal_shape_at_dims_should_pass(self, shapes, dims):
arrays = [array_from_shape(*shape) for shape in shapes]
asserts.assert_equal_shape(arrays, dims=dims)
@parameterized.named_parameters(
('first_dim', [[1, 2], [2, 2]], 0),
('last_dim', [[1, 3], [1, 4]], 1),
('last_dim_neg', [[1, 3], [1, 4]], -1),
('multiple_dims', [[1, 2, 3], [1, 2, 4]], [0, 2]),
)
def test_equal_shape_at_dims_should_fail(self, shapes, dims):
arrays = [array_from_shape(*shape) for shape in shapes]
with self.assertRaisesRegex(
AssertionError, _get_err_regex('have different shapes at dims')):
asserts.assert_equal_shape(arrays, dims=dims)
class ShapeAssertTest(parameterized.TestCase):
@parameterized.named_parameters(
('wrong_rank', [1], (1,)),
('wrong_shape', [1, 2], (1, 3)),
('some_wrong_shape', [[1, 2], [3, 4]], [(1, 2), (1, 3)]),
('wrong_common_shape', [[1, 2], [3, 4, 3]], (2,)),
('wrong_common_shape_2', [[1, 2, 3], [1, 2]], (2,)),
('some_wrong_shape_set', [[1, 2], [3, 4]], [(1, 2), (1, {3, 4})]),
)
def test_shape_should_fail(self, arrays, shapes):
arrays = as_arrays(arrays)
with self.subTest('list'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(arrays, list(shapes))
with self.subTest('tuple'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(arrays, tuple(shapes))
@parameterized.named_parameters(
('too_many_shapes', [[1]], [(1,), (2,)]),
('not_enough_shapes', [[1, 2], [3, 4]], [(3,)]),
)
def test_shape_should_fail_wrong_length(self, arrays, shapes):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of `inputs` and `expected_shapes` must match')):
asserts.assert_shape(arrays, tuple(shapes))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of `inputs` and `expected_shapes` must match')):
asserts.assert_shape(arrays, list(shapes))
@parameterized.named_parameters(
('scalars', [1, 2], ()),
('vectors', [[1, 2], [3, 4, 5]], [(2,), (3,)]),
('matrices', [[[1, 2], [3, 4]]], (2, 2)),
('matrices_variable_shape', [[[1, 2], [3, 4]]], (None, 2)),
('vectors_common_shape', [[1, 2], [3, 4]], (2,)),
('variable_common_shape', [[[1, 2], [3, 4]], [[1], [3]]], (2, None)),
('common_shape_set', [[[1, 2], [3, 4]], [[1], [3]]], (2, {1, 2})),
)
def test_shape_should_pass(self, arrays, shapes):
arrays = as_arrays(arrays)
with self.subTest('tuple'):
asserts.assert_shape(arrays, tuple(shapes))
with self.subTest('list'):
asserts.assert_shape(arrays, list(shapes))
@parameterized.named_parameters(
('variable_shape', (2, None)),
('shape_set', (2, {1, 2})),
('suffix', (2, ...)),
)
def test_pytypes_pass(self, shape):
arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]])
with self.subTest('tuple'):
asserts.assert_shape(arrays, tuple(shape))
with self.subTest('list'):
asserts.assert_shape(arrays, list(shape))
@parameterized.named_parameters(
('prefix_2', array_from_shape(2, 3, 4, 5, 6), (..., 4, 5, 6)),
('prefix_1', array_from_shape(3, 4, 5, 6), (..., 4, 5, 6)),
('prefix_0', array_from_shape(4, 5, 6), (..., 4, 5, 6)),
('inner_2', array_from_shape(2, 3, 4, 5, 6), (2, 3, ..., 6)),
('inner_1', array_from_shape(2, 3, 4, 6), (2, 3, ..., 6)),
('inner_0', array_from_shape(2, 3, 6), (2, 3, ..., 6)),
('suffix_2', array_from_shape(2, 3, 4, 5, 6), (2, 3, 4, ...)),
('suffix_1', array_from_shape(2, 3, 4, 5), (2, 3, 4, ...)),
('suffix_0', array_from_shape(2, 3, 4), (2, 3, 4, ...)),
)
def test_ellipsis_should_pass(self, array, expected_shape):
with self.subTest('list'):
asserts.assert_shape(array, list(expected_shape))
with self.subTest('tuple'):
asserts.assert_shape(array, tuple(expected_shape))
@parameterized.named_parameters(
('prefix', array_from_shape(3, 1, 5), (..., 4, 5, 6)),
('inner_bad_prefix', array_from_shape(2, 1, 4, 6), (2, 3, ..., 6)),
('inner_bad_suffix', array_from_shape(2, 3, 1, 5), (2, 3, ..., 6)),
('inner_both_bad', array_from_shape(2, 1, 4, 5), (2, 3, ..., 6)),
('suffix', array_from_shape(2, 3, 1, 5), (2, 3, 4, ...)),
('short_rank_prefix', array_from_shape(2, 3), (..., 4, 5, 6)),
('short_rank_inner', array_from_shape(2, 3), (2, 3, ..., 6)),
('short_rank_suffix', array_from_shape(2, 3), (2, 3, 4, ...)),
)
def test_ellipsis_should_fail(self, array, expected_shape):
with self.subTest('tuple'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(array, tuple(expected_shape))
with self.subTest('list'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(array, list(expected_shape))
@parameterized.named_parameters(
('prefix_and_suffix', array_from_shape(2, 3), (..., 2, 3, ...)),)
def test_multiple_ellipses(self, array, expected_shape):
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
ValueError,
'`expected_shape` may not contain more than one ellipsis, but got .+'):
asserts.assert_shape(array, expected_shape)
def rank_array(n):
return np.zeros(shape=[2] * n)
class BroadcastAssertTest(parameterized.TestCase):
@parameterized.parameters(
{'shape_a': (), 'shape_b': ()},
{'shape_a': (), 'shape_b': (2, 3)},
{'shape_a': (2, 3), 'shape_b': (2, 3)},
{'shape_a': (1, 3), 'shape_b': (2, 3)},
{'shape_a': (2, 1), 'shape_b': (2, 3)},
{'shape_a': (4,), 'shape_b': (2, 3, 4)},
{'shape_a': (3, 4), 'shape_b': (2, 3, 4)},
)
def test_shapes_are_broadcastable(self, shape_a, shape_b):
asserts.assert_is_broadcastable(shape_a, shape_b)
@parameterized.parameters(
{'shape_a': (2,), 'shape_b': ()},
{'shape_a': (2, 3, 4), 'shape_b': (3, 4)},
{'shape_a': (3, 5), 'shape_b': (3, 4)},
{'shape_a': (3, 4), 'shape_b': (3, 1)},
{'shape_a': (3, 4), 'shape_b': (1, 4)},
)
def test_shapes_are_not_broadcastable(self, shape_a, shape_b):
with self.assertRaises(AssertionError):
asserts.assert_is_broadcastable(shape_a, shape_b)
class RankAssertTest(parameterized.TestCase):
def test_rank_should_fail_array_expectations(self):
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
ValueError,
'expected ranks should be .* but was an array'):
asserts.assert_rank(rank_array(2), np.array([2]))
def test_rank_should_fail_wrong_expectation_structure(self):
# pytype: disable=wrong-arg-types
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
ValueError, 'Expected ranks should be integers or sets of integers'):
asserts.assert_rank(rank_array(2), [[1, 2]])
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises
ValueError, 'Expected ranks should be integers or sets of integers'):
asserts.assert_rank([rank_array(1), rank_array(2)], [[1], [2]])
# pytype: enable=wrong-arg-types
@parameterized.named_parameters(
('rank_1', rank_array(1), 2),
('rank_2', rank_array(2), 1),
('rank_3', rank_array(3), {2, 4}),
)
def test_rank_should_fail_single(self, array, rank):
array = np.asarray(array)
with self.assertRaisesRegex(
AssertionError, _get_err_regex('input .+ has rank .+ but expected .+')):
asserts.assert_rank(array, rank)
@parameterized.named_parameters(
('wrong_1', [rank_array(1), rank_array(2)], [2, 2]),
('wrong_2', [rank_array(1), rank_array(2)], [1, 3]),
('wrong_3', [rank_array(1), rank_array(2)], [{2, 3}, 2]),
('wrong_4', [rank_array(1), rank_array(2)], [1, {1, 3}]),
)
def test_assert_rank_should_fail_sequence(self, arrays, ranks):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError, _get_err_regex('input .+ has rank .+ but expected .+')):
asserts.assert_rank(arrays, ranks)
@parameterized.named_parameters(
('not_enough_ranks', [1, 3, 4], [1, 1]),
('too_many_ranks', [1, 2], [1, 1, 1]),
)
def test_rank_should_fail_wrong_length(self, array, rank):
array = np.asarray(array)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of inputs and expected_ranks must match.')):
asserts.assert_rank(array, rank)
@parameterized.named_parameters(
('rank_1', rank_array(1), 1),
('rank_2', rank_array(2), 2),
('rank_3', rank_array(3), {1, 2, 3}),
)
def test_rank_should_pass_single_input(self, array, rank):
array = np.asarray(array)
asserts.assert_rank(array, rank)
@parameterized.named_parameters(
('rank_1', rank_array(1), 1),
('rank_2', rank_array(2), 2),
('rank_3', rank_array(3), {1, 2, 3}),
)
def test_rank_should_pass_repeated_input(self, array, rank):
arrays = as_arrays([array] * 3)
asserts.assert_rank(arrays, rank)
@parameterized.named_parameters(
('single_option', [rank_array(1), rank_array(2)], {1, 2}),
('seq_options_1', [rank_array(1), rank_array(2)], [{1, 2}, 2]),
('seq_options_2', [rank_array(1), rank_array(2)], [1, {1, 2}]),
)
def test_rank_should_pass_multiple_options(self, arrays, ranks):
arrays = as_arrays(arrays)
asserts.assert_rank(arrays, ranks)
class TypeAssertTest(parameterized.TestCase):
@parameterized.named_parameters(
('one_float', 3., int),
('one_int', 3, float),
('many_floats', [1., 2., 3.], int),
('many_floats_verbose', [1., 2., 3.], [float, float, int]),
('one_bool_as_float', True, float),
('one_bool_as_int', True, int),
('one_float_as_bool', 3., bool),
('one_int_as_bool', 3, bool),
)
def test_type_should_fail_scalar(self, scalars, wrong_type):
with self.assertRaisesRegex(
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')):
asserts.assert_type(scalars, wrong_type)
@variants.variants(with_device=True, without_device=True)
@parameterized.named_parameters(
('one_float_array', [1., 2.], float, int),
('one_int_array', [1, 2], int, float),
('bfloat16_array', [1, 2], jnp.bfloat16, jnp.float32),
('int8_array', [1, 2], jnp.int8, jnp.int32),
('float32_array', [1, 2], jnp.float32, np.integer),
)
def test_type_should_fail_array(self, array, dtype, wrong_type):
array = self.variant(emplace)(array, dtype)
with self.assertRaisesRegex(
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')):
asserts.assert_type(array, wrong_type)
@parameterized.named_parameters(
('one_float', 3., float),
('one_int', 3, int),
('one_bool', True, bool),
('many_floats', [1., 2., 3.], float),
('many_floats_verbose', [1., 2., 3.], [float, float, float]),
)
def test_type_should_pass_scalar(self, array, expected_type):
asserts.assert_type(array, expected_type)
@variants.variants(with_device=True, without_device=True)
@parameterized.named_parameters(
('one_float_array', [1., 2.], float, float),
('one_int_array', [1, 2], int, int),
('one_integer_array', [1, 2], int, np.integer),
('one_bool_array', [True], bool, bool),
)
def test_type_should_pass_array(self, array, dtype, expected_type):
array = self.variant(emplace)(array, dtype)
asserts.assert_type(array, expected_type)
def test_type_should_fail_mixed(self):
a_float = 1.
an_int = 2
a_np_float = np.asarray([3., 4.])
a_jax_int = jnp.asarray([5, 6])
with self.assertRaisesRegex(
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')):
asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
[float, int, float, float])
def test_type_should_pass_mixed(self):
a_float = 1.
an_int = 2
a_np_float = np.asarray([3., 4.])
a_jax_int = jnp.asarray([5, 6])
asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
[float, int, float, int])
@parameterized.named_parameters(
('too_many_types', [1., 2], [float, int, float]),
('not_enough_types', [1., 2], [float]),
)
def test_type_should_fail_wrong_length(self, array, wrong_type):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of `inputs` and `expected_types` must match')):
asserts.assert_type(array, wrong_type)
class AxisDimensionAssertionsTest(parameterized.TestCase):
def test_assert_axis_dimension_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension(tensor, axis=i, expected=s)
def test_assert_axis_dimension_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Expected tensor to have dimension')):
asserts.assert_axis_dimension(tensor, axis=i, expected=s + 1)
def test_assert_axis_dimension_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension(tensor, axis=i, expected=1)
def test_assert_axis_dimension_gt_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_gt(tensor, axis=i, val=s - 1)
def test_assert_axis_dimension_gt_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension greater than')):
asserts.assert_axis_dimension_gt(tensor, axis=i, val=s)
def test_assert_axis_dimension_gt_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_gt(tensor, axis=i, val=0)
def test_assert_axis_dimension_gteq_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s)
def test_assert_axis_dimension_gteq_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension greater than or')):
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s + 1)
def test_assert_axis_dimension_gteq_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=0)
def test_assert_axis_dimension_lt_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_lt(tensor, axis=i, val=s + 1)
def test_assert_axis_dimension_lt_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension less than')):
asserts.assert_axis_dimension_lt(tensor, axis=i, val=s)
def test_assert_axis_dimension_lt_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_lt(tensor, axis=i, val=0)
def test_assert_axis_dimension_lteq_pass(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s)
def test_assert_axis_dimension_lteq_fail(self):
tensor = jnp.ones((3, 2, 7, 2))
for i in range(-tensor.ndim, tensor.ndim):
s = tensor.shape[i]
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Expected tensor to have dimension less than or')):
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s - 1)
def test_assert_axis_dimension_lteq_axis_invalid(self):
tensor = jnp.ones((3, 2))
for i in (2, -3):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('not available')):
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=0)
def test_assert_axis_dimension_string_tensor(self):
tensor = ['ab', 'cddd']
asserts.assert_axis_dimension(tensor, axis=0, expected=2)
asserts.assert_axis_dimension(np.array(tensor), axis=0, expected=2)
class TreeAssertionsTest(parameterized.TestCase):
def _assert_tree_structs_validation(self, assert_fn):
"""Checks that assert_fn correctly processes invalid args' structs."""
get_val = lambda: jnp.zeros([3])
tree1 = [[get_val(), get_val()], get_val()]
tree2 = [[get_val(), get_val()], get_val()]
tree3 = [get_val(), [get_val(), get_val()]]
tree4 = [get_val(), [get_val(), get_val()], get_val()]
tree5 = dict(x=1, y=2, z=3)
tree6 = dict(x=1, y=2, z=3, n=None)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Error in tree structs equality check.*trees 0 and 1')):
assert_fn(tree1, tree5)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Error in tree structs equality check.*trees 0 and 1')):
assert_fn(tree1, tree3)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Error in tree structs equality check.*trees 0 and 2')):
assert_fn([], [], tree1)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Error in tree structs equality check.*trees 0 and 3')):
assert_fn(tree2, tree1, tree2, tree3, tree1)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Error in tree structs equality check.*trees 0 and 2')):
assert_fn(tree2, tree1, tree4)
# Test `None`s.
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Error in tree structs equality check.*trees 0 and 1')):
assert_fn(tree5, tree6)
def test_assert_tree_no_nones(self):
with self.subTest('tree_no_nones'):
tree_ok = {'a': [jnp.zeros((1,))], 'b': 1}
asserts.assert_tree_no_nones(tree_ok)
with self.subTest('tree_with_nones'):
tree_with_none = {'a': [jnp.zeros((1,))], 'b': None}
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Tree contains `None`')
):
asserts.assert_tree_no_nones(tree_with_none)
# Check `None`.
with self.subTest('input_none'):
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Tree contains `None`')
):
asserts.assert_tree_no_nones(None)
def test_tree_all_finite_passes_finite(self):
finite_tree = {'a': jnp.ones((3,)), 'b': jnp.array([0.0, 0.0])}
asserts.assert_tree_all_finite(finite_tree)
self.assertTrue(asserts._assert_tree_all_finite_jittable(finite_tree))
def test_tree_all_finite_should_fail_inf(self):
inf_tree = {
'finite_var': jnp.ones((3,)),
'inf_var': jnp.array([0.0, jnp.inf]),
}
err_msg = 'Tree contains non-finite value'
with self.assertRaisesRegex(AssertionError, _get_err_regex(err_msg)):
asserts.assert_tree_all_finite(inf_tree)
with self.assertRaisesRegex(ValueError, err_msg):
asserts._assert_tree_all_finite_jittable(inf_tree)
def test_assert_trees_all_equal_passes_same_tree(self):
tree = {
'a': [jnp.zeros((1,))],
'b': ([0], (0,), 0),
}
asserts.assert_trees_all_equal(tree, tree)
tree = jax.tree_map(jnp.asarray, tree)
self.assertTrue(asserts._assert_trees_all_equal_jittable(tree, tree))
def test_assert_trees_all_equal_passes_values_equal(self):
tree1 = (jnp.array([0.0, 0.0]),)
tree2 = (jnp.array([0.0, 0.0]),)
asserts.assert_trees_all_equal(tree1, tree2)
self.assertTrue(asserts._assert_trees_all_equal_jittable(tree1, tree2))
def test_assert_trees_all_equal_fail_values_close_but_not_equal(self):
tree1 = (jnp.array([1.0, 1.0]),)
tree2 = (jnp.array([1.0, 1.0 + 5e-7]),)
error_msg = 'Values not exactly equal'
with self.assertRaisesRegex(AssertionError, _get_err_regex(error_msg)):
asserts.assert_trees_all_equal(tree1, tree2)
with self.assertRaisesRegex(ValueError, error_msg):
asserts._assert_trees_all_equal_jittable(tree1, tree2)
def test_assert_trees_all_equal_strict_mode(self):
# See 'notes' section of
# https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html
# for details about the 'strict' mode of `numpy.testing.assert_array_equal`.
# tldr; it has special handling for scalar values (by default).
tree1 = {'a': jnp.array([1.0], dtype=jnp.float32), 'b': 0.0}
tree2 = {'a': jnp.array(1.0, dtype=jnp.float32), 'b': 0.0}
asserts.assert_trees_all_equal(tree1, tree2)
asserts.assert_trees_all_equal(tree1, tree2, strict=False)
err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'a\'')
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_equal(tree1, tree2, strict=True)
err_regex = r'Trees 0 and 1 differ in leaves'
with self.assertRaisesRegex(ValueError, err_regex):
asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=True)
# We do not implement this special scalar handling in the jittable
# assertion (it's possible, but doesn't seem worth the effort).
err_regex = r'`strict=False` is not implemented'
with self.assertRaisesRegex(NotImplementedError, err_regex):
asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=False)
def test_assert_trees_all_close_passes_same_tree(self):
tree = {
'a': [jnp.zeros((1,))],
'b': ([0], (0,), 0),
}
asserts.assert_trees_all_close(tree, tree)
tree = jax.tree_map(jnp.asarray, tree)
self.assertTrue(asserts._assert_trees_all_close_jittable(tree, tree))
def test_assert_trees_all_close_passes_values_equal(self):
tree1 = (jnp.array([0.0, 0.0]),)
tree2 = (jnp.array([0.0, 0.0]),)
asserts.assert_trees_all_close(tree1, tree2)
self.assertTrue(asserts._assert_trees_all_close_jittable(tree1, tree2))
def test_assert_trees_all_close_passes_values_close_but_not_equal(self):
tree1 = (jnp.array([1.0, 1.0]),)
tree2 = (jnp.array([1.0, 1.0 + 5e-7]),)
asserts.assert_trees_all_close(tree1, tree2, rtol=1e-6)
self.assertTrue(
asserts._assert_trees_all_close_jittable(tree1, tree2, rtol=1e-6))
def test_assert_trees_all_close_bfloat16(self):
tree1 = {'a': jnp.asarray([0.8, 1.6], dtype=jnp.bfloat16)}
tree2 = {
'a': jnp.asarray([0.8, 1.6], dtype=jnp.bfloat16).astype(jnp.float32)
}
tree3 = {'a': jnp.asarray([0.8, 1.7], dtype=jnp.bfloat16)}
asserts.assert_trees_all_close(tree1, tree1)
asserts.assert_trees_all_close(tree1, tree2)
self.assertTrue(asserts._assert_trees_all_close_jittable(tree1, tree2))
err_msg = 'Values not approximately equal'
err_regex = _get_err_regex(err_msg)
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_close(tree1, tree3)
with self.assertRaisesRegex(ValueError, err_msg):
asserts._assert_trees_all_close_jittable(tree1, tree3)
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_close(tree2, tree3)
with self.assertRaisesRegex(ValueError, err_msg):
asserts._assert_trees_all_close_jittable(tree2, tree3)
def test_assert_trees_all_close_ulp_jittable_raises_valueerror(self):
tree = (jnp.array([1.0]),)
err_msg = 'assert_trees_all_close_ulp is not supported within JIT contexts.'
err_regex = _get_err_regex(err_msg)
with self.assertRaisesRegex(RuntimeError, err_regex):
asserts._assert_trees_all_close_ulp_jittable(tree, tree)
def test_assert_trees_all_close_ulp_passes_same_tree(self):
tree = {
'a': [jnp.zeros((1,))],
'b': ([0], (0,), 0),
}
asserts.assert_trees_all_close_ulp(tree, tree)
def test_assert_trees_all_close_ulp_passes_values_equal(self):
tree1 = (jnp.array([0.0, 0.0]),)
tree2 = (jnp.array([0.0, 0.0]),)
try:
asserts.assert_trees_all_close_ulp(tree1, tree2)
except AssertionError:
self.fail('assert_trees_all_close_ulp raised AssertionError')
def test_assert_trees_all_close_ulp_passes_values_within_maxulp(self):
# np.spacing(np.float32(1 << 23)) == 1.0.
value_where_ulp_is_1 = np.float32(1 << 23)
tree1 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1]),)
tree2 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1 + 1.0]),)
assert tree2[0][0] != tree2[0][1]
try:
asserts.assert_trees_all_close_ulp(tree1, tree2, maxulp=2)
except AssertionError:
self.fail('assert_trees_all_close_ulp raised AssertionError')
def test_assert_trees_all_close_ulp_passes_values_maxulp_apart(self):
# np.spacing(np.float32(1 << 23)) == 1.0.
value_where_ulp_is_1 = np.float32(1 << 23)
tree1 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1]),)
tree2 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1 + 1.0]),)
assert tree2[0][0] != tree2[0][1]
try:
asserts.assert_trees_all_close_ulp(tree1, tree2, maxulp=1)
except AssertionError:
self.fail('assert_trees_all_close_ulp raised AssertionError')
def test_assert_trees_all_close_ulp_fails_values_gt_maxulp_apart(self):
# np.spacing(np.float32(1 << 23)) == 1.0.
value_where_ulp_is_1 = np.float32(1 << 23)
tree1 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1]),)
tree2 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1 + 2.0]),)
assert tree2[0][0] != tree2[0][1]
err_msg = re.escape(
'not almost equal up to 1 ULP (max difference is 2 ULP)'
)
err_regex = _get_err_regex(err_msg)
with self.assertRaisesRegex(AssertionError, err_regex):
asserts.assert_trees_all_close_ulp(tree1, tree2, maxulp=1)
def test_assert_trees_all_close_ulp_fails_bfloat16(self):
tree_f32 = (jnp.array([0.0]),)
tree_bf16 = (jnp.array([0.0], dtype=jnp.bfloat16),)
err_msg = 'ULP assertions are not currently supported for bfloat16.'
err_regex = _get_err_regex(err_msg)
with self.assertRaisesRegex(ValueError, err_regex): # pylint: disable=g-error-prone-assert-raises
asserts.assert_trees_all_close_ulp(tree_bf16, tree_bf16)
with self.assertRaisesRegex(ValueError, err_regex): # pylint: disable=g-error-prone-assert-raises
asserts.assert_trees_all_close_ulp(tree_bf16, tree_f32)
def test_assert_tree_has_only_ndarrays(self):
# Check correct inputs.
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(1), 'b': np.ones(3)})
asserts.assert_tree_has_only_ndarrays(np.zeros(4))
asserts.assert_tree_has_only_ndarrays(())
# Check incorrect inputs.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'b\' is not an ndarray')):
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros((1,)), 'b': 1})
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'b/1\' is not an ndarray')):
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(101), 'b': [1, 2]})
def test_assert_tree_is_on_host(self):
cpu = jax.local_devices(backend='cpu')[0]
# Check Numpy arrays.
for flag in (False, True):
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': np.ones(3)},
allow_cpu_device=flag)
asserts.assert_tree_is_on_host(np.zeros(4), allow_cpu_device=flag)
asserts.assert_tree_is_on_host(
jax.device_get(jax.device_put(np.ones(3))), allow_cpu_device=flag)
asserts.assert_tree_is_on_host((), allow_cpu_device=flag)
# Check DeviceArray (for platforms other than CPU).
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'a\' resides on')):
asserts.assert_tree_is_on_host({'a': jnp.zeros(1)},
allow_cpu_device=False)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'a\' resides on')):
asserts.assert_tree_is_on_host({'a': jax.device_put(np.zeros(1))},
allow_cpu_device=False)
# Check Jax arrays on CPU.
cpu_arr = jax.device_put(np.ones(5), cpu)
asserts.assert_tree_is_on_host({'a': cpu_arr})
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': cpu_arr})
# Check sharded Jax arrays on CPUs.
asserts.assert_tree_is_on_host(
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))},
allow_cpu_device=True,
allow_sharded_arrays=True,
)
# Disallow JAX arrays on CPU.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'a\' resides on.*CPU')):
asserts.assert_tree_is_on_host({'a': cpu_arr},
allow_cpu_device=False)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'b\' resides on.*CPU')):
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': cpu_arr},
allow_cpu_device=False)
# Check incorrect inputs.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'b\' is not an ndarray')):
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': 1})
# ShardedArrays are disallowed.
with self.assertRaisesRegex(
AssertionError, _get_err_regex('sharded arrays are disallowed')
):
asserts.assert_tree_is_on_host(
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))},
allow_cpu_device=False,
)
# ShardedArrays on CPUs, CPUs disallowed.
with self.assertRaisesRegex(
AssertionError, _get_err_regex("'a' is sharded and resides on.*CPU")
):
asserts.assert_tree_is_on_host(
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))},
allow_cpu_device=False,
allow_sharded_arrays=True,
)
def test_assert_tree_is_on_device(self):
# Check CPU platform.
cpu = jax.local_devices(backend='cpu')[0]
to_cpu = lambda x: jax.device_put(x, cpu)
cpu_tree = {'a': to_cpu(np.zeros(1)), 'b': to_cpu(np.ones(3))}
asserts.assert_tree_is_on_device(cpu_tree, device=cpu)
asserts.assert_tree_is_on_device(cpu_tree, platform='cpu')
asserts.assert_tree_is_on_device(cpu_tree, platform=['cpu'])
asserts.assert_tree_is_on_device(cpu_tree, device=cpu, platform='')
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'a\' resides on \'cpu\'')):
asserts.assert_tree_is_on_device(cpu_tree, platform='tpu')
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'b\' resides on \'cpu\'')):
asserts.assert_tree_is_on_device(cpu_tree, platform=('tpu', 'gpu'))
# Check TPU platform (if available).
if _num_devices_available('tpu') > 1:
tpu_1, tpu_2 = jax.devices('tpu')[:2]
to_tpu_1 = lambda x: jax.device_put(x, tpu_1)
to_tpu_2 = lambda x: jax.device_put(x, tpu_2)
tpu_1_tree = {'a': to_tpu_1(np.zeros(1)), 'b': to_tpu_1(np.ones(3))}
tpu_2_tree = {'a': to_tpu_2(np.zeros(1)), 'b': to_tpu_2(np.ones(3))}
tpu_1_2_tree = {'a': to_tpu_1(np.zeros(1)), 'b': to_tpu_2(np.ones(3))}
# Device asserts.
asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_1)
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_2)
with self.assertRaisesRegex(
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=0")
):
asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_2)
with self.assertRaisesRegex(
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=1")
):
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_1)
with self.assertRaisesRegex(
AssertionError, _get_err_regex("'a' resides on .*Cpu")
):
asserts.assert_tree_is_on_device(cpu_tree, device=tpu_2)
# Platform asserts.
asserts.assert_tree_is_on_device(tpu_1_tree, platform='tpu')
asserts.assert_tree_is_on_device(tpu_2_tree, platform='tpu')
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'a\' resides on \'tpu\'')):
asserts.assert_tree_is_on_device(tpu_1_tree, platform='cpu')
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'a\' resides on \'tpu\'')):
asserts.assert_tree_is_on_device(tpu_2_tree, platform='gpu')
# Mixed cases.
asserts.assert_tree_is_on_device(tpu_1_2_tree, platform='tpu')
asserts.assert_tree_is_on_device((tpu_1_2_tree, cpu_tree),
platform=('cpu', 'tpu'))
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'1/a\' resides on \'cpu\'')):
asserts.assert_tree_is_on_device((tpu_1_2_tree, cpu_tree),
platform='tpu')
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'0/a\' resides on \'tpu\'')):
asserts.assert_tree_is_on_device((tpu_1_2_tree, cpu_tree),
platform=('cpu', 'gpu'))
# Check incorrect inputs.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'b\' is not an ndarray')):
asserts.assert_tree_is_on_device({'a': np.zeros(1), 'b': 1})
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'b\' has unexpected type')):
asserts.assert_tree_is_on_device({'a': jnp.zeros(1), 'b': np.ones(3)})
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'a\' is a ShardedDeviceArra')):
# ShardedArrays are disallowed.
asserts.assert_tree_is_on_device(
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))}, device=cpu)
def test_assert_tree_is_sharded(self):
np_tree = {'a': np.zeros(1), 'b': np.ones(3)}
def _format(*devs):
return re.escape(f'{devs}')
# Check single-device case.
cpu = jax.local_devices(backend='cpu')[0]
cpu_tree = jax.device_put_replicated(np_tree, (cpu,))
asserts.assert_tree_is_sharded(cpu_tree, devices=(cpu,))
asserts.assert_tree_is_sharded((), devices=(cpu,))
with self.assertRaisesRegex(
AssertionError, _get_err_regex(r'\'a\' is sharded.*expected \(\)')):
asserts.assert_tree_is_sharded(cpu_tree, devices=())
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'a\' is sharded across {_format(cpu)}.*'
f'expected {_format(cpu, cpu)}')):
asserts.assert_tree_is_sharded(cpu_tree, devices=(cpu, cpu))
# Check multiple-devices case (if available).
if _num_devices_available('tpu') > 1:
tpu_1, tpu_2 = jax.devices('tpu')[:2]
tpu_1_tree = jax.device_put_replicated(np_tree, (tpu_1,))
tpu_2_tree = jax.device_put_replicated(np_tree, (tpu_2,))
tpu_1_2_tree = jax.device_put_replicated(np_tree, (tpu_1, tpu_2))
tpu_2_1_tree = jax.device_put_replicated(np_tree, (tpu_2, tpu_1))
asserts.assert_tree_is_sharded(tpu_1_2_tree, devices=(tpu_1, tpu_2))
asserts.assert_tree_is_sharded(tpu_2_1_tree, devices=(tpu_2, tpu_1))
# Wrong device.
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'a\' is sharded across {_format(tpu_1)}.*'
f'expected {_format(tpu_2)}')):
asserts.assert_tree_is_sharded(tpu_1_tree, devices=(tpu_2,))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'a\' is sharded across {_format(cpu)}.*'
f'expected {_format(tpu_2)}')):
asserts.assert_tree_is_sharded(cpu_tree, devices=(tpu_2,))
# Too many devices.
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'a\' is sharded across {_format(tpu_1)}.*'
f'expected {_format(tpu_1, tpu_2)}')):
asserts.assert_tree_is_sharded(tpu_1_tree, devices=(tpu_1, tpu_2))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'a\' is sharded across {_format(tpu_1, tpu_2)}.*'
f'expected {_format(tpu_1, tpu_2, cpu)}')):
asserts.assert_tree_is_sharded(
tpu_1_2_tree, devices=(tpu_1, tpu_2, cpu))
# Wrong order.
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'a\' is sharded across {_format(tpu_2, tpu_1)}.*'
f'expected {_format(tpu_1, tpu_2)}')):
asserts.assert_tree_is_sharded(tpu_2_1_tree, devices=(tpu_1, tpu_2))
# Mixed cases.
mixed_tree = (tpu_1_tree, tpu_2_tree)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'0/a\' is sharded across {_format(tpu_1)}.*'
f'expected {_format(tpu_2)}')):
asserts.assert_tree_is_sharded(mixed_tree, devices=(tpu_2,))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'1/a\' is sharded across {_format(tpu_2)}.*'
f'expected {_format(tpu_1)}')):
asserts.assert_tree_is_sharded(mixed_tree, devices=(tpu_1,))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(f'\'0/a\' is sharded across {_format(tpu_1)}.*'
f'expected {_format(tpu_1, tpu_2)}')):
asserts.assert_tree_is_sharded(mixed_tree, devices=(tpu_1, tpu_2))
# Check incorrect inputs.
with self.assertRaisesRegex(AssertionError,
_get_err_regex('\'1\' is not an ndarray')):
asserts.assert_tree_is_sharded((cpu_tree, 1123), devices=(cpu,))
with self.assertRaisesRegex(
AssertionError, _get_err_regex('\'a\' is not a jax.Array')):
asserts.assert_tree_is_sharded({'a': np.zeros(1)}, devices=(cpu,))
with self.assertRaisesRegex(
AssertionError, _get_err_regex('\'a\' is not sharded')):
asserts.assert_tree_is_sharded({'a': jnp.zeros(1)}, devices=(cpu,))
with self.assertRaisesRegex(
AssertionError, _get_err_regex("'a' is not sharded.*Cpu")
):
asserts.assert_tree_is_sharded({'a': jax.device_put(np.zeros(1), cpu)},
devices=(cpu,))
def test_assert_trees_all_close_fails_different_structure(self):
self._assert_tree_structs_validation(asserts.assert_trees_all_close)
def test_assert_trees_all_close_fails_values_differ(self):
tree1 = jnp.array([0.0, 2.0])
tree2 = jnp.array([0.0, 2.1])
asserts.assert_trees_all_close(tree1, tree2, atol=0.1)
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Values not approximately equal')):
asserts.assert_trees_all_close(tree1, tree2, atol=0.01)
asserts.assert_trees_all_close(tree1, tree2, rtol=0.1)
with self.assertRaisesRegex(
AssertionError, _get_err_regex('Values not approximately equal')):
asserts.assert_trees_all_close(tree1, tree2, rtol=0.01)
def test_assert_trees_all_equal_sizes(self):
get_val = lambda s1, s2: jnp.zeros([s1, s2])
tree1 = dict(a1=get_val(3, 1), d=dict(a2=get_val(4, 1), a3=get_val(5, 3)))
tree2 = dict(a1=get_val(3, 1), d=dict(a2=get_val(4, 1), a3=get_val(5, 3)))
tree3 = dict(a1=get_val(3, 1), d=dict(a2=get_val(4, 2), a3=get_val(5, 3)))
self._assert_tree_structs_validation(asserts.assert_trees_all_equal_sizes)
asserts.assert_trees_all_equal_sizes(tree1, tree1)
asserts.assert_trees_all_equal_sizes(tree2, tree1)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Trees 0 and 1 differ in leaves \'d/a2\': sizes: 4 != 8'
)):
asserts.assert_trees_all_equal_sizes(tree1, tree3)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Trees 0 and 3 differ in leaves \'d/a2\': sizes: 4 != 8'
)):
asserts.assert_trees_all_equal_sizes(tree1, tree2, tree2, tree3, tree1)
def test_assert_trees_all_equal_shapes(self):
get_val = lambda s: jnp.zeros([s])
tree1 = dict(a1=get_val(3), d=dict(a2=get_val(4), a3=get_val(5)))
tree2 = dict(a1=get_val(3), d=dict(a2=get_val(4), a3=get_val(5)))
tree3 = dict(a1=get_val(3), d=dict(a2=get_val(7), a3=get_val(5)))
self._assert_tree_structs_validation(asserts.assert_trees_all_equal_shapes)
asserts.assert_trees_all_equal_shapes(tree1, tree1)
asserts.assert_trees_all_equal_shapes(tree2, tree1)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Trees 0 and 1 differ in leaves \'d/a2\': shapes: \(4,\) != \(7,\)'
)):
asserts.assert_trees_all_equal_shapes(tree1, tree3)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Trees 0 and 3 differ in leaves \'d/a2\': shapes: \(4,\) != \(7,\)'
)):
asserts.assert_trees_all_equal_shapes(tree1, tree2, tree2, tree3, tree1)
def test_assert_trees_all_equal_structs(self):
get_val = lambda: jnp.zeros([3])
tree1 = [[get_val(), get_val()], get_val()]
tree2 = [[get_val(), get_val()], get_val()]
tree3 = [get_val(), [get_val(), get_val()]]
asserts.assert_trees_all_equal_structs(tree1, tree2, tree2, tree1)
asserts.assert_trees_all_equal_structs(tree3, tree3)
self._assert_tree_structs_validation(asserts.assert_trees_all_equal_structs)
@parameterized.named_parameters(
('scalars', ()),
('vectors', (3,)),
('matrices', (3, 2)),
)
def test_assert_tree_shape_prefix(self, shape):
tree = {'x': {'y': np.zeros([3, 2])}, 'z': np.zeros([3, 2, 1])}
with self.subTest('tuple'):
asserts.assert_tree_shape_prefix(tree, tuple(shape))
with self.subTest('list'):
asserts.assert_tree_shape_prefix(tree, list(shape))
def test_leaf_shape_should_fail_wrong_length(self):
tree = {'x': {'y': np.zeros([3, 2])}, 'z': np.zeros([3, 2, 1])}
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')):
asserts.assert_tree_shape_prefix(tree, (3, 2, 1))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')):
asserts.assert_tree_shape_prefix(tree, [3, 2, 1])
@parameterized.named_parameters(
('scalars', ()),
('vectors', (1,)),
('matrices', (2, 1)),
)
def test_assert_tree_shape_suffix_matching(self, shape):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([2, 1])}
with self.subTest('tuple'):
asserts.assert_tree_shape_suffix(tree, tuple(shape))
with self.subTest('list'):
asserts.assert_tree_shape_suffix(tree, list(shape))
@parameterized.named_parameters(
('bad_suffix_leaf_1', 'z', (1, 1), (2, 1)),
('bad_suffix_leaf_2', 'x/y', (2, 1), (1, 1)),
)
def test_assert_tree_shape_suffix_mismatch(self, leaf, shape_true, shape):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([1, 1])}
error_msg = (
r'Tree leaf \'' + str(leaf) + '\'.*different from expected: '
+ re.escape(str(shape_true)) + ' != ' + re.escape(str(shape))
)
with self.subTest('tuple'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
error_msg)):
asserts.assert_tree_shape_suffix(tree, tuple(shape))
with self.subTest('list'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
error_msg)):
asserts.assert_tree_shape_suffix(tree, list(shape))
def test_assert_tree_shape_suffix_long_suffix(self):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([4, 2, 1])}
asserts.assert_tree_shape_suffix(tree, (4, 2, 1))
asserts.assert_tree_shape_suffix(tree, [4, 2, 1])
with self.assertRaisesRegex(
AssertionError, _get_err_regex('which is smaller than the expected')):
asserts.assert_tree_shape_suffix(tree, (3, 4, 2, 1))
with self.assertRaisesRegex(
AssertionError, _get_err_regex('which is smaller than the expected')):
asserts.assert_tree_shape_suffix(tree, [3, 4, 2, 1])
def test_assert_trees_all_equal_dtypes(self):
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
t_1 = {'x': np.zeros(5, dtype=np.uint16), 'y': np.ones(4, dtype=np.float32)}
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Trees 0 and 1 differ')):
asserts.assert_trees_all_equal_dtypes(t_0, t_1)
t_2 = {'x': np.zeros(6, dtype=jnp.int16), 'y': np.ones(6, dtype=np.float32)}
asserts.assert_trees_all_equal_dtypes(t_0, t_2, t_0)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Trees 0 and 4 differ')):
asserts.assert_trees_all_equal_dtypes(t_0, t_0, t_2, t_0, t_1, t_2)
# np vs jnp
t_3 = {'x': np.zeros(1, dtype=np.int16), 'y': np.ones(2, dtype=jnp.float32)}
t_4 = {
'x': np.zeros(1, dtype=jnp.int16),
'y': np.ones(2, dtype=jnp.float32)
}
asserts.assert_trees_all_equal_dtypes(t_0, t_2, t_3, t_4)
# bfloat16
t_5 = {'y': np.ones(2, dtype=np.float16)}
t_6 = {'y': np.ones(2, dtype=jnp.bfloat16)}
asserts.assert_trees_all_equal_dtypes(t_6, t_6)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Trees 0 and 1 differ')):
asserts.assert_trees_all_equal_dtypes(t_5, t_6)
def test_assert_trees_all_equal_shapes_and_dtypes(self):
# Test dtypes
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
t_1 = {'x': np.zeros(3, dtype=np.uint16), 'y': np.ones(2, dtype=np.float32)}
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Trees 0 and 1 differ')):
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_1)
t_2 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_2, t_0)
# Test shapes
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
t_1 = {'x': np.zeros(4, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Trees 0 and 1 differ')):
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_1)
t_2 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_2, t_0)
def test_assert_trees_all_equal_wrong_usage(self):
# not an array
with self.assertRaisesRegex(AssertionError,
_get_err_regex(r'is not a \(j-\)np array')):
asserts.assert_trees_all_equal_dtypes({'x': 1.}, {'x': np.array(1.)})
# 1 tree
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises
ValueError, 'Assertions over only one tree does not make sense'):
asserts.assert_trees_all_equal_dtypes({'x': 1.})
def test_assert_trees_all_equal_none(self):
t_0 = {'x': None, 'y': np.array(2, dtype=np.int32)}
t_1 = {'x': None, 'y': np.array([23], dtype=np.int32)}
t_2 = {'x': None, 'y': np.array(3, dtype=np.float32)}
t_3 = {'y': np.array([23], dtype=np.int32)}
with self.assertRaisesRegex(AssertionError,
_get_err_regex('Trees 0 and 2 differ')):
asserts.assert_trees_all_equal_dtypes(t_0, t_1, t_2)
asserts.assert_trees_all_equal_dtypes(t_0, t_1)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('trees 0 and 1 do not match')):
asserts.assert_trees_all_equal_dtypes(t_0, t_3)
with self.assertRaisesRegex(AssertionError,
_get_err_regex('trees 0 and 1 do not match')):
asserts.assert_trees_all_equal_dtypes(t_0, t_3)
class DevicesAssertTest(parameterized.TestCase):
def _device_count(self, backend):
try:
return jax.device_count(backend)
except RuntimeError:
return 0
@parameterized.parameters('cpu', 'gpu', 'tpu')
def test_not_less_than(self, devtype):
n = self._device_count(devtype)
if n > 0:
asserts.assert_devices_available(
n - 1, devtype, backend=devtype, not_less_than=True)
with self.assertRaisesRegex(AssertionError,
_get_err_regex(f'Only {n} < {n + 1}')):
asserts.assert_devices_available(
n + 1, devtype, backend=devtype, not_less_than=True)
else:
with self.assertRaisesRegex(RuntimeError, # pylint: disable=g-error-prone-assert-raises
'(failed to initialize)|(Unknown backend)'):
asserts.assert_devices_available(
n - 1, devtype, backend=devtype, not_less_than=True)
def test_unsupported_device(self):
with self.assertRaisesRegex(ValueError, 'Unknown device type'): # pylint: disable=g-error-prone-assert-raises
asserts.assert_devices_available(1, 'unsupported_devtype')
def test_gpu_assert(self):
n_gpu = self._device_count('gpu')
asserts.assert_devices_available(n_gpu, 'gpu')
if n_gpu:
asserts.assert_gpu_available()
else:
with self.assertRaisesRegex(AssertionError,
_get_err_regex('No 2 GPUs available')):
asserts.assert_devices_available(2, 'gpu')
with self.assertRaisesRegex(AssertionError,
_get_err_regex('No GPU devices available')):
asserts.assert_gpu_available()
with self.assertRaisesRegex(AssertionError,
_get_err_regex('No 2 GPUs available')):
asserts.assert_devices_available(2, 'gpu', backend='cpu')
def test_cpu_assert(self):
n_cpu = jax.device_count('cpu')
asserts.assert_devices_available(n_cpu, 'cpu', backend='cpu')
def test_tpu_assert(self):
n_tpu = self._device_count('tpu')
asserts.assert_devices_available(n_tpu, 'tpu')
if n_tpu:
asserts.assert_tpu_available()
else:
with self.assertRaisesRegex(AssertionError,
_get_err_regex('No 3 TPUs available')):
asserts.assert_devices_available(3, 'tpu')
with self.assertRaisesRegex(AssertionError,
_get_err_regex('No TPU devices available')):
asserts.assert_tpu_available()
with self.assertRaisesRegex(AssertionError,
_get_err_regex('No 3 TPUs available')):
asserts.assert_devices_available(3, 'tpu', backend='cpu')
class NumericalGradsAssertTest(parameterized.TestCase):
def _test_fn(self, fn, init_args, seed, n=10):
rng_key = jax.random.PRNGKey(seed)
for _ in range(n):
rng_key, *tree_keys = jax.random.split(rng_key, len(init_args) + 1)
x = jax.tree_util.tree_map(
lambda k, x: jax.random.uniform(k, shape=x.shape),
list(tree_keys), list(init_args))
asserts.assert_numerical_grads(fn, x, order=1)
@parameterized.parameters(([1], 24), ([5], 6), ([3, 5], 20))
def test_easy(self, x_shape, seed):
f_easy = lambda x: jnp.sum(x**2 - 2 * x + 10)
init_args = (jnp.zeros(x_shape),)
self._test_fn(f_easy, init_args, seed)
@parameterized.parameters(([1], 24), ([5], 6), ([3, 5], 20))
def test_easy_with_stop_gradient(self, x_shape, seed):
f_easy_sg = lambda x: jnp.sum(jax.lax.stop_gradient(x**2) - 2 * x + 10)
init_args = (jnp.zeros(x_shape),)
self._test_fn(f_easy_sg, init_args, seed)
@parameterized.parameters(([1], 24), ([5], 6), ([3, 5], 20))
def test_hard(self, x_shape, seed):
def f_hard_with_sg(lr, x):
inner_loss = lambda y: jnp.sum((y - 1.0)**2)
inner_loss_grad = jax.grad(inner_loss)
def fu(lr, x):
for _ in range(10):
x1 = x - lr * inner_loss_grad(x) + 100 * lr**2
x2 = x - lr * inner_loss_grad(x) - 100 * lr**2
x = jax.lax.select((x > 3.).any(), x1, x2 + lr)
return x
y = fu(lr, x)
return jnp.sum(inner_loss(y))
lr = jnp.zeros([1] * len(x_shape))
x = jnp.zeros(x_shape)
self._test_fn(f_hard_with_sg, (lr, x), seed)
@parameterized.parameters(([1], 24), ([5], 6), ([3, 5], 20))
def test_hard_with_stop_gradient(self, x_shape, seed):
def f_hard_with_sg(lr, x):
inner_loss = lambda y: jnp.sum((y - 1.0)**2)
inner_loss_grad = jax.grad(inner_loss)
def fu(lr, x):
for _ in range(10):
x1 = x - lr * inner_loss_grad(x) + 100 * jax.lax.stop_gradient(lr)**2
x2 = x - lr * inner_loss_grad(x) - 100 * lr**2
x = jax.lax.select((x > 3.).any(), x1, x2 + jax.lax.stop_gradient(lr))
return x
y = fu(lr, x)
return jnp.sum(inner_loss(y))
lr = jnp.zeros([1] * len(x_shape))
x = jnp.zeros(x_shape)
self._test_fn(f_hard_with_sg, (lr, x), seed)
class EqualAssertionsTest(parameterized.TestCase):
@parameterized.named_parameters(
('dtypes', jnp.int32, jnp.int32),
('lists', [1, 2], [1, 2]),
('dicts', dict(a=[7, jnp.int32]), dict(a=[7, jnp.int32])),
)
def test_assert_equal_pass(self, first, second):
asserts.assert_equal(first, second)
def test_assert_equal_pass_on_arrays(self):
# Not using named_parameters, becase JAX cannot be used before app.run().
asserts.assert_equal(jnp.ones([]), np.ones([]))
asserts.assert_equal(
jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))
@parameterized.named_parameters(
('dtypes', jnp.int32, jnp.float32),
('lists', [1, 2], [1, 7]),
('lists2', [1, 2], [1]),
('dicts1', dict(a=[7, jnp.int32]), dict(b=[7, jnp.int32])),
('dicts2', dict(a=[7, jnp.int32]), dict(b=[1, jnp.int32])),
('dicts3', dict(a=[7, jnp.int32]), dict(a=[1, jnp.int32], b=2)),
('dicts4', dict(a=[7, jnp.int32]), dict(a=[1, jnp.float32])),
('arrays', np.zeros([]), np.ones([])),
)
def test_assert_equal_fail(self, first, second):
with self.assertRaises(AssertionError):
asserts.assert_equal(first, second)
class IsDivisibleTest(parameterized.TestCase):
def test_assert_is_divisible(self):
asserts.assert_is_divisible(6, 3)
def test_assert_is_divisible_fail(self):
with self.assertRaises(AssertionError):
asserts.assert_is_divisible(7, 3)
if __name__ == '__main__':
jax.config.update('jax_numpy_rank_promotion', 'raise')
absltest.main()