Spaces:
Building
Building
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for `asserts.py`.""" | |
import functools | |
import re | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
from chex._src import asserts | |
from chex._src import asserts_internal | |
from chex._src import variants | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
_get_err_regex = asserts_internal.get_err_regex | |
_num_devices_available = asserts_internal.num_devices_available | |
def as_arrays(arrays): | |
return [np.asarray(a) for a in arrays] | |
def array_from_shape(*shape): | |
return np.ones(shape=shape) | |
def emplace(arrays, dtype): | |
return jnp.array(arrays, dtype=dtype) | |
class AssertsSwitchTest(parameterized.TestCase): | |
"""Tests for enable/disable_asserts.""" | |
def test_enable_disable_asserts(self): | |
with self.assertRaisesRegex(AssertionError, _get_err_regex('scalar')): | |
asserts.assert_scalar('test') | |
asserts.disable_asserts() | |
asserts.assert_scalar('test') | |
asserts.enable_asserts() | |
with self.assertRaisesRegex(AssertionError, _get_err_regex('scalar')): | |
asserts.assert_scalar('test') | |
asserts.disable_asserts() | |
asserts.assert_is_divisible(13, 5) | |
# To avoid side effects. | |
asserts.enable_asserts() | |
class AssertMaxTracesTest(variants.TestCase): | |
def setUp(self): | |
super().setUp() | |
asserts.clear_trace_counter() | |
def _init(self, fn_, init_type, max_traces, kwargs, static_arg): | |
"""Initializes common test cases.""" | |
variant_kwargs = {} | |
if static_arg: | |
variant_kwargs['static_argnums'] = 1 | |
if kwargs: | |
args, kwargs = [], {'n': max_traces} | |
else: | |
args, kwargs = [max_traces], {} | |
if init_type == 't1': | |
def fn(x, y): | |
if static_arg: | |
self.assertNotIsInstance(y, jax.core.Tracer) | |
return fn_(x, y) | |
fn_jitted = self.variant(fn, **variant_kwargs) | |
elif init_type == 't2': | |
def fn(x, y): | |
if static_arg: | |
self.assertNotIsInstance(y, jax.core.Tracer) | |
return fn_(x, y) | |
fn = asserts.assert_max_traces(fn, *args, **kwargs) | |
fn_jitted = self.variant(fn, **variant_kwargs) | |
elif init_type == 't3': | |
def fn(x, y): | |
if static_arg: | |
self.assertNotIsInstance(y, jax.core.Tracer) | |
return fn_(x, y) | |
def fn_jitted(x, y): | |
self.assertIsInstance(x, jax.core.Tracer) | |
return fn_(x, y) | |
else: | |
raise ValueError(f'Unknown type {init_type}.') | |
return fn, fn_jitted | |
def test_assert(self, init_type, kwargs, static_arg, max_traces): | |
fn_ = lambda x, y: x + y | |
fn, fn_jitted = self._init(fn_, init_type, max_traces, kwargs, static_arg) | |
# Original function. | |
for _ in range(max_traces + 3): | |
self.assertEqual(fn(1, 2), 3) | |
# Every call results in re-tracing because arguments' shapes are different. | |
for i in range(max_traces): | |
for k in range(5): | |
arg = jnp.zeros(i + 1) + k | |
np.testing.assert_array_equal(fn_jitted(arg, 2), arg + 2) | |
# Original function. | |
for _ in range(max_traces + 3): | |
self.assertEqual(fn(1, 2), 3) | |
self.assertEqual(fn([1], [2]), [1, 2]) | |
self.assertEqual(fn('a', 'b'), 'ab') | |
# (max_traces + 1)-th re-tracing. | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('fn.* is traced > .* times!')): | |
arg = jnp.zeros(max_traces + 1) | |
fn_jitted(arg, 2) | |
def test_incorrect_ordering(self): | |
# pylint:disable=g-error-prone-assert-raises,unused-variable | |
with self.assertRaisesRegex(ValueError, 'change wrappers ordering'): | |
def fn(_): | |
pass | |
def dummy_wrapper(fn): | |
def fn_wrapped(): | |
return fn() | |
return fn_wrapped | |
with self.assertRaisesRegex(ValueError, 'change wrappers ordering'): | |
def fn_2(): | |
pass | |
# pylint:enable=g-error-prone-assert-raises,unused-variable | |
def test_redefined_traced_function(self): | |
def outer_fn(x): | |
def inner_fn(y): | |
return y.sum() | |
return inner_fn(2 * x) | |
self.assertEqual(outer_fn(1), 2) | |
self.assertEqual(outer_fn(2), 4) | |
self.assertEqual(outer_fn(3), 6) | |
# Fails since the traced inner function is redefined at each call. | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('fn.* is traced > .* times!')): | |
outer_fn(4) | |
asserts.clear_trace_counter() | |
for i in range(10): | |
if i > 2: | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('fn.* is traced > .* times!')): | |
outer_fn(1) | |
else: | |
outer_fn(1) | |
def test_nested_functions(self): | |
def jitted_outer_fn(x): | |
def inner_fn(y): | |
return y.sum() | |
return inner_fn(2 * x) | |
# Inner assert_max_traces have no effect since the outer_fn is traced once. | |
for i in range(10): | |
self.assertEqual(jitted_outer_fn(i), 2 * i) | |
class ScalarAssertTest(parameterized.TestCase): | |
def test_scalar(self): | |
asserts.assert_scalar(1) | |
asserts.assert_scalar(1.) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('must be a scalar')): | |
asserts.assert_scalar(np.array(1.)) # pytype: disable=wrong-arg-types | |
def test_scalar_positive(self): | |
asserts.assert_scalar_positive(0.5) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('must be positive')): | |
asserts.assert_scalar_positive(-0.5) | |
def test_scalar_non_negative(self): | |
asserts.assert_scalar_non_negative(0.5) | |
asserts.assert_scalar_non_negative(0.) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('must be non-negative')): | |
asserts.assert_scalar_non_negative(-0.5) | |
def test_scalar_negative(self): | |
asserts.assert_scalar_negative(-0.5) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('argument must be negative')): | |
asserts.assert_scalar_negative(0.5) | |
def test_scalar_in(self): | |
asserts.assert_scalar_in(0.5, 0, 1) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('argument must be in')): | |
asserts.assert_scalar_in(-0.5, 0, 1) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('argument must be in')): | |
asserts.assert_scalar_in(1.5, 0, 1) | |
def test_scalar_in_excluded(self): | |
asserts.assert_scalar_in(0.5, 0, 1, included=False) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('argument must be in')): | |
asserts.assert_scalar_in(0, 0, 1, included=False) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('argument must be in')): | |
asserts.assert_scalar_in(1, 0, 1, included=False) | |
class EqualSizeAssertTest(parameterized.TestCase): | |
def test_equal_size_should_fail(self, arrays): | |
arrays = as_arrays(arrays) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('Arrays have different sizes') | |
): | |
asserts.assert_equal_size(arrays) | |
def test_equal_size_should_pass(self, arrays): | |
arrays = as_arrays(arrays) | |
asserts.assert_equal_size(arrays) | |
class SizeAssertTest(parameterized.TestCase): | |
def test_size_should_fail(self, arrays, sizes): | |
arrays = as_arrays(arrays) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('input .+ has size .+ but expected .+')): | |
asserts.assert_size(arrays, sizes) | |
def test_size_should_fail_wrong_length(self, arrays, sizes): | |
arrays = as_arrays(arrays) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Length of `inputs` and `expected_sizes` must match')): | |
asserts.assert_size(arrays, sizes) | |
def test_size_should_pass(self, arrays, sizes): | |
arrays = as_arrays(arrays) | |
asserts.assert_size(arrays, sizes) | |
def test_pytypes_pass(self): | |
arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]]) | |
asserts.assert_size(arrays, (4, None)) | |
asserts.assert_size(arrays, (4, {1, 2})) | |
asserts.assert_size(arrays, (4, ...)) | |
def test_ellipsis_should_pass(self, arrays, expected_size): | |
arrays = as_arrays(arrays) | |
asserts.assert_size(arrays, expected_size) | |
class EqualShapeAssertTest(parameterized.TestCase): | |
def test_equal_shape_should_fail(self, arrays): | |
arrays = as_arrays(arrays) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('Arrays have different shapes')): | |
asserts.assert_equal_shape(arrays) | |
def test_equal_shape_should_pass(self, arrays): | |
arrays = as_arrays(arrays) | |
asserts.assert_equal_shape(arrays) | |
def test_equal_shape_prefix_should_pass(self, arrays): | |
arrays = as_arrays(arrays) | |
asserts.assert_equal_shape_prefix(arrays, prefix_len=1) | |
def test_equal_shape_prefix_should_fail(self, arrays): | |
arrays = as_arrays(arrays) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('different shape prefixes')): | |
asserts.assert_equal_shape_prefix(arrays, prefix_len=1) | |
def test_equal_shape_at_dims_should_pass(self, shapes, dims): | |
arrays = [array_from_shape(*shape) for shape in shapes] | |
asserts.assert_equal_shape(arrays, dims=dims) | |
def test_equal_shape_at_dims_should_fail(self, shapes, dims): | |
arrays = [array_from_shape(*shape) for shape in shapes] | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('have different shapes at dims')): | |
asserts.assert_equal_shape(arrays, dims=dims) | |
class ShapeAssertTest(parameterized.TestCase): | |
def test_shape_should_fail(self, arrays, shapes): | |
arrays = as_arrays(arrays) | |
with self.subTest('list'): | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('input .+ has shape .+ but expected .+')): | |
asserts.assert_shape(arrays, list(shapes)) | |
with self.subTest('tuple'): | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('input .+ has shape .+ but expected .+')): | |
asserts.assert_shape(arrays, tuple(shapes)) | |
def test_shape_should_fail_wrong_length(self, arrays, shapes): | |
arrays = as_arrays(arrays) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Length of `inputs` and `expected_shapes` must match')): | |
asserts.assert_shape(arrays, tuple(shapes)) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Length of `inputs` and `expected_shapes` must match')): | |
asserts.assert_shape(arrays, list(shapes)) | |
def test_shape_should_pass(self, arrays, shapes): | |
arrays = as_arrays(arrays) | |
with self.subTest('tuple'): | |
asserts.assert_shape(arrays, tuple(shapes)) | |
with self.subTest('list'): | |
asserts.assert_shape(arrays, list(shapes)) | |
def test_pytypes_pass(self, shape): | |
arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]]) | |
with self.subTest('tuple'): | |
asserts.assert_shape(arrays, tuple(shape)) | |
with self.subTest('list'): | |
asserts.assert_shape(arrays, list(shape)) | |
def test_ellipsis_should_pass(self, array, expected_shape): | |
with self.subTest('list'): | |
asserts.assert_shape(array, list(expected_shape)) | |
with self.subTest('tuple'): | |
asserts.assert_shape(array, tuple(expected_shape)) | |
def test_ellipsis_should_fail(self, array, expected_shape): | |
with self.subTest('tuple'): | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('input .+ has shape .+ but expected .+')): | |
asserts.assert_shape(array, tuple(expected_shape)) | |
with self.subTest('list'): | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('input .+ has shape .+ but expected .+')): | |
asserts.assert_shape(array, list(expected_shape)) | |
def test_multiple_ellipses(self, array, expected_shape): | |
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises | |
ValueError, | |
'`expected_shape` may not contain more than one ellipsis, but got .+'): | |
asserts.assert_shape(array, expected_shape) | |
def rank_array(n): | |
return np.zeros(shape=[2] * n) | |
class BroadcastAssertTest(parameterized.TestCase): | |
def test_shapes_are_broadcastable(self, shape_a, shape_b): | |
asserts.assert_is_broadcastable(shape_a, shape_b) | |
def test_shapes_are_not_broadcastable(self, shape_a, shape_b): | |
with self.assertRaises(AssertionError): | |
asserts.assert_is_broadcastable(shape_a, shape_b) | |
class RankAssertTest(parameterized.TestCase): | |
def test_rank_should_fail_array_expectations(self): | |
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises | |
ValueError, | |
'expected ranks should be .* but was an array'): | |
asserts.assert_rank(rank_array(2), np.array([2])) | |
def test_rank_should_fail_wrong_expectation_structure(self): | |
# pytype: disable=wrong-arg-types | |
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises | |
ValueError, 'Expected ranks should be integers or sets of integers'): | |
asserts.assert_rank(rank_array(2), [[1, 2]]) | |
with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises | |
ValueError, 'Expected ranks should be integers or sets of integers'): | |
asserts.assert_rank([rank_array(1), rank_array(2)], [[1], [2]]) | |
# pytype: enable=wrong-arg-types | |
def test_rank_should_fail_single(self, array, rank): | |
array = np.asarray(array) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('input .+ has rank .+ but expected .+')): | |
asserts.assert_rank(array, rank) | |
def test_assert_rank_should_fail_sequence(self, arrays, ranks): | |
arrays = as_arrays(arrays) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('input .+ has rank .+ but expected .+')): | |
asserts.assert_rank(arrays, ranks) | |
def test_rank_should_fail_wrong_length(self, array, rank): | |
array = np.asarray(array) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Length of inputs and expected_ranks must match.')): | |
asserts.assert_rank(array, rank) | |
def test_rank_should_pass_single_input(self, array, rank): | |
array = np.asarray(array) | |
asserts.assert_rank(array, rank) | |
def test_rank_should_pass_repeated_input(self, array, rank): | |
arrays = as_arrays([array] * 3) | |
asserts.assert_rank(arrays, rank) | |
def test_rank_should_pass_multiple_options(self, arrays, ranks): | |
arrays = as_arrays(arrays) | |
asserts.assert_rank(arrays, ranks) | |
class TypeAssertTest(parameterized.TestCase): | |
def test_type_should_fail_scalar(self, scalars, wrong_type): | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')): | |
asserts.assert_type(scalars, wrong_type) | |
def test_type_should_fail_array(self, array, dtype, wrong_type): | |
array = self.variant(emplace)(array, dtype) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')): | |
asserts.assert_type(array, wrong_type) | |
def test_type_should_pass_scalar(self, array, expected_type): | |
asserts.assert_type(array, expected_type) | |
def test_type_should_pass_array(self, array, dtype, expected_type): | |
array = self.variant(emplace)(array, dtype) | |
asserts.assert_type(array, expected_type) | |
def test_type_should_fail_mixed(self): | |
a_float = 1. | |
an_int = 2 | |
a_np_float = np.asarray([3., 4.]) | |
a_jax_int = jnp.asarray([5, 6]) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')): | |
asserts.assert_type([a_float, an_int, a_np_float, a_jax_int], | |
[float, int, float, float]) | |
def test_type_should_pass_mixed(self): | |
a_float = 1. | |
an_int = 2 | |
a_np_float = np.asarray([3., 4.]) | |
a_jax_int = jnp.asarray([5, 6]) | |
asserts.assert_type([a_float, an_int, a_np_float, a_jax_int], | |
[float, int, float, int]) | |
def test_type_should_fail_wrong_length(self, array, wrong_type): | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Length of `inputs` and `expected_types` must match')): | |
asserts.assert_type(array, wrong_type) | |
class AxisDimensionAssertionsTest(parameterized.TestCase): | |
def test_assert_axis_dimension_pass(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
asserts.assert_axis_dimension(tensor, axis=i, expected=s) | |
def test_assert_axis_dimension_fail(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('Expected tensor to have dimension')): | |
asserts.assert_axis_dimension(tensor, axis=i, expected=s + 1) | |
def test_assert_axis_dimension_axis_invalid(self): | |
tensor = jnp.ones((3, 2)) | |
for i in (2, -3): | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('not available')): | |
asserts.assert_axis_dimension(tensor, axis=i, expected=1) | |
def test_assert_axis_dimension_gt_pass(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
asserts.assert_axis_dimension_gt(tensor, axis=i, val=s - 1) | |
def test_assert_axis_dimension_gt_fail(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Expected tensor to have dimension greater than')): | |
asserts.assert_axis_dimension_gt(tensor, axis=i, val=s) | |
def test_assert_axis_dimension_gt_axis_invalid(self): | |
tensor = jnp.ones((3, 2)) | |
for i in (2, -3): | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('not available')): | |
asserts.assert_axis_dimension_gt(tensor, axis=i, val=0) | |
def test_assert_axis_dimension_gteq_pass(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s) | |
def test_assert_axis_dimension_gteq_fail(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Expected tensor to have dimension greater than or')): | |
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=s + 1) | |
def test_assert_axis_dimension_gteq_axis_invalid(self): | |
tensor = jnp.ones((3, 2)) | |
for i in (2, -3): | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('not available')): | |
asserts.assert_axis_dimension_gteq(tensor, axis=i, val=0) | |
def test_assert_axis_dimension_lt_pass(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
asserts.assert_axis_dimension_lt(tensor, axis=i, val=s + 1) | |
def test_assert_axis_dimension_lt_fail(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Expected tensor to have dimension less than')): | |
asserts.assert_axis_dimension_lt(tensor, axis=i, val=s) | |
def test_assert_axis_dimension_lt_axis_invalid(self): | |
tensor = jnp.ones((3, 2)) | |
for i in (2, -3): | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('not available')): | |
asserts.assert_axis_dimension_lt(tensor, axis=i, val=0) | |
def test_assert_axis_dimension_lteq_pass(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s) | |
def test_assert_axis_dimension_lteq_fail(self): | |
tensor = jnp.ones((3, 2, 7, 2)) | |
for i in range(-tensor.ndim, tensor.ndim): | |
s = tensor.shape[i] | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Expected tensor to have dimension less than or')): | |
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=s - 1) | |
def test_assert_axis_dimension_lteq_axis_invalid(self): | |
tensor = jnp.ones((3, 2)) | |
for i in (2, -3): | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('not available')): | |
asserts.assert_axis_dimension_lteq(tensor, axis=i, val=0) | |
def test_assert_axis_dimension_string_tensor(self): | |
tensor = ['ab', 'cddd'] | |
asserts.assert_axis_dimension(tensor, axis=0, expected=2) | |
asserts.assert_axis_dimension(np.array(tensor), axis=0, expected=2) | |
class TreeAssertionsTest(parameterized.TestCase): | |
def _assert_tree_structs_validation(self, assert_fn): | |
"""Checks that assert_fn correctly processes invalid args' structs.""" | |
get_val = lambda: jnp.zeros([3]) | |
tree1 = [[get_val(), get_val()], get_val()] | |
tree2 = [[get_val(), get_val()], get_val()] | |
tree3 = [get_val(), [get_val(), get_val()]] | |
tree4 = [get_val(), [get_val(), get_val()], get_val()] | |
tree5 = dict(x=1, y=2, z=3) | |
tree6 = dict(x=1, y=2, z=3, n=None) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Error in tree structs equality check.*trees 0 and 1')): | |
assert_fn(tree1, tree5) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Error in tree structs equality check.*trees 0 and 1')): | |
assert_fn(tree1, tree3) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Error in tree structs equality check.*trees 0 and 2')): | |
assert_fn([], [], tree1) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Error in tree structs equality check.*trees 0 and 3')): | |
assert_fn(tree2, tree1, tree2, tree3, tree1) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Error in tree structs equality check.*trees 0 and 2')): | |
assert_fn(tree2, tree1, tree4) | |
# Test `None`s. | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex('Error in tree structs equality check.*trees 0 and 1')): | |
assert_fn(tree5, tree6) | |
def test_assert_tree_no_nones(self): | |
with self.subTest('tree_no_nones'): | |
tree_ok = {'a': [jnp.zeros((1,))], 'b': 1} | |
asserts.assert_tree_no_nones(tree_ok) | |
with self.subTest('tree_with_nones'): | |
tree_with_none = {'a': [jnp.zeros((1,))], 'b': None} | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('Tree contains `None`') | |
): | |
asserts.assert_tree_no_nones(tree_with_none) | |
# Check `None`. | |
with self.subTest('input_none'): | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('Tree contains `None`') | |
): | |
asserts.assert_tree_no_nones(None) | |
def test_tree_all_finite_passes_finite(self): | |
finite_tree = {'a': jnp.ones((3,)), 'b': jnp.array([0.0, 0.0])} | |
asserts.assert_tree_all_finite(finite_tree) | |
self.assertTrue(asserts._assert_tree_all_finite_jittable(finite_tree)) | |
def test_tree_all_finite_should_fail_inf(self): | |
inf_tree = { | |
'finite_var': jnp.ones((3,)), | |
'inf_var': jnp.array([0.0, jnp.inf]), | |
} | |
err_msg = 'Tree contains non-finite value' | |
with self.assertRaisesRegex(AssertionError, _get_err_regex(err_msg)): | |
asserts.assert_tree_all_finite(inf_tree) | |
with self.assertRaisesRegex(ValueError, err_msg): | |
asserts._assert_tree_all_finite_jittable(inf_tree) | |
def test_assert_trees_all_equal_passes_same_tree(self): | |
tree = { | |
'a': [jnp.zeros((1,))], | |
'b': ([0], (0,), 0), | |
} | |
asserts.assert_trees_all_equal(tree, tree) | |
tree = jax.tree_map(jnp.asarray, tree) | |
self.assertTrue(asserts._assert_trees_all_equal_jittable(tree, tree)) | |
def test_assert_trees_all_equal_passes_values_equal(self): | |
tree1 = (jnp.array([0.0, 0.0]),) | |
tree2 = (jnp.array([0.0, 0.0]),) | |
asserts.assert_trees_all_equal(tree1, tree2) | |
self.assertTrue(asserts._assert_trees_all_equal_jittable(tree1, tree2)) | |
def test_assert_trees_all_equal_fail_values_close_but_not_equal(self): | |
tree1 = (jnp.array([1.0, 1.0]),) | |
tree2 = (jnp.array([1.0, 1.0 + 5e-7]),) | |
error_msg = 'Values not exactly equal' | |
with self.assertRaisesRegex(AssertionError, _get_err_regex(error_msg)): | |
asserts.assert_trees_all_equal(tree1, tree2) | |
with self.assertRaisesRegex(ValueError, error_msg): | |
asserts._assert_trees_all_equal_jittable(tree1, tree2) | |
def test_assert_trees_all_equal_strict_mode(self): | |
# See 'notes' section of | |
# https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_array_equal.html | |
# for details about the 'strict' mode of `numpy.testing.assert_array_equal`. | |
# tldr; it has special handling for scalar values (by default). | |
tree1 = {'a': jnp.array([1.0], dtype=jnp.float32), 'b': 0.0} | |
tree2 = {'a': jnp.array(1.0, dtype=jnp.float32), 'b': 0.0} | |
asserts.assert_trees_all_equal(tree1, tree2) | |
asserts.assert_trees_all_equal(tree1, tree2, strict=False) | |
err_regex = _get_err_regex(r'Trees 0 and 1 differ in leaves \'a\'') | |
with self.assertRaisesRegex(AssertionError, err_regex): | |
asserts.assert_trees_all_equal(tree1, tree2, strict=True) | |
err_regex = r'Trees 0 and 1 differ in leaves' | |
with self.assertRaisesRegex(ValueError, err_regex): | |
asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=True) | |
# We do not implement this special scalar handling in the jittable | |
# assertion (it's possible, but doesn't seem worth the effort). | |
err_regex = r'`strict=False` is not implemented' | |
with self.assertRaisesRegex(NotImplementedError, err_regex): | |
asserts._assert_trees_all_equal_jittable(tree1, tree2, strict=False) | |
def test_assert_trees_all_close_passes_same_tree(self): | |
tree = { | |
'a': [jnp.zeros((1,))], | |
'b': ([0], (0,), 0), | |
} | |
asserts.assert_trees_all_close(tree, tree) | |
tree = jax.tree_map(jnp.asarray, tree) | |
self.assertTrue(asserts._assert_trees_all_close_jittable(tree, tree)) | |
def test_assert_trees_all_close_passes_values_equal(self): | |
tree1 = (jnp.array([0.0, 0.0]),) | |
tree2 = (jnp.array([0.0, 0.0]),) | |
asserts.assert_trees_all_close(tree1, tree2) | |
self.assertTrue(asserts._assert_trees_all_close_jittable(tree1, tree2)) | |
def test_assert_trees_all_close_passes_values_close_but_not_equal(self): | |
tree1 = (jnp.array([1.0, 1.0]),) | |
tree2 = (jnp.array([1.0, 1.0 + 5e-7]),) | |
asserts.assert_trees_all_close(tree1, tree2, rtol=1e-6) | |
self.assertTrue( | |
asserts._assert_trees_all_close_jittable(tree1, tree2, rtol=1e-6)) | |
def test_assert_trees_all_close_bfloat16(self): | |
tree1 = {'a': jnp.asarray([0.8, 1.6], dtype=jnp.bfloat16)} | |
tree2 = { | |
'a': jnp.asarray([0.8, 1.6], dtype=jnp.bfloat16).astype(jnp.float32) | |
} | |
tree3 = {'a': jnp.asarray([0.8, 1.7], dtype=jnp.bfloat16)} | |
asserts.assert_trees_all_close(tree1, tree1) | |
asserts.assert_trees_all_close(tree1, tree2) | |
self.assertTrue(asserts._assert_trees_all_close_jittable(tree1, tree2)) | |
err_msg = 'Values not approximately equal' | |
err_regex = _get_err_regex(err_msg) | |
with self.assertRaisesRegex(AssertionError, err_regex): | |
asserts.assert_trees_all_close(tree1, tree3) | |
with self.assertRaisesRegex(ValueError, err_msg): | |
asserts._assert_trees_all_close_jittable(tree1, tree3) | |
with self.assertRaisesRegex(AssertionError, err_regex): | |
asserts.assert_trees_all_close(tree2, tree3) | |
with self.assertRaisesRegex(ValueError, err_msg): | |
asserts._assert_trees_all_close_jittable(tree2, tree3) | |
def test_assert_trees_all_close_ulp_jittable_raises_valueerror(self): | |
tree = (jnp.array([1.0]),) | |
err_msg = 'assert_trees_all_close_ulp is not supported within JIT contexts.' | |
err_regex = _get_err_regex(err_msg) | |
with self.assertRaisesRegex(RuntimeError, err_regex): | |
asserts._assert_trees_all_close_ulp_jittable(tree, tree) | |
def test_assert_trees_all_close_ulp_passes_same_tree(self): | |
tree = { | |
'a': [jnp.zeros((1,))], | |
'b': ([0], (0,), 0), | |
} | |
asserts.assert_trees_all_close_ulp(tree, tree) | |
def test_assert_trees_all_close_ulp_passes_values_equal(self): | |
tree1 = (jnp.array([0.0, 0.0]),) | |
tree2 = (jnp.array([0.0, 0.0]),) | |
try: | |
asserts.assert_trees_all_close_ulp(tree1, tree2) | |
except AssertionError: | |
self.fail('assert_trees_all_close_ulp raised AssertionError') | |
def test_assert_trees_all_close_ulp_passes_values_within_maxulp(self): | |
# np.spacing(np.float32(1 << 23)) == 1.0. | |
value_where_ulp_is_1 = np.float32(1 << 23) | |
tree1 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1]),) | |
tree2 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1 + 1.0]),) | |
assert tree2[0][0] != tree2[0][1] | |
try: | |
asserts.assert_trees_all_close_ulp(tree1, tree2, maxulp=2) | |
except AssertionError: | |
self.fail('assert_trees_all_close_ulp raised AssertionError') | |
def test_assert_trees_all_close_ulp_passes_values_maxulp_apart(self): | |
# np.spacing(np.float32(1 << 23)) == 1.0. | |
value_where_ulp_is_1 = np.float32(1 << 23) | |
tree1 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1]),) | |
tree2 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1 + 1.0]),) | |
assert tree2[0][0] != tree2[0][1] | |
try: | |
asserts.assert_trees_all_close_ulp(tree1, tree2, maxulp=1) | |
except AssertionError: | |
self.fail('assert_trees_all_close_ulp raised AssertionError') | |
def test_assert_trees_all_close_ulp_fails_values_gt_maxulp_apart(self): | |
# np.spacing(np.float32(1 << 23)) == 1.0. | |
value_where_ulp_is_1 = np.float32(1 << 23) | |
tree1 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1]),) | |
tree2 = (jnp.array([value_where_ulp_is_1, value_where_ulp_is_1 + 2.0]),) | |
assert tree2[0][0] != tree2[0][1] | |
err_msg = re.escape( | |
'not almost equal up to 1 ULP (max difference is 2 ULP)' | |
) | |
err_regex = _get_err_regex(err_msg) | |
with self.assertRaisesRegex(AssertionError, err_regex): | |
asserts.assert_trees_all_close_ulp(tree1, tree2, maxulp=1) | |
def test_assert_trees_all_close_ulp_fails_bfloat16(self): | |
tree_f32 = (jnp.array([0.0]),) | |
tree_bf16 = (jnp.array([0.0], dtype=jnp.bfloat16),) | |
err_msg = 'ULP assertions are not currently supported for bfloat16.' | |
err_regex = _get_err_regex(err_msg) | |
with self.assertRaisesRegex(ValueError, err_regex): # pylint: disable=g-error-prone-assert-raises | |
asserts.assert_trees_all_close_ulp(tree_bf16, tree_bf16) | |
with self.assertRaisesRegex(ValueError, err_regex): # pylint: disable=g-error-prone-assert-raises | |
asserts.assert_trees_all_close_ulp(tree_bf16, tree_f32) | |
def test_assert_tree_has_only_ndarrays(self): | |
# Check correct inputs. | |
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(1), 'b': np.ones(3)}) | |
asserts.assert_tree_has_only_ndarrays(np.zeros(4)) | |
asserts.assert_tree_has_only_ndarrays(()) | |
# Check incorrect inputs. | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'b\' is not an ndarray')): | |
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros((1,)), 'b': 1}) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'b/1\' is not an ndarray')): | |
asserts.assert_tree_has_only_ndarrays({'a': jnp.zeros(101), 'b': [1, 2]}) | |
def test_assert_tree_is_on_host(self): | |
cpu = jax.local_devices(backend='cpu')[0] | |
# Check Numpy arrays. | |
for flag in (False, True): | |
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': np.ones(3)}, | |
allow_cpu_device=flag) | |
asserts.assert_tree_is_on_host(np.zeros(4), allow_cpu_device=flag) | |
asserts.assert_tree_is_on_host( | |
jax.device_get(jax.device_put(np.ones(3))), allow_cpu_device=flag) | |
asserts.assert_tree_is_on_host((), allow_cpu_device=flag) | |
# Check DeviceArray (for platforms other than CPU). | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'a\' resides on')): | |
asserts.assert_tree_is_on_host({'a': jnp.zeros(1)}, | |
allow_cpu_device=False) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'a\' resides on')): | |
asserts.assert_tree_is_on_host({'a': jax.device_put(np.zeros(1))}, | |
allow_cpu_device=False) | |
# Check Jax arrays on CPU. | |
cpu_arr = jax.device_put(np.ones(5), cpu) | |
asserts.assert_tree_is_on_host({'a': cpu_arr}) | |
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': cpu_arr}) | |
# Check sharded Jax arrays on CPUs. | |
asserts.assert_tree_is_on_host( | |
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))}, | |
allow_cpu_device=True, | |
allow_sharded_arrays=True, | |
) | |
# Disallow JAX arrays on CPU. | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'a\' resides on.*CPU')): | |
asserts.assert_tree_is_on_host({'a': cpu_arr}, | |
allow_cpu_device=False) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'b\' resides on.*CPU')): | |
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': cpu_arr}, | |
allow_cpu_device=False) | |
# Check incorrect inputs. | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'b\' is not an ndarray')): | |
asserts.assert_tree_is_on_host({'a': np.zeros(1), 'b': 1}) | |
# ShardedArrays are disallowed. | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('sharded arrays are disallowed') | |
): | |
asserts.assert_tree_is_on_host( | |
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))}, | |
allow_cpu_device=False, | |
) | |
# ShardedArrays on CPUs, CPUs disallowed. | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex("'a' is sharded and resides on.*CPU") | |
): | |
asserts.assert_tree_is_on_host( | |
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))}, | |
allow_cpu_device=False, | |
allow_sharded_arrays=True, | |
) | |
def test_assert_tree_is_on_device(self): | |
# Check CPU platform. | |
cpu = jax.local_devices(backend='cpu')[0] | |
to_cpu = lambda x: jax.device_put(x, cpu) | |
cpu_tree = {'a': to_cpu(np.zeros(1)), 'b': to_cpu(np.ones(3))} | |
asserts.assert_tree_is_on_device(cpu_tree, device=cpu) | |
asserts.assert_tree_is_on_device(cpu_tree, platform='cpu') | |
asserts.assert_tree_is_on_device(cpu_tree, platform=['cpu']) | |
asserts.assert_tree_is_on_device(cpu_tree, device=cpu, platform='') | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'a\' resides on \'cpu\'')): | |
asserts.assert_tree_is_on_device(cpu_tree, platform='tpu') | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'b\' resides on \'cpu\'')): | |
asserts.assert_tree_is_on_device(cpu_tree, platform=('tpu', 'gpu')) | |
# Check TPU platform (if available). | |
if _num_devices_available('tpu') > 1: | |
tpu_1, tpu_2 = jax.devices('tpu')[:2] | |
to_tpu_1 = lambda x: jax.device_put(x, tpu_1) | |
to_tpu_2 = lambda x: jax.device_put(x, tpu_2) | |
tpu_1_tree = {'a': to_tpu_1(np.zeros(1)), 'b': to_tpu_1(np.ones(3))} | |
tpu_2_tree = {'a': to_tpu_2(np.zeros(1)), 'b': to_tpu_2(np.ones(3))} | |
tpu_1_2_tree = {'a': to_tpu_1(np.zeros(1)), 'b': to_tpu_2(np.ones(3))} | |
# Device asserts. | |
asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_1) | |
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_2) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=0") | |
): | |
asserts.assert_tree_is_on_device(tpu_1_tree, device=tpu_2) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex(r"'a' resides on.*TpuDevice\(id=1") | |
): | |
asserts.assert_tree_is_on_device(tpu_2_tree, device=tpu_1) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex("'a' resides on .*Cpu") | |
): | |
asserts.assert_tree_is_on_device(cpu_tree, device=tpu_2) | |
# Platform asserts. | |
asserts.assert_tree_is_on_device(tpu_1_tree, platform='tpu') | |
asserts.assert_tree_is_on_device(tpu_2_tree, platform='tpu') | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'a\' resides on \'tpu\'')): | |
asserts.assert_tree_is_on_device(tpu_1_tree, platform='cpu') | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'a\' resides on \'tpu\'')): | |
asserts.assert_tree_is_on_device(tpu_2_tree, platform='gpu') | |
# Mixed cases. | |
asserts.assert_tree_is_on_device(tpu_1_2_tree, platform='tpu') | |
asserts.assert_tree_is_on_device((tpu_1_2_tree, cpu_tree), | |
platform=('cpu', 'tpu')) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'1/a\' resides on \'cpu\'')): | |
asserts.assert_tree_is_on_device((tpu_1_2_tree, cpu_tree), | |
platform='tpu') | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'0/a\' resides on \'tpu\'')): | |
asserts.assert_tree_is_on_device((tpu_1_2_tree, cpu_tree), | |
platform=('cpu', 'gpu')) | |
# Check incorrect inputs. | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'b\' is not an ndarray')): | |
asserts.assert_tree_is_on_device({'a': np.zeros(1), 'b': 1}) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'b\' has unexpected type')): | |
asserts.assert_tree_is_on_device({'a': jnp.zeros(1), 'b': np.ones(3)}) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'a\' is a ShardedDeviceArra')): | |
# ShardedArrays are disallowed. | |
asserts.assert_tree_is_on_device( | |
{'a': jax.device_put_replicated(np.zeros(1), (cpu,))}, device=cpu) | |
def test_assert_tree_is_sharded(self): | |
np_tree = {'a': np.zeros(1), 'b': np.ones(3)} | |
def _format(*devs): | |
return re.escape(f'{devs}') | |
# Check single-device case. | |
cpu = jax.local_devices(backend='cpu')[0] | |
cpu_tree = jax.device_put_replicated(np_tree, (cpu,)) | |
asserts.assert_tree_is_sharded(cpu_tree, devices=(cpu,)) | |
asserts.assert_tree_is_sharded((), devices=(cpu,)) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex(r'\'a\' is sharded.*expected \(\)')): | |
asserts.assert_tree_is_sharded(cpu_tree, devices=()) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'a\' is sharded across {_format(cpu)}.*' | |
f'expected {_format(cpu, cpu)}')): | |
asserts.assert_tree_is_sharded(cpu_tree, devices=(cpu, cpu)) | |
# Check multiple-devices case (if available). | |
if _num_devices_available('tpu') > 1: | |
tpu_1, tpu_2 = jax.devices('tpu')[:2] | |
tpu_1_tree = jax.device_put_replicated(np_tree, (tpu_1,)) | |
tpu_2_tree = jax.device_put_replicated(np_tree, (tpu_2,)) | |
tpu_1_2_tree = jax.device_put_replicated(np_tree, (tpu_1, tpu_2)) | |
tpu_2_1_tree = jax.device_put_replicated(np_tree, (tpu_2, tpu_1)) | |
asserts.assert_tree_is_sharded(tpu_1_2_tree, devices=(tpu_1, tpu_2)) | |
asserts.assert_tree_is_sharded(tpu_2_1_tree, devices=(tpu_2, tpu_1)) | |
# Wrong device. | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'a\' is sharded across {_format(tpu_1)}.*' | |
f'expected {_format(tpu_2)}')): | |
asserts.assert_tree_is_sharded(tpu_1_tree, devices=(tpu_2,)) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'a\' is sharded across {_format(cpu)}.*' | |
f'expected {_format(tpu_2)}')): | |
asserts.assert_tree_is_sharded(cpu_tree, devices=(tpu_2,)) | |
# Too many devices. | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'a\' is sharded across {_format(tpu_1)}.*' | |
f'expected {_format(tpu_1, tpu_2)}')): | |
asserts.assert_tree_is_sharded(tpu_1_tree, devices=(tpu_1, tpu_2)) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'a\' is sharded across {_format(tpu_1, tpu_2)}.*' | |
f'expected {_format(tpu_1, tpu_2, cpu)}')): | |
asserts.assert_tree_is_sharded( | |
tpu_1_2_tree, devices=(tpu_1, tpu_2, cpu)) | |
# Wrong order. | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'a\' is sharded across {_format(tpu_2, tpu_1)}.*' | |
f'expected {_format(tpu_1, tpu_2)}')): | |
asserts.assert_tree_is_sharded(tpu_2_1_tree, devices=(tpu_1, tpu_2)) | |
# Mixed cases. | |
mixed_tree = (tpu_1_tree, tpu_2_tree) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'0/a\' is sharded across {_format(tpu_1)}.*' | |
f'expected {_format(tpu_2)}')): | |
asserts.assert_tree_is_sharded(mixed_tree, devices=(tpu_2,)) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'1/a\' is sharded across {_format(tpu_2)}.*' | |
f'expected {_format(tpu_1)}')): | |
asserts.assert_tree_is_sharded(mixed_tree, devices=(tpu_1,)) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(f'\'0/a\' is sharded across {_format(tpu_1)}.*' | |
f'expected {_format(tpu_1, tpu_2)}')): | |
asserts.assert_tree_is_sharded(mixed_tree, devices=(tpu_1, tpu_2)) | |
# Check incorrect inputs. | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('\'1\' is not an ndarray')): | |
asserts.assert_tree_is_sharded((cpu_tree, 1123), devices=(cpu,)) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('\'a\' is not a jax.Array')): | |
asserts.assert_tree_is_sharded({'a': np.zeros(1)}, devices=(cpu,)) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('\'a\' is not sharded')): | |
asserts.assert_tree_is_sharded({'a': jnp.zeros(1)}, devices=(cpu,)) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex("'a' is not sharded.*Cpu") | |
): | |
asserts.assert_tree_is_sharded({'a': jax.device_put(np.zeros(1), cpu)}, | |
devices=(cpu,)) | |
def test_assert_trees_all_close_fails_different_structure(self): | |
self._assert_tree_structs_validation(asserts.assert_trees_all_close) | |
def test_assert_trees_all_close_fails_values_differ(self): | |
tree1 = jnp.array([0.0, 2.0]) | |
tree2 = jnp.array([0.0, 2.1]) | |
asserts.assert_trees_all_close(tree1, tree2, atol=0.1) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('Values not approximately equal')): | |
asserts.assert_trees_all_close(tree1, tree2, atol=0.01) | |
asserts.assert_trees_all_close(tree1, tree2, rtol=0.1) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('Values not approximately equal')): | |
asserts.assert_trees_all_close(tree1, tree2, rtol=0.01) | |
def test_assert_trees_all_equal_sizes(self): | |
get_val = lambda s1, s2: jnp.zeros([s1, s2]) | |
tree1 = dict(a1=get_val(3, 1), d=dict(a2=get_val(4, 1), a3=get_val(5, 3))) | |
tree2 = dict(a1=get_val(3, 1), d=dict(a2=get_val(4, 1), a3=get_val(5, 3))) | |
tree3 = dict(a1=get_val(3, 1), d=dict(a2=get_val(4, 2), a3=get_val(5, 3))) | |
self._assert_tree_structs_validation(asserts.assert_trees_all_equal_sizes) | |
asserts.assert_trees_all_equal_sizes(tree1, tree1) | |
asserts.assert_trees_all_equal_sizes(tree2, tree1) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex( | |
r'Trees 0 and 1 differ in leaves \'d/a2\': sizes: 4 != 8' | |
)): | |
asserts.assert_trees_all_equal_sizes(tree1, tree3) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex( | |
r'Trees 0 and 3 differ in leaves \'d/a2\': sizes: 4 != 8' | |
)): | |
asserts.assert_trees_all_equal_sizes(tree1, tree2, tree2, tree3, tree1) | |
def test_assert_trees_all_equal_shapes(self): | |
get_val = lambda s: jnp.zeros([s]) | |
tree1 = dict(a1=get_val(3), d=dict(a2=get_val(4), a3=get_val(5))) | |
tree2 = dict(a1=get_val(3), d=dict(a2=get_val(4), a3=get_val(5))) | |
tree3 = dict(a1=get_val(3), d=dict(a2=get_val(7), a3=get_val(5))) | |
self._assert_tree_structs_validation(asserts.assert_trees_all_equal_shapes) | |
asserts.assert_trees_all_equal_shapes(tree1, tree1) | |
asserts.assert_trees_all_equal_shapes(tree2, tree1) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex( | |
r'Trees 0 and 1 differ in leaves \'d/a2\': shapes: \(4,\) != \(7,\)' | |
)): | |
asserts.assert_trees_all_equal_shapes(tree1, tree3) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex( | |
r'Trees 0 and 3 differ in leaves \'d/a2\': shapes: \(4,\) != \(7,\)' | |
)): | |
asserts.assert_trees_all_equal_shapes(tree1, tree2, tree2, tree3, tree1) | |
def test_assert_trees_all_equal_structs(self): | |
get_val = lambda: jnp.zeros([3]) | |
tree1 = [[get_val(), get_val()], get_val()] | |
tree2 = [[get_val(), get_val()], get_val()] | |
tree3 = [get_val(), [get_val(), get_val()]] | |
asserts.assert_trees_all_equal_structs(tree1, tree2, tree2, tree1) | |
asserts.assert_trees_all_equal_structs(tree3, tree3) | |
self._assert_tree_structs_validation(asserts.assert_trees_all_equal_structs) | |
def test_assert_tree_shape_prefix(self, shape): | |
tree = {'x': {'y': np.zeros([3, 2])}, 'z': np.zeros([3, 2, 1])} | |
with self.subTest('tuple'): | |
asserts.assert_tree_shape_prefix(tree, tuple(shape)) | |
with self.subTest('list'): | |
asserts.assert_tree_shape_prefix(tree, list(shape)) | |
def test_leaf_shape_should_fail_wrong_length(self): | |
tree = {'x': {'y': np.zeros([3, 2])}, 'z': np.zeros([3, 2, 1])} | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')): | |
asserts.assert_tree_shape_prefix(tree, (3, 2, 1)) | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')): | |
asserts.assert_tree_shape_prefix(tree, [3, 2, 1]) | |
def test_assert_tree_shape_suffix_matching(self, shape): | |
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([2, 1])} | |
with self.subTest('tuple'): | |
asserts.assert_tree_shape_suffix(tree, tuple(shape)) | |
with self.subTest('list'): | |
asserts.assert_tree_shape_suffix(tree, list(shape)) | |
def test_assert_tree_shape_suffix_mismatch(self, leaf, shape_true, shape): | |
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([1, 1])} | |
error_msg = ( | |
r'Tree leaf \'' + str(leaf) + '\'.*different from expected: ' | |
+ re.escape(str(shape_true)) + ' != ' + re.escape(str(shape)) | |
) | |
with self.subTest('tuple'): | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex( | |
error_msg)): | |
asserts.assert_tree_shape_suffix(tree, tuple(shape)) | |
with self.subTest('list'): | |
with self.assertRaisesRegex( | |
AssertionError, | |
_get_err_regex( | |
error_msg)): | |
asserts.assert_tree_shape_suffix(tree, list(shape)) | |
def test_assert_tree_shape_suffix_long_suffix(self): | |
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([4, 2, 1])} | |
asserts.assert_tree_shape_suffix(tree, (4, 2, 1)) | |
asserts.assert_tree_shape_suffix(tree, [4, 2, 1]) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('which is smaller than the expected')): | |
asserts.assert_tree_shape_suffix(tree, (3, 4, 2, 1)) | |
with self.assertRaisesRegex( | |
AssertionError, _get_err_regex('which is smaller than the expected')): | |
asserts.assert_tree_shape_suffix(tree, [3, 4, 2, 1]) | |
def test_assert_trees_all_equal_dtypes(self): | |
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)} | |
t_1 = {'x': np.zeros(5, dtype=np.uint16), 'y': np.ones(4, dtype=np.float32)} | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('Trees 0 and 1 differ')): | |
asserts.assert_trees_all_equal_dtypes(t_0, t_1) | |
t_2 = {'x': np.zeros(6, dtype=jnp.int16), 'y': np.ones(6, dtype=np.float32)} | |
asserts.assert_trees_all_equal_dtypes(t_0, t_2, t_0) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('Trees 0 and 4 differ')): | |
asserts.assert_trees_all_equal_dtypes(t_0, t_0, t_2, t_0, t_1, t_2) | |
# np vs jnp | |
t_3 = {'x': np.zeros(1, dtype=np.int16), 'y': np.ones(2, dtype=jnp.float32)} | |
t_4 = { | |
'x': np.zeros(1, dtype=jnp.int16), | |
'y': np.ones(2, dtype=jnp.float32) | |
} | |
asserts.assert_trees_all_equal_dtypes(t_0, t_2, t_3, t_4) | |
# bfloat16 | |
t_5 = {'y': np.ones(2, dtype=np.float16)} | |
t_6 = {'y': np.ones(2, dtype=jnp.bfloat16)} | |
asserts.assert_trees_all_equal_dtypes(t_6, t_6) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('Trees 0 and 1 differ')): | |
asserts.assert_trees_all_equal_dtypes(t_5, t_6) | |
def test_assert_trees_all_equal_shapes_and_dtypes(self): | |
# Test dtypes | |
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)} | |
t_1 = {'x': np.zeros(3, dtype=np.uint16), 'y': np.ones(2, dtype=np.float32)} | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('Trees 0 and 1 differ')): | |
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_1) | |
t_2 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)} | |
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_2, t_0) | |
# Test shapes | |
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)} | |
t_1 = {'x': np.zeros(4, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)} | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('Trees 0 and 1 differ')): | |
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_1) | |
t_2 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)} | |
asserts.assert_trees_all_equal_shapes_and_dtypes(t_0, t_2, t_0) | |
def test_assert_trees_all_equal_wrong_usage(self): | |
# not an array | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex(r'is not a \(j-\)np array')): | |
asserts.assert_trees_all_equal_dtypes({'x': 1.}, {'x': np.array(1.)}) | |
# 1 tree | |
with self.assertRaisesRegex( # pylint:disable=g-error-prone-assert-raises | |
ValueError, 'Assertions over only one tree does not make sense'): | |
asserts.assert_trees_all_equal_dtypes({'x': 1.}) | |
def test_assert_trees_all_equal_none(self): | |
t_0 = {'x': None, 'y': np.array(2, dtype=np.int32)} | |
t_1 = {'x': None, 'y': np.array([23], dtype=np.int32)} | |
t_2 = {'x': None, 'y': np.array(3, dtype=np.float32)} | |
t_3 = {'y': np.array([23], dtype=np.int32)} | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('Trees 0 and 2 differ')): | |
asserts.assert_trees_all_equal_dtypes(t_0, t_1, t_2) | |
asserts.assert_trees_all_equal_dtypes(t_0, t_1) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('trees 0 and 1 do not match')): | |
asserts.assert_trees_all_equal_dtypes(t_0, t_3) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('trees 0 and 1 do not match')): | |
asserts.assert_trees_all_equal_dtypes(t_0, t_3) | |
class DevicesAssertTest(parameterized.TestCase): | |
def _device_count(self, backend): | |
try: | |
return jax.device_count(backend) | |
except RuntimeError: | |
return 0 | |
def test_not_less_than(self, devtype): | |
n = self._device_count(devtype) | |
if n > 0: | |
asserts.assert_devices_available( | |
n - 1, devtype, backend=devtype, not_less_than=True) | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex(f'Only {n} < {n + 1}')): | |
asserts.assert_devices_available( | |
n + 1, devtype, backend=devtype, not_less_than=True) | |
else: | |
with self.assertRaisesRegex(RuntimeError, # pylint: disable=g-error-prone-assert-raises | |
'(failed to initialize)|(Unknown backend)'): | |
asserts.assert_devices_available( | |
n - 1, devtype, backend=devtype, not_less_than=True) | |
def test_unsupported_device(self): | |
with self.assertRaisesRegex(ValueError, 'Unknown device type'): # pylint: disable=g-error-prone-assert-raises | |
asserts.assert_devices_available(1, 'unsupported_devtype') | |
def test_gpu_assert(self): | |
n_gpu = self._device_count('gpu') | |
asserts.assert_devices_available(n_gpu, 'gpu') | |
if n_gpu: | |
asserts.assert_gpu_available() | |
else: | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('No 2 GPUs available')): | |
asserts.assert_devices_available(2, 'gpu') | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('No GPU devices available')): | |
asserts.assert_gpu_available() | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('No 2 GPUs available')): | |
asserts.assert_devices_available(2, 'gpu', backend='cpu') | |
def test_cpu_assert(self): | |
n_cpu = jax.device_count('cpu') | |
asserts.assert_devices_available(n_cpu, 'cpu', backend='cpu') | |
def test_tpu_assert(self): | |
n_tpu = self._device_count('tpu') | |
asserts.assert_devices_available(n_tpu, 'tpu') | |
if n_tpu: | |
asserts.assert_tpu_available() | |
else: | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('No 3 TPUs available')): | |
asserts.assert_devices_available(3, 'tpu') | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('No TPU devices available')): | |
asserts.assert_tpu_available() | |
with self.assertRaisesRegex(AssertionError, | |
_get_err_regex('No 3 TPUs available')): | |
asserts.assert_devices_available(3, 'tpu', backend='cpu') | |
class NumericalGradsAssertTest(parameterized.TestCase): | |
def _test_fn(self, fn, init_args, seed, n=10): | |
rng_key = jax.random.PRNGKey(seed) | |
for _ in range(n): | |
rng_key, *tree_keys = jax.random.split(rng_key, len(init_args) + 1) | |
x = jax.tree_util.tree_map( | |
lambda k, x: jax.random.uniform(k, shape=x.shape), | |
list(tree_keys), list(init_args)) | |
asserts.assert_numerical_grads(fn, x, order=1) | |
def test_easy(self, x_shape, seed): | |
f_easy = lambda x: jnp.sum(x**2 - 2 * x + 10) | |
init_args = (jnp.zeros(x_shape),) | |
self._test_fn(f_easy, init_args, seed) | |
def test_easy_with_stop_gradient(self, x_shape, seed): | |
f_easy_sg = lambda x: jnp.sum(jax.lax.stop_gradient(x**2) - 2 * x + 10) | |
init_args = (jnp.zeros(x_shape),) | |
self._test_fn(f_easy_sg, init_args, seed) | |
def test_hard(self, x_shape, seed): | |
def f_hard_with_sg(lr, x): | |
inner_loss = lambda y: jnp.sum((y - 1.0)**2) | |
inner_loss_grad = jax.grad(inner_loss) | |
def fu(lr, x): | |
for _ in range(10): | |
x1 = x - lr * inner_loss_grad(x) + 100 * lr**2 | |
x2 = x - lr * inner_loss_grad(x) - 100 * lr**2 | |
x = jax.lax.select((x > 3.).any(), x1, x2 + lr) | |
return x | |
y = fu(lr, x) | |
return jnp.sum(inner_loss(y)) | |
lr = jnp.zeros([1] * len(x_shape)) | |
x = jnp.zeros(x_shape) | |
self._test_fn(f_hard_with_sg, (lr, x), seed) | |
def test_hard_with_stop_gradient(self, x_shape, seed): | |
def f_hard_with_sg(lr, x): | |
inner_loss = lambda y: jnp.sum((y - 1.0)**2) | |
inner_loss_grad = jax.grad(inner_loss) | |
def fu(lr, x): | |
for _ in range(10): | |
x1 = x - lr * inner_loss_grad(x) + 100 * jax.lax.stop_gradient(lr)**2 | |
x2 = x - lr * inner_loss_grad(x) - 100 * lr**2 | |
x = jax.lax.select((x > 3.).any(), x1, x2 + jax.lax.stop_gradient(lr)) | |
return x | |
y = fu(lr, x) | |
return jnp.sum(inner_loss(y)) | |
lr = jnp.zeros([1] * len(x_shape)) | |
x = jnp.zeros(x_shape) | |
self._test_fn(f_hard_with_sg, (lr, x), seed) | |
class EqualAssertionsTest(parameterized.TestCase): | |
def test_assert_equal_pass(self, first, second): | |
asserts.assert_equal(first, second) | |
def test_assert_equal_pass_on_arrays(self): | |
# Not using named_parameters, becase JAX cannot be used before app.run(). | |
asserts.assert_equal(jnp.ones([]), np.ones([])) | |
asserts.assert_equal( | |
jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64)) | |
def test_assert_equal_fail(self, first, second): | |
with self.assertRaises(AssertionError): | |
asserts.assert_equal(first, second) | |
class IsDivisibleTest(parameterized.TestCase): | |
def test_assert_is_divisible(self): | |
asserts.assert_is_divisible(6, 3) | |
def test_assert_is_divisible_fail(self): | |
with self.assertRaises(AssertionError): | |
asserts.assert_is_divisible(7, 3) | |
if __name__ == '__main__': | |
jax.config.update('jax_numpy_rank_promotion', 'raise') | |
absltest.main() | |