Spaces:
Building
Building
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for `variants.py`. | |
To run tests in multi-cpu regime, one need to set the flag `--n_cpu_devices=N`. | |
""" | |
import inspect | |
import itertools | |
import unittest | |
from absl import flags | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
from chex._src import asserts | |
from chex._src import fake | |
from chex._src import pytypes | |
from chex._src import variants | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
FLAGS = flags.FLAGS | |
ArrayBatched = pytypes.ArrayBatched | |
DEFAULT_FN = lambda arg_0, arg_1: arg_1 - arg_0 | |
DEFAULT_PARAMS = ((1, 2, 1), (4, 6, 2)) | |
DEFAULT_NDARRAY_PARAMS_SHAPE = (5, 7) | |
DEFAULT_NAMED_PARAMS = (('case_0', 1, 2, 1), ('case_1', 4, 6, 2)) | |
# Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests. | |
def setUpModule(): | |
fake.set_n_cpu_devices() | |
asserts.assert_devices_available( | |
FLAGS['chex_n_cpu_devices'].value, 'cpu', backend='cpu') | |
def _scalar_to_ndarray(x, shape=None): | |
return np.broadcast_to(x, shape or DEFAULT_NDARRAY_PARAMS_SHAPE) | |
def _variant_default_tests_generator(fn, is_jit_context, which_variants, | |
**var_kwargs): | |
"""Returns a generator with standard tests. | |
For internal usage. Allows to dynamically generate common tests. | |
See tests' names and comments for more information. | |
Args: | |
fn: a separate function to be tested (without `self` argument). | |
is_jit_context: is a function is supposed to be JIT-ted. | |
which_variants: chex variants to use in tests generation. | |
**var_kwargs: kwargs for variants wrappers. | |
Returns: | |
A generator with tests. | |
""" | |
# All generated tests use default arguments (defined at the top of this file). | |
arg_0, arg_1, expected = DEFAULT_PARAMS[0] | |
varg_0, varg_1, vexpected = ( | |
_scalar_to_ndarray(a) for a in (arg_0, arg_1, expected)) | |
# We test whether the function has been jitted by introducing a counter | |
# variable as a side-effect. When the function is repeatedly called, jitted | |
# code will only execute the side-effect once | |
python_execution_count = 0 | |
def fn_with_counter(*args, **kwargs): | |
nonlocal python_execution_count | |
python_execution_count += 1 | |
return fn(*args, **kwargs) | |
def exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1): | |
self.assertEqual(python_execution_count, 0) | |
_ = var_fn(arg_0, arg_1) | |
# In jit context, JAX can omit retracing a function from the previous | |
# test, hence `python_execution_count` will be equal to 0. | |
# In non-jit context, `python_execution_count` must always increase. | |
if not is_jit_context: | |
self.assertEqual(python_execution_count, 1) | |
actual = var_fn(arg_0, arg_1) | |
if is_jit_context: | |
# Either 1 (initial tracing) or 0 (function reuse). | |
self.assertLess(python_execution_count, 2) | |
else: | |
self.assertEqual(python_execution_count, 2) | |
return actual | |
# Here, various tests follow. Tests' names intended to be self-descriptive. | |
def test_with_scalar_args(self): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
var_fn = self.variant(fn_with_counter, **var_kwargs) | |
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
def test_called_variant(self): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
var_fn = self.variant(**var_kwargs)(fn_with_counter) | |
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
def test_with_kwargs(self): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
var_fn = self.variant(fn_with_counter, **var_kwargs) | |
actual = exec_with_tracing_counter_checks( | |
self, var_fn, arg_1=arg_1, arg_0=arg_0) | |
self.assertEqual(actual, expected) | |
def test_scalar_parameters(self, arg_0, arg_1, expected): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
var_fn = self.variant(fn_with_counter, **var_kwargs) | |
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
def test_named_scalar_parameters(self, arg_0, arg_1, expected): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
var_fn = self.variant(fn_with_counter, **var_kwargs) | |
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
def test_with_ndarray_args(self): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
var_fn = self.variant(fn_with_counter, **var_kwargs) | |
actual = exec_with_tracing_counter_checks(self, var_fn, varg_0, varg_1) | |
vexpected_ = vexpected | |
# pmap variant case. | |
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1: | |
vexpected_ = jnp.broadcast_to(vexpected_, actual.shape) | |
np.testing.assert_array_equal(actual, vexpected_) | |
def test_ndarray_parameters(self, arg_0, arg_1, expected): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
varg_0, varg_1, vexpected = ( | |
_scalar_to_ndarray(a) for a in (arg_0, arg_1, expected)) | |
var_fn = self.variant(fn_with_counter, **var_kwargs) | |
actual = exec_with_tracing_counter_checks(self, var_fn, varg_0, varg_1) | |
# pmap variant case. | |
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1: | |
vexpected = jnp.broadcast_to(vexpected, actual.shape) | |
np.testing.assert_array_equal(actual, vexpected) | |
def test_ndarray_named_parameters(self, arg_0, arg_1, expected): | |
nonlocal python_execution_count | |
python_execution_count = 0 | |
varg_0, varg_1, vexpected = ( | |
_scalar_to_ndarray(a) for a in (arg_0, arg_1, expected)) | |
var_fn = self.variant(fn_with_counter, **var_kwargs) | |
actual = exec_with_tracing_counter_checks(self, var_fn, varg_0, varg_1) | |
# pmap variant case. | |
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1: | |
vexpected = jnp.broadcast_to(vexpected, actual.shape) | |
np.testing.assert_array_equal(actual, vexpected) | |
all_tests = (test_with_scalar_args, test_called_variant, test_with_kwargs, | |
test_scalar_parameters, test_named_scalar_parameters, | |
test_with_ndarray_args, test_ndarray_parameters, | |
test_ndarray_named_parameters) | |
# Each test is a generator itself, hence we use chaining from itertools. | |
return itertools.chain(*all_tests) | |
class ParamsProductTest(absltest.TestCase): | |
def test_product(self): | |
l1 = ( | |
('x1', 1, 10), | |
('x2', 2, 20), | |
) | |
l2 = ( | |
('y1', 3), | |
('y2', 4), | |
) | |
l3 = ( | |
('z1', 5, 50), | |
('z2', 6, 60), | |
) | |
l4 = (('aux', 'AUX'),) | |
expected = [('x1', 1, 10, 'y1', 3, 'z1', 5, 50, 'aux', 'AUX'), | |
('x1', 1, 10, 'y1', 3, 'z2', 6, 60, 'aux', 'AUX'), | |
('x1', 1, 10, 'y2', 4, 'z1', 5, 50, 'aux', 'AUX'), | |
('x1', 1, 10, 'y2', 4, 'z2', 6, 60, 'aux', 'AUX'), | |
('x2', 2, 20, 'y1', 3, 'z1', 5, 50, 'aux', 'AUX'), | |
('x2', 2, 20, 'y1', 3, 'z2', 6, 60, 'aux', 'AUX'), | |
('x2', 2, 20, 'y2', 4, 'z1', 5, 50, 'aux', 'AUX'), | |
('x2', 2, 20, 'y2', 4, 'z2', 6, 60, 'aux', 'AUX')] | |
product = list(variants.params_product(l1, l2, l3, l4, named=False)) | |
self.assertEqual(product, expected) | |
named_expected = [('x1_y1_z1_aux', 1, 10, 3, 5, 50, 'AUX'), | |
('x1_y1_z2_aux', 1, 10, 3, 6, 60, 'AUX'), | |
('x1_y2_z1_aux', 1, 10, 4, 5, 50, 'AUX'), | |
('x1_y2_z2_aux', 1, 10, 4, 6, 60, 'AUX'), | |
('x2_y1_z1_aux', 2, 20, 3, 5, 50, 'AUX'), | |
('x2_y1_z2_aux', 2, 20, 3, 6, 60, 'AUX'), | |
('x2_y2_z1_aux', 2, 20, 4, 5, 50, 'AUX'), | |
('x2_y2_z2_aux', 2, 20, 4, 6, 60, 'AUX')] | |
named_product = list(variants.params_product(l1, l2, l3, l4, named=True)) | |
self.assertEqual(named_product, named_expected) | |
class FailedTestsTest(absltest.TestCase): | |
# Inner class prevents FailedTest being run by `absltest.main()`. | |
class FailedTest(variants.TestCase): | |
def test_failure(self): | |
self.assertEqual('meaning of life', 1337) | |
def test_error(self): | |
raise ValueError('this message does not specify the Chex variant') | |
def setUp(self): | |
super().setUp() | |
self.chex_info = str(variants.ChexVariantType.WITHOUT_JIT) | |
self.res = unittest.TestResult() | |
ts = unittest.makeSuite(self.FailedTest) # pytype: disable=module-attr | |
ts.run(self.res) | |
def test_useful_failures(self): | |
self.assertIsNotNone(self.res.failures) | |
for test_method, _ in self.res.failures: | |
self.assertIn(self.chex_info, test_method._testMethodName) | |
def test_useful_errors(self): | |
self.assertIsNotNone(self.res.errors) | |
for test_method, msg in self.res.errors: | |
self.assertIn(self.chex_info, test_method._testMethodName) | |
self.assertIn('this message does not specify the Chex variant', msg) | |
class OneFailedVariantTest(variants.TestCase): | |
# Inner class prevents MaybeFailedTest being run by `absltest.main()`. | |
class MaybeFailedTest(variants.TestCase): | |
def test_failure(self): | |
def fails_for_without_device_variant(x): | |
self.assertIsInstance(x, jax.Array) | |
fails_for_without_device_variant(42) | |
def test_useful_failure(self): | |
expected_info = str(variants.ChexVariantType.WITHOUT_DEVICE) | |
unexpected_info = str(variants.ChexVariantType.WITH_DEVICE) | |
res = unittest.TestResult() | |
ts = unittest.makeSuite(self.MaybeFailedTest) # pytype: disable=module-attr | |
ts.run(res) | |
self.assertLen(res.failures, 1) | |
for test_method, _ in res.failures: | |
self.assertIn(expected_info, test_method._testMethodName) | |
self.assertNotIn(unexpected_info, test_method._testMethodName) | |
class WrongBaseClassTest(variants.TestCase): | |
# Inner class prevents InnerTest being run by `absltest.main()`. | |
class InnerTest(absltest.TestCase): | |
def test_failure(self): | |
pass | |
def test_wrong_base_class(self): | |
res = unittest.TestResult() | |
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr | |
ts.run(res) | |
self.assertLen(res.errors, 1) | |
for _, msg in res.errors: | |
self.assertRegex(msg, | |
'RuntimeError.+make sure.+inherit from `chex.TestCase`') | |
class BaseClassesTest(parameterized.TestCase): | |
"""Tests different combinations of base classes for a variants test.""" | |
def generate_test_class(self, base_1, base_2): | |
"""Returns a test class derived from the specified bases.""" | |
class InnerBaseClassTest(base_1, base_2): | |
def test_should_pass(self, arg_0, arg_1, expected): | |
actual = self.variant(DEFAULT_FN)(arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
return InnerBaseClassTest | |
def test_inheritance(self, base_classes): | |
res = unittest.TestResult() | |
test_class = self.generate_test_class(*base_classes) | |
for base_class in base_classes: | |
self.assertTrue(issubclass(test_class, base_class)) | |
ts = unittest.makeSuite(test_class) # pytype: disable=module-attr | |
ts.run(res) | |
self.assertEqual(res.testsRun, 8) | |
self.assertEmpty(res.errors or res.failures) | |
class VariantsTestCaseWithParameterizedTest(absltest.TestCase): | |
# Inner class prevents InnerTest being run by `absltest.main()`. | |
class InnerTest(variants.TestCase): | |
def test_should_pass(self, arg_0, arg_1, expected): | |
actual = self.variant(DEFAULT_FN)(arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
def test_should_pass(self): | |
res = unittest.TestResult() | |
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr | |
ts.run(res) | |
self.assertEqual(res.testsRun, 8) | |
self.assertEmpty(res.errors or res.failures) | |
class WrongWrappersOrderTest(parameterized.TestCase): | |
def setUp(self): | |
super().setUp() | |
self._error_msg = ('A test wrapper attempts to access __name__ of ' | |
'VariantsTestCaseGenerator') | |
def test_incorrect_wrapping_order_named_all_variants(self): | |
with self.assertRaisesRegex(RuntimeError, self._error_msg): | |
def _(*unused_args): | |
pass | |
def test_incorrect_wrapping_order_named_some_variants(self): | |
with self.assertRaisesRegex(RuntimeError, self._error_msg): | |
def _(*unused_args): | |
pass | |
def test_incorrect_wrapping_order_all_variants(self): | |
with self.assertRaisesRegex(RuntimeError, self._error_msg): | |
def _(*unused_args): | |
pass | |
def test_incorrect_wrapping_order_some_variants(self): | |
with self.assertRaisesRegex(RuntimeError, self._error_msg): | |
def _(*unused_args): | |
pass | |
class UnusedVariantTest(absltest.TestCase): | |
# Inner class prevents InnerTest being run by `absltest.main()`. | |
class InnerTest(variants.TestCase): | |
def test_noop(self): | |
pass | |
def test_unused_variant(self): | |
res = unittest.TestResult() | |
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr | |
ts.run(res) | |
self.assertLen(res.errors, 4) | |
for _, msg in res.errors: | |
self.assertRegex( | |
msg, 'RuntimeError: Test is wrapped .+ but never calls self.variant') | |
class NoVariantsTest(absltest.TestCase): | |
"""Checks that Chex raises ValueError when no variants are selected.""" | |
def test_no_variants(self): | |
with self.assertRaisesRegex(ValueError, 'No variants selected'): | |
class InnerTest(variants.TestCase): # pylint:disable=unused-variable | |
def test_noop(self): | |
pass | |
class UnknownVariantArgumentsTest(absltest.TestCase): | |
# Inner class prevents InnerTest being run by `absltest.main()`. | |
class InnerTest(variants.TestCase): | |
def test_arg(self): | |
self.variant(lambda: None, some_unknown_arg=16) | |
def test_unknown_argument(self): | |
res = unittest.TestResult() | |
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr | |
ts.run(res) | |
self.assertLen(res.errors, 4) | |
for _, msg in res.errors: | |
self.assertRegex(msg, 'Unknown arguments in .+some_unknown_arg') | |
class VariantTypesTest(absltest.TestCase): | |
# Inner class prevents InnerTest being run by `absltest.main()`. | |
class InnerTest(variants.TestCase): | |
var_types = set() | |
def test_var_type(self): | |
self.variant(lambda: None) | |
self.var_types.add(self.variant.type) | |
def test_var_type_fetch(self): | |
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr | |
ts.run(unittest.TestResult()) | |
expected_types = set(variants.ChexVariantType) | |
if jax.device_count() == 1: | |
expected_types.remove(variants.ChexVariantType.WITH_PMAP) | |
self.assertSetEqual(self.InnerTest.var_types, expected_types) | |
def test_consistency(self): | |
self.assertLen(variants._variant_decorators, len(variants.ChexVariantType)) | |
for arg in inspect.getfullargspec(variants.variants).args: | |
if arg == 'test_method': | |
continue | |
self.assertTrue(hasattr(variants.ChexVariantType, arg.upper())) | |
class CountVariantsTest(absltest.TestCase): | |
# Inner class prevents InnerTest being run by `absltest.main()`. | |
class InnerTest(variants.TestCase): | |
test_1_count = 0 | |
test_2_count = 0 | |
test_3_count = 0 | |
test_4_count = 0 | |
def test_1(self): | |
type(self).test_1_count += 1 | |
def test_2(self): | |
type(self).test_2_count += 1 | |
def test_3(self): | |
type(self).test_3_count += 1 | |
def test_4(self): | |
type(self).test_4_count += 1 | |
def test_counters(self): | |
res = unittest.TestResult() | |
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr | |
ts.run(res) | |
active_pmap = int(jax.device_count() > 1) | |
self.assertEqual(self.InnerTest.test_1_count, 4 + active_pmap) | |
self.assertEqual(self.InnerTest.test_2_count, 4) | |
self.assertEqual(self.InnerTest.test_3_count, 1) | |
self.assertEqual(self.InnerTest.test_4_count, 2) | |
# Test methods do not use `self.variant`. | |
self.assertLen(res.errors, 1 + 2 + 4 + 4 + active_pmap) | |
for _, msg in res.errors: | |
self.assertRegex( | |
msg, 'RuntimeError: Test is wrapped .+ but never calls self.variant') | |
class MultipleVariantsTest(parameterized.TestCase): | |
def test_all_variants(self): | |
# self.variant must be used at least once. | |
self.variant(lambda x: x)(0) | |
self.assertNotEqual('meaning of life', 1337) | |
def test_all_variants_no_parens(self): | |
# self.variant must be used at least once. | |
self.variant(lambda x: x)(0) | |
self.assertNotEqual('meaning of life', 1337) | |
def test_many_variants(self, arg_0, arg_1, expected): | |
def fn(arg_0, arg_1): | |
return arg_1 - arg_0 | |
actual = fn(arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
class VmappedFunctionTest(parameterized.TestCase): | |
def test_vmapped_fn_named_params(self, arg_0, arg_1, expected): | |
varg_0, varg_1, vexpected = ( | |
_scalar_to_ndarray(x) for x in (arg_0, arg_1, expected)) | |
vmapped_fn = jax.vmap(DEFAULT_FN) | |
actual = self.variant(vmapped_fn)(varg_0, varg_1) | |
# pmap variant. | |
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1: | |
vexpected = jnp.broadcast_to(vexpected, actual.shape) | |
np.testing.assert_array_equal(actual, vexpected) | |
class WithoutJitTest(parameterized.TestCase): | |
tests = _variant_default_tests_generator( | |
fn=DEFAULT_FN, | |
is_jit_context=False, | |
which_variants=dict(without_jit=True)) | |
class WithJitTest(parameterized.TestCase): | |
tests = _variant_default_tests_generator( | |
fn=DEFAULT_FN, is_jit_context=True, which_variants=dict(with_jit=True)) | |
def test_different_jit_kwargs(self, arg_0, arg_1, expected): | |
kwarg_0 = arg_0 | |
kwarg_1 = arg_1 | |
arg_0_type = type(arg_0) | |
arg_1_type = type(arg_1) | |
kwarg_0_type = type(kwarg_0) | |
kwarg_1_type = type(kwarg_1) | |
def fn_0(arg_0, arg_1, kwarg_0, kwarg_1): | |
self.assertIsInstance(arg_0, arg_0_type) | |
self.assertNotIsInstance(arg_1, arg_1_type) | |
self.assertNotIsInstance(kwarg_0, kwarg_0_type) | |
self.assertIsInstance(kwarg_1, kwarg_1_type) | |
return DEFAULT_FN(arg_0 + kwarg_0, arg_1 + kwarg_1) | |
actual_0 = fn_0(arg_0, arg_1, kwarg_0=kwarg_0, kwarg_1=kwarg_1) | |
self.assertEqual(actual_0, 2 * expected) | |
def fn_1(arg_0, arg_1, kwarg_0, kwarg_1): | |
self.assertNotIsInstance(arg_0, arg_0_type) | |
self.assertIsInstance(arg_1, arg_1_type) | |
self.assertNotIsInstance(kwarg_0, kwarg_0_type) | |
self.assertIsInstance(kwarg_1, kwarg_1_type) | |
return DEFAULT_FN(arg_0 + kwarg_0, arg_1 + kwarg_1) | |
actual_1 = fn_1(arg_0, arg_1, kwarg_0=kwarg_0, kwarg_1=kwarg_1) | |
self.assertEqual(actual_1, 2 * expected) | |
def fn_2(arg_0, arg_1, kwarg_0, kwarg_1): | |
self.assertNotIsInstance(arg_0, arg_0_type) | |
self.assertNotIsInstance(arg_1, arg_1_type) | |
self.assertIsInstance(kwarg_0, kwarg_0_type) | |
self.assertNotIsInstance(kwarg_1, kwarg_1_type) | |
return DEFAULT_FN(arg_0 + kwarg_0, arg_1 + kwarg_1) | |
actual_2 = fn_2(arg_0, arg_1, kwarg_0=kwarg_0, kwarg_1=kwarg_1) | |
self.assertEqual(actual_2, 2 * expected) | |
def fn_3(arg_0, arg_1): | |
self.assertIsInstance(arg_0, arg_0_type) | |
self.assertNotIsInstance(arg_1, arg_1_type) | |
return DEFAULT_FN(arg_0, arg_1) | |
fn_3_v0 = self.variant(static_argnums=0, static_argnames='arg_0')(fn_3) | |
fn_3_v1 = self.variant(static_argnums=0)(fn_3) | |
fn_3_v2 = self.variant(static_argnums=(), static_argnames='arg_0')(fn_3) | |
self.assertEqual(fn_3_v0(arg_0, arg_1), expected) | |
self.assertEqual(fn_3_v1(arg_0=arg_0, arg_1=arg_1), expected) | |
self.assertEqual(fn_3_v1(arg_0, arg_1=arg_1), expected) | |
self.assertEqual(fn_3_v2(arg_0=arg_0, arg_1=arg_1), expected) | |
def _test_fn_without_device(arg_0, arg_1): | |
tc = unittest.TestCase() | |
tc.assertNotIsInstance(arg_0, jax.Array) | |
tc.assertNotIsInstance(arg_1, jax.Array) | |
return DEFAULT_FN(arg_0, arg_1) | |
class WithoutDeviceTest(parameterized.TestCase): | |
tests = _variant_default_tests_generator( | |
fn=_test_fn_without_device, | |
is_jit_context=False, | |
which_variants=dict(without_device=True)) | |
def test_emplace(self, arg_0, arg_1, expected): | |
(arg_0, arg_1) = self.variant(lambda x: x)((arg_0, arg_1)) | |
actual = _test_fn_without_device(arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
def _test_fn_with_device(arg_0, arg_1): | |
tc = unittest.TestCase() | |
tc.assertIsInstance(arg_0, jax.Array) | |
tc.assertIsInstance(arg_1, jax.Array) | |
return DEFAULT_FN(arg_0, arg_1) | |
class WithDeviceTest(parameterized.TestCase): | |
tests = _variant_default_tests_generator( | |
fn=_test_fn_with_device, | |
is_jit_context=False, | |
which_variants=dict(with_device=True)) | |
def test_emplace(self, arg_0, arg_1, expected): | |
(arg_0, arg_1) = self.variant(lambda x: x)((arg_0, arg_1)) | |
actual = _test_fn_with_device(arg_0, arg_1) | |
self.assertEqual(actual, expected) | |
def test_ignore_argnums(self, arg_0, arg_1, expected): | |
static_type = type(arg_0) | |
def fn(arg_0, arg_1, float_arg): | |
self.assertIsInstance(arg_0, static_type) | |
self.assertIsInstance(arg_1, jax.Array) | |
self.assertIsInstance(float_arg, float) | |
return DEFAULT_FN(arg_0, arg_1) | |
actual = fn(arg_0, arg_1, 5.3) | |
self.assertEqual(actual, expected) | |
def _test_fn_single_device(arg_0, arg_1): | |
tc = unittest.TestCase() | |
tc.assertIn(np.shape(arg_0), {(), DEFAULT_NDARRAY_PARAMS_SHAPE}) | |
tc.assertIn(np.shape(arg_1), {(), DEFAULT_NDARRAY_PARAMS_SHAPE}) | |
res = DEFAULT_FN(arg_0, arg_1) | |
psum_res = jax.lax.psum(res, axis_name='i') | |
return psum_res | |
class WithPmapSingleDeviceTest(parameterized.TestCase): | |
tests_single_device = _variant_default_tests_generator( | |
fn=_test_fn_single_device, | |
is_jit_context=True, | |
which_variants=dict(with_pmap=True), | |
n_devices=1) | |
class WithPmapAllAvailableDeviceTest(parameterized.TestCase): | |
def setUp(self): | |
super().setUp() | |
# Choose devices and a backend. | |
n_tpu = asserts._ai.num_devices_available('tpu') | |
n_gpu = asserts._ai.num_devices_available('gpu') | |
if n_tpu > 1: | |
self.n_devices, self.backend = n_tpu, 'tpu' | |
elif n_gpu > 1: | |
self.n_devices, self.backend = n_gpu, 'gpu' | |
else: | |
self.n_devices, self.backend = FLAGS['chex_n_cpu_devices'].value, 'cpu' | |
def test_pmap(self, arg_0, arg_1, expected): | |
n_devices, backend = self.n_devices, self.backend | |
n_copies = 3 | |
arg_0_type = type(arg_0) | |
arg_1_type = type(arg_1) | |
def fn(arg_0, arg_1): | |
self.assertNotIsInstance(arg_0, arg_0_type) | |
self.assertNotIsInstance(arg_1, arg_1_type) | |
asserts.assert_shape(arg_0, [n_copies]) | |
asserts.assert_shape(arg_1, [n_copies]) | |
res = arg_1 - arg_0 | |
psum_res = jax.lax.psum(res, axis_name='i') | |
return psum_res | |
arg_0 = jnp.zeros((n_copies,)) + arg_0 | |
arg_1 = jnp.zeros((n_copies,)) + arg_1 | |
actual = fn(arg_0, arg_1) | |
self.assertEqual(actual.shape, (n_devices, n_copies)) | |
# Exponents of `n_devices`: | |
# +1: psum() inside fn() | |
# +1: jnp.sum() to aggregate results | |
self.assertEqual(jnp.sum(actual), n_copies * n_devices**2 * expected) | |
def test_pmap_vmapped_fn(self, arg_0, arg_1, expected): | |
n_devices, backend = self.n_devices, self.backend | |
n_copies = 7 | |
actual_shape = (n_copies,) + DEFAULT_NDARRAY_PARAMS_SHAPE | |
varg_0 = _scalar_to_ndarray(arg_0, actual_shape) | |
varg_1 = _scalar_to_ndarray(arg_1, actual_shape) | |
vexpected = _scalar_to_ndarray(expected) | |
arg_0_type = type(varg_0) | |
arg_1_type = type(varg_1) | |
def fn(arg_0, arg_1): | |
self.assertNotIsInstance(arg_0, arg_0_type) | |
self.assertNotIsInstance(arg_1, arg_1_type) | |
def vmapped_fn(arg_0, arg_1): | |
self.assertIsInstance(arg_0, ArrayBatched) | |
self.assertIsInstance(arg_1, ArrayBatched) | |
asserts.assert_shape(arg_0, actual_shape[1:]) | |
asserts.assert_shape(arg_1, actual_shape[1:]) | |
return arg_1 - arg_0 | |
res = vmapped_fn(arg_0, arg_1) | |
psum_res = jax.lax.psum(res, axis_name='i') | |
return psum_res | |
actual = fn(varg_0, varg_1) | |
self.assertEqual(actual.shape, (n_devices,) + actual_shape) | |
# Sum over `n_devices` and `n_copies` axes. | |
actual = actual.sum(axis=0).sum(axis=0) | |
# Exponents of `n_devices`: | |
# +1: psum() inside fn() | |
# +1: jnp.sum() to aggregate results | |
np.testing.assert_array_equal(actual, n_copies * n_devices**2 * vexpected) | |
def test_pmap_static_argnums(self, arg_0, arg_1, expected): | |
n_devices, backend = self.n_devices, self.backend | |
n_copies = 5 | |
actual_shape = (n_copies,) | |
varg_0 = _scalar_to_ndarray(arg_0, actual_shape) | |
arg_0_type = type(varg_0) | |
arg_1_type = type(arg_1) | |
def fn_static(arg_0, arg_1): | |
self.assertNotIsInstance(arg_0, arg_0_type) | |
self.assertIsInstance(arg_1, arg_1_type) | |
asserts.assert_shape(arg_0, [n_copies]) | |
arg_1 = _scalar_to_ndarray(arg_1, actual_shape) | |
asserts.assert_shape(arg_1, [n_copies]) | |
arg_1 = np.array(arg_1) # don't stage out operations on arg_1 | |
psum_arg_1 = np.sum(jax.lax.psum(arg_1, axis_name='j')) | |
self.assertEqual(psum_arg_1, arg_1[0] * (n_copies * n_devices)) | |
res = arg_1 - arg_0 | |
psum_res = jax.lax.psum(res, axis_name='j') | |
return psum_res | |
actual = fn_static(varg_0, arg_1) | |
self.assertEqual(actual.shape, (n_devices, n_copies)) | |
# Exponents of `n_devices`: | |
# +1: psum() inside fn() | |
# +1: jnp.sum() to aggregate results | |
self.assertEqual(jnp.sum(actual), n_copies * n_devices**2 * expected) | |
def test_pmap_static_argnums_zero(self): | |
n_devices, backend = self.n_devices, self.backend | |
n_copies = 5 | |
varg_0 = 10 | |
varg_1 = jnp.zeros(n_copies) + 20 | |
arg_0_type = type(varg_0) | |
arg_1_type = type(varg_1) | |
def fn_static(arg_0, arg_1): | |
self.assertIsInstance(arg_0, arg_0_type) | |
self.assertNotIsInstance(arg_1, arg_1_type) | |
arg_0 = jnp.zeros(n_copies) + arg_0 | |
asserts.assert_shape(arg_0, [n_copies]) | |
asserts.assert_shape(arg_1, [n_copies]) | |
res = arg_1 - arg_0 | |
return jax.lax.psum(res, axis_name='i') | |
actual = fn_static(varg_0, varg_1) | |
self.assertEqual(actual.shape, (n_devices, n_copies)) | |
# Exponents of `n_devices`: | |
# +1: psum() inside fn() | |
# +1: jnp.sum() to aggregate results | |
self.assertEqual(jnp.sum(actual), n_copies * n_devices**2 * 10) | |
def test_pmap_in_axes(self): | |
n_devices, backend = self.n_devices, self.backend | |
n_copies = 7 | |
varg_0 = jnp.zeros((n_devices, n_copies)) + 1 | |
varg_1 = jnp.zeros((n_devices, n_copies)) + 2 | |
arg_0_type = type(varg_0) | |
arg_1_type = type(varg_1) | |
def fn(arg_0, arg_1): | |
self.assertNotIsInstance(arg_0, arg_0_type) | |
self.assertNotIsInstance(arg_1, arg_1_type) | |
asserts.assert_shape(arg_0, [n_copies]) | |
asserts.assert_shape(arg_1, [n_devices, n_copies]) | |
res = arg_1 - arg_0 | |
psum_res = jax.lax.psum(res, axis_name='i') | |
return psum_res | |
actual = fn(varg_0, varg_1) | |
self.assertEqual(actual.shape, (n_devices, n_devices, n_copies)) | |
self.assertEqual(jnp.sum(actual), n_copies * n_devices**3) | |
def test_pmap_wrong_axis_size(self): | |
n_devices, backend = self.n_devices, self.backend | |
def fn(arg_0, arg_1): | |
raise RuntimeError('This line should not be executed.') | |
varg_0 = jnp.zeros(n_devices + 1) | |
varg_1 = jnp.zeros(n_devices + 2) | |
with self.assertRaisesRegex( | |
ValueError, 'Pmappable.* axes size must be equal to number of devices.*' | |
f'expected the first dim to be {n_devices}'): | |
fn(varg_0, varg_1) | |
if __name__ == '__main__': | |
absltest.main() | |