PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `variants.py`.
To run tests in multi-cpu regime, one need to set the flag `--n_cpu_devices=N`.
"""
import inspect
import itertools
import unittest
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from chex._src import asserts
from chex._src import fake
from chex._src import pytypes
from chex._src import variants
import jax
import jax.numpy as jnp
import numpy as np
FLAGS = flags.FLAGS
ArrayBatched = pytypes.ArrayBatched
DEFAULT_FN = lambda arg_0, arg_1: arg_1 - arg_0
DEFAULT_PARAMS = ((1, 2, 1), (4, 6, 2))
DEFAULT_NDARRAY_PARAMS_SHAPE = (5, 7)
DEFAULT_NAMED_PARAMS = (('case_0', 1, 2, 1), ('case_1', 4, 6, 2))
# Set `FLAGS.chex_n_cpu_devices` CPU devices for all tests.
def setUpModule():
fake.set_n_cpu_devices()
asserts.assert_devices_available(
FLAGS['chex_n_cpu_devices'].value, 'cpu', backend='cpu')
def _scalar_to_ndarray(x, shape=None):
return np.broadcast_to(x, shape or DEFAULT_NDARRAY_PARAMS_SHAPE)
def _variant_default_tests_generator(fn, is_jit_context, which_variants,
**var_kwargs):
"""Returns a generator with standard tests.
For internal usage. Allows to dynamically generate common tests.
See tests' names and comments for more information.
Args:
fn: a separate function to be tested (without `self` argument).
is_jit_context: is a function is supposed to be JIT-ted.
which_variants: chex variants to use in tests generation.
**var_kwargs: kwargs for variants wrappers.
Returns:
A generator with tests.
"""
# All generated tests use default arguments (defined at the top of this file).
arg_0, arg_1, expected = DEFAULT_PARAMS[0]
varg_0, varg_1, vexpected = (
_scalar_to_ndarray(a) for a in (arg_0, arg_1, expected))
# We test whether the function has been jitted by introducing a counter
# variable as a side-effect. When the function is repeatedly called, jitted
# code will only execute the side-effect once
python_execution_count = 0
def fn_with_counter(*args, **kwargs):
nonlocal python_execution_count
python_execution_count += 1
return fn(*args, **kwargs)
def exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1):
self.assertEqual(python_execution_count, 0)
_ = var_fn(arg_0, arg_1)
# In jit context, JAX can omit retracing a function from the previous
# test, hence `python_execution_count` will be equal to 0.
# In non-jit context, `python_execution_count` must always increase.
if not is_jit_context:
self.assertEqual(python_execution_count, 1)
actual = var_fn(arg_0, arg_1)
if is_jit_context:
# Either 1 (initial tracing) or 0 (function reuse).
self.assertLess(python_execution_count, 2)
else:
self.assertEqual(python_execution_count, 2)
return actual
# Here, various tests follow. Tests' names intended to be self-descriptive.
@variants.variants(**which_variants)
def test_with_scalar_args(self):
nonlocal python_execution_count
python_execution_count = 0
var_fn = self.variant(fn_with_counter, **var_kwargs)
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1)
self.assertEqual(actual, expected)
@variants.variants(**which_variants)
def test_called_variant(self):
nonlocal python_execution_count
python_execution_count = 0
var_fn = self.variant(**var_kwargs)(fn_with_counter)
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1)
self.assertEqual(actual, expected)
@variants.variants(**which_variants)
def test_with_kwargs(self):
nonlocal python_execution_count
python_execution_count = 0
var_fn = self.variant(fn_with_counter, **var_kwargs)
actual = exec_with_tracing_counter_checks(
self, var_fn, arg_1=arg_1, arg_0=arg_0)
self.assertEqual(actual, expected)
@variants.variants(**which_variants)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_scalar_parameters(self, arg_0, arg_1, expected):
nonlocal python_execution_count
python_execution_count = 0
var_fn = self.variant(fn_with_counter, **var_kwargs)
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1)
self.assertEqual(actual, expected)
@variants.variants(**which_variants)
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
def test_named_scalar_parameters(self, arg_0, arg_1, expected):
nonlocal python_execution_count
python_execution_count = 0
var_fn = self.variant(fn_with_counter, **var_kwargs)
actual = exec_with_tracing_counter_checks(self, var_fn, arg_0, arg_1)
self.assertEqual(actual, expected)
@variants.variants(**which_variants)
def test_with_ndarray_args(self):
nonlocal python_execution_count
python_execution_count = 0
var_fn = self.variant(fn_with_counter, **var_kwargs)
actual = exec_with_tracing_counter_checks(self, var_fn, varg_0, varg_1)
vexpected_ = vexpected
# pmap variant case.
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1:
vexpected_ = jnp.broadcast_to(vexpected_, actual.shape)
np.testing.assert_array_equal(actual, vexpected_)
@variants.variants(**which_variants)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_ndarray_parameters(self, arg_0, arg_1, expected):
nonlocal python_execution_count
python_execution_count = 0
varg_0, varg_1, vexpected = (
_scalar_to_ndarray(a) for a in (arg_0, arg_1, expected))
var_fn = self.variant(fn_with_counter, **var_kwargs)
actual = exec_with_tracing_counter_checks(self, var_fn, varg_0, varg_1)
# pmap variant case.
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1:
vexpected = jnp.broadcast_to(vexpected, actual.shape)
np.testing.assert_array_equal(actual, vexpected)
@variants.variants(**which_variants)
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
def test_ndarray_named_parameters(self, arg_0, arg_1, expected):
nonlocal python_execution_count
python_execution_count = 0
varg_0, varg_1, vexpected = (
_scalar_to_ndarray(a) for a in (arg_0, arg_1, expected))
var_fn = self.variant(fn_with_counter, **var_kwargs)
actual = exec_with_tracing_counter_checks(self, var_fn, varg_0, varg_1)
# pmap variant case.
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1:
vexpected = jnp.broadcast_to(vexpected, actual.shape)
np.testing.assert_array_equal(actual, vexpected)
all_tests = (test_with_scalar_args, test_called_variant, test_with_kwargs,
test_scalar_parameters, test_named_scalar_parameters,
test_with_ndarray_args, test_ndarray_parameters,
test_ndarray_named_parameters)
# Each test is a generator itself, hence we use chaining from itertools.
return itertools.chain(*all_tests)
class ParamsProductTest(absltest.TestCase):
def test_product(self):
l1 = (
('x1', 1, 10),
('x2', 2, 20),
)
l2 = (
('y1', 3),
('y2', 4),
)
l3 = (
('z1', 5, 50),
('z2', 6, 60),
)
l4 = (('aux', 'AUX'),)
expected = [('x1', 1, 10, 'y1', 3, 'z1', 5, 50, 'aux', 'AUX'),
('x1', 1, 10, 'y1', 3, 'z2', 6, 60, 'aux', 'AUX'),
('x1', 1, 10, 'y2', 4, 'z1', 5, 50, 'aux', 'AUX'),
('x1', 1, 10, 'y2', 4, 'z2', 6, 60, 'aux', 'AUX'),
('x2', 2, 20, 'y1', 3, 'z1', 5, 50, 'aux', 'AUX'),
('x2', 2, 20, 'y1', 3, 'z2', 6, 60, 'aux', 'AUX'),
('x2', 2, 20, 'y2', 4, 'z1', 5, 50, 'aux', 'AUX'),
('x2', 2, 20, 'y2', 4, 'z2', 6, 60, 'aux', 'AUX')]
product = list(variants.params_product(l1, l2, l3, l4, named=False))
self.assertEqual(product, expected)
named_expected = [('x1_y1_z1_aux', 1, 10, 3, 5, 50, 'AUX'),
('x1_y1_z2_aux', 1, 10, 3, 6, 60, 'AUX'),
('x1_y2_z1_aux', 1, 10, 4, 5, 50, 'AUX'),
('x1_y2_z2_aux', 1, 10, 4, 6, 60, 'AUX'),
('x2_y1_z1_aux', 2, 20, 3, 5, 50, 'AUX'),
('x2_y1_z2_aux', 2, 20, 3, 6, 60, 'AUX'),
('x2_y2_z1_aux', 2, 20, 4, 5, 50, 'AUX'),
('x2_y2_z2_aux', 2, 20, 4, 6, 60, 'AUX')]
named_product = list(variants.params_product(l1, l2, l3, l4, named=True))
self.assertEqual(named_product, named_expected)
class FailedTestsTest(absltest.TestCase):
# Inner class prevents FailedTest being run by `absltest.main()`.
class FailedTest(variants.TestCase):
@variants.variants(without_jit=True)
def test_failure(self):
self.assertEqual('meaning of life', 1337)
@variants.variants(without_jit=True)
def test_error(self):
raise ValueError('this message does not specify the Chex variant')
def setUp(self):
super().setUp()
self.chex_info = str(variants.ChexVariantType.WITHOUT_JIT)
self.res = unittest.TestResult()
ts = unittest.makeSuite(self.FailedTest) # pytype: disable=module-attr
ts.run(self.res)
def test_useful_failures(self):
self.assertIsNotNone(self.res.failures)
for test_method, _ in self.res.failures:
self.assertIn(self.chex_info, test_method._testMethodName)
def test_useful_errors(self):
self.assertIsNotNone(self.res.errors)
for test_method, msg in self.res.errors:
self.assertIn(self.chex_info, test_method._testMethodName)
self.assertIn('this message does not specify the Chex variant', msg)
class OneFailedVariantTest(variants.TestCase):
# Inner class prevents MaybeFailedTest being run by `absltest.main()`.
class MaybeFailedTest(variants.TestCase):
@variants.variants(with_device=True, without_device=True)
def test_failure(self):
@self.variant
def fails_for_without_device_variant(x):
self.assertIsInstance(x, jax.Array)
fails_for_without_device_variant(42)
def test_useful_failure(self):
expected_info = str(variants.ChexVariantType.WITHOUT_DEVICE)
unexpected_info = str(variants.ChexVariantType.WITH_DEVICE)
res = unittest.TestResult()
ts = unittest.makeSuite(self.MaybeFailedTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.failures, 1)
for test_method, _ in res.failures:
self.assertIn(expected_info, test_method._testMethodName)
self.assertNotIn(unexpected_info, test_method._testMethodName)
class WrongBaseClassTest(variants.TestCase):
# Inner class prevents InnerTest being run by `absltest.main()`.
class InnerTest(absltest.TestCase):
@variants.all_variants
def test_failure(self):
pass
def test_wrong_base_class(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.errors, 1)
for _, msg in res.errors:
self.assertRegex(msg,
'RuntimeError.+make sure.+inherit from `chex.TestCase`')
class BaseClassesTest(parameterized.TestCase):
"""Tests different combinations of base classes for a variants test."""
def generate_test_class(self, base_1, base_2):
"""Returns a test class derived from the specified bases."""
class InnerBaseClassTest(base_1, base_2):
@variants.all_variants(with_pmap=False)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_should_pass(self, arg_0, arg_1, expected):
actual = self.variant(DEFAULT_FN)(arg_0, arg_1)
self.assertEqual(actual, expected)
return InnerBaseClassTest
@parameterized.named_parameters(
('parameterized', (parameterized.TestCase, object)),
('variants', (variants.TestCase, object)),
('variants_and_parameterized',
(variants.TestCase, parameterized.TestCase)),
)
def test_inheritance(self, base_classes):
res = unittest.TestResult()
test_class = self.generate_test_class(*base_classes)
for base_class in base_classes:
self.assertTrue(issubclass(test_class, base_class))
ts = unittest.makeSuite(test_class) # pytype: disable=module-attr
ts.run(res)
self.assertEqual(res.testsRun, 8)
self.assertEmpty(res.errors or res.failures)
class VariantsTestCaseWithParameterizedTest(absltest.TestCase):
# Inner class prevents InnerTest being run by `absltest.main()`.
class InnerTest(variants.TestCase):
@variants.all_variants(with_pmap=False)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_should_pass(self, arg_0, arg_1, expected):
actual = self.variant(DEFAULT_FN)(arg_0, arg_1)
self.assertEqual(actual, expected)
def test_should_pass(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertEqual(res.testsRun, 8)
self.assertEmpty(res.errors or res.failures)
class WrongWrappersOrderTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self._error_msg = ('A test wrapper attempts to access __name__ of '
'VariantsTestCaseGenerator')
def test_incorrect_wrapping_order_named_all_variants(self):
with self.assertRaisesRegex(RuntimeError, self._error_msg):
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
@variants.all_variants()
def _(*unused_args):
pass
def test_incorrect_wrapping_order_named_some_variants(self):
with self.assertRaisesRegex(RuntimeError, self._error_msg):
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
@variants.variants(with_jit=True, with_device=True)
def _(*unused_args):
pass
def test_incorrect_wrapping_order_all_variants(self):
with self.assertRaisesRegex(RuntimeError, self._error_msg):
@parameterized.parameters(*DEFAULT_PARAMS)
@variants.all_variants()
def _(*unused_args):
pass
def test_incorrect_wrapping_order_some_variants(self):
with self.assertRaisesRegex(RuntimeError, self._error_msg):
@parameterized.parameters(*DEFAULT_PARAMS)
@variants.variants(without_jit=True, without_device=True)
def _(*unused_args):
pass
class UnusedVariantTest(absltest.TestCase):
# Inner class prevents InnerTest being run by `absltest.main()`.
class InnerTest(variants.TestCase):
@variants.all_variants(with_pmap=False)
def test_noop(self):
pass
def test_unused_variant(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.errors, 4)
for _, msg in res.errors:
self.assertRegex(
msg, 'RuntimeError: Test is wrapped .+ but never calls self.variant')
class NoVariantsTest(absltest.TestCase):
"""Checks that Chex raises ValueError when no variants are selected."""
def test_no_variants(self):
with self.assertRaisesRegex(ValueError, 'No variants selected'):
class InnerTest(variants.TestCase): # pylint:disable=unused-variable
@variants.variants()
def test_noop(self):
pass
class UnknownVariantArgumentsTest(absltest.TestCase):
# Inner class prevents InnerTest being run by `absltest.main()`.
class InnerTest(variants.TestCase):
@variants.all_variants(with_pmap=False)
def test_arg(self):
self.variant(lambda: None, some_unknown_arg=16)
def test_unknown_argument(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
self.assertLen(res.errors, 4)
for _, msg in res.errors:
self.assertRegex(msg, 'Unknown arguments in .+some_unknown_arg')
class VariantTypesTest(absltest.TestCase):
# Inner class prevents InnerTest being run by `absltest.main()`.
class InnerTest(variants.TestCase):
var_types = set()
@variants.all_variants()
def test_var_type(self):
self.variant(lambda: None)
self.var_types.add(self.variant.type)
def test_var_type_fetch(self):
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts.run(unittest.TestResult())
expected_types = set(variants.ChexVariantType)
if jax.device_count() == 1:
expected_types.remove(variants.ChexVariantType.WITH_PMAP)
self.assertSetEqual(self.InnerTest.var_types, expected_types)
def test_consistency(self):
self.assertLen(variants._variant_decorators, len(variants.ChexVariantType))
for arg in inspect.getfullargspec(variants.variants).args:
if arg == 'test_method':
continue
self.assertTrue(hasattr(variants.ChexVariantType, arg.upper()))
class CountVariantsTest(absltest.TestCase):
# Inner class prevents InnerTest being run by `absltest.main()`.
class InnerTest(variants.TestCase):
test_1_count = 0
test_2_count = 0
test_3_count = 0
test_4_count = 0
@variants.all_variants
def test_1(self):
type(self).test_1_count += 1
@variants.all_variants(with_pmap=False)
def test_2(self):
type(self).test_2_count += 1
@variants.variants(with_jit=True)
def test_3(self):
type(self).test_3_count += 1
@variants.variants(with_jit=True)
@variants.variants(without_jit=False)
@variants.variants(with_device=True)
@variants.variants(without_device=False)
def test_4(self):
type(self).test_4_count += 1
def test_counters(self):
res = unittest.TestResult()
ts = unittest.makeSuite(self.InnerTest) # pytype: disable=module-attr
ts.run(res)
active_pmap = int(jax.device_count() > 1)
self.assertEqual(self.InnerTest.test_1_count, 4 + active_pmap)
self.assertEqual(self.InnerTest.test_2_count, 4)
self.assertEqual(self.InnerTest.test_3_count, 1)
self.assertEqual(self.InnerTest.test_4_count, 2)
# Test methods do not use `self.variant`.
self.assertLen(res.errors, 1 + 2 + 4 + 4 + active_pmap)
for _, msg in res.errors:
self.assertRegex(
msg, 'RuntimeError: Test is wrapped .+ but never calls self.variant')
class MultipleVariantsTest(parameterized.TestCase):
@variants.all_variants()
def test_all_variants(self):
# self.variant must be used at least once.
self.variant(lambda x: x)(0)
self.assertNotEqual('meaning of life', 1337)
@variants.all_variants
def test_all_variants_no_parens(self):
# self.variant must be used at least once.
self.variant(lambda x: x)(0)
self.assertNotEqual('meaning of life', 1337)
@variants.variants(
with_jit=True, without_jit=True, with_device=True, without_device=True)
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
def test_many_variants(self, arg_0, arg_1, expected):
@self.variant
def fn(arg_0, arg_1):
return arg_1 - arg_0
actual = fn(arg_0, arg_1)
self.assertEqual(actual, expected)
class VmappedFunctionTest(parameterized.TestCase):
@variants.all_variants(with_pmap=True)
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
def test_vmapped_fn_named_params(self, arg_0, arg_1, expected):
varg_0, varg_1, vexpected = (
_scalar_to_ndarray(x) for x in (arg_0, arg_1, expected))
vmapped_fn = jax.vmap(DEFAULT_FN)
actual = self.variant(vmapped_fn)(varg_0, varg_1)
# pmap variant.
if len(actual.shape) == len(DEFAULT_NDARRAY_PARAMS_SHAPE) + 1:
vexpected = jnp.broadcast_to(vexpected, actual.shape)
np.testing.assert_array_equal(actual, vexpected)
class WithoutJitTest(parameterized.TestCase):
tests = _variant_default_tests_generator(
fn=DEFAULT_FN,
is_jit_context=False,
which_variants=dict(without_jit=True))
class WithJitTest(parameterized.TestCase):
tests = _variant_default_tests_generator(
fn=DEFAULT_FN, is_jit_context=True, which_variants=dict(with_jit=True))
@variants.variants(with_jit=True)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_different_jit_kwargs(self, arg_0, arg_1, expected):
kwarg_0 = arg_0
kwarg_1 = arg_1
arg_0_type = type(arg_0)
arg_1_type = type(arg_1)
kwarg_0_type = type(kwarg_0)
kwarg_1_type = type(kwarg_1)
@self.variant(static_argnums=(0,), static_argnames=('kwarg_1',))
def fn_0(arg_0, arg_1, kwarg_0, kwarg_1):
self.assertIsInstance(arg_0, arg_0_type)
self.assertNotIsInstance(arg_1, arg_1_type)
self.assertNotIsInstance(kwarg_0, kwarg_0_type)
self.assertIsInstance(kwarg_1, kwarg_1_type)
return DEFAULT_FN(arg_0 + kwarg_0, arg_1 + kwarg_1)
actual_0 = fn_0(arg_0, arg_1, kwarg_0=kwarg_0, kwarg_1=kwarg_1)
self.assertEqual(actual_0, 2 * expected)
@self.variant(static_argnums=(1, 3), static_argnames=('kwarg_1',))
def fn_1(arg_0, arg_1, kwarg_0, kwarg_1):
self.assertNotIsInstance(arg_0, arg_0_type)
self.assertIsInstance(arg_1, arg_1_type)
self.assertNotIsInstance(kwarg_0, kwarg_0_type)
self.assertIsInstance(kwarg_1, kwarg_1_type)
return DEFAULT_FN(arg_0 + kwarg_0, arg_1 + kwarg_1)
actual_1 = fn_1(arg_0, arg_1, kwarg_0=kwarg_0, kwarg_1=kwarg_1)
self.assertEqual(actual_1, 2 * expected)
@self.variant(static_argnums=(), static_argnames=('kwarg_0',))
def fn_2(arg_0, arg_1, kwarg_0, kwarg_1):
self.assertNotIsInstance(arg_0, arg_0_type)
self.assertNotIsInstance(arg_1, arg_1_type)
self.assertIsInstance(kwarg_0, kwarg_0_type)
self.assertNotIsInstance(kwarg_1, kwarg_1_type)
return DEFAULT_FN(arg_0 + kwarg_0, arg_1 + kwarg_1)
actual_2 = fn_2(arg_0, arg_1, kwarg_0=kwarg_0, kwarg_1=kwarg_1)
self.assertEqual(actual_2, 2 * expected)
def fn_3(arg_0, arg_1):
self.assertIsInstance(arg_0, arg_0_type)
self.assertNotIsInstance(arg_1, arg_1_type)
return DEFAULT_FN(arg_0, arg_1)
fn_3_v0 = self.variant(static_argnums=0, static_argnames='arg_0')(fn_3)
fn_3_v1 = self.variant(static_argnums=0)(fn_3)
fn_3_v2 = self.variant(static_argnums=(), static_argnames='arg_0')(fn_3)
self.assertEqual(fn_3_v0(arg_0, arg_1), expected)
self.assertEqual(fn_3_v1(arg_0=arg_0, arg_1=arg_1), expected)
self.assertEqual(fn_3_v1(arg_0, arg_1=arg_1), expected)
self.assertEqual(fn_3_v2(arg_0=arg_0, arg_1=arg_1), expected)
def _test_fn_without_device(arg_0, arg_1):
tc = unittest.TestCase()
tc.assertNotIsInstance(arg_0, jax.Array)
tc.assertNotIsInstance(arg_1, jax.Array)
return DEFAULT_FN(arg_0, arg_1)
class WithoutDeviceTest(parameterized.TestCase):
tests = _variant_default_tests_generator(
fn=_test_fn_without_device,
is_jit_context=False,
which_variants=dict(without_device=True))
@variants.variants(without_device=True)
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
def test_emplace(self, arg_0, arg_1, expected):
(arg_0, arg_1) = self.variant(lambda x: x)((arg_0, arg_1))
actual = _test_fn_without_device(arg_0, arg_1)
self.assertEqual(actual, expected)
def _test_fn_with_device(arg_0, arg_1):
tc = unittest.TestCase()
tc.assertIsInstance(arg_0, jax.Array)
tc.assertIsInstance(arg_1, jax.Array)
return DEFAULT_FN(arg_0, arg_1)
class WithDeviceTest(parameterized.TestCase):
tests = _variant_default_tests_generator(
fn=_test_fn_with_device,
is_jit_context=False,
which_variants=dict(with_device=True))
@variants.variants(with_device=True)
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
def test_emplace(self, arg_0, arg_1, expected):
(arg_0, arg_1) = self.variant(lambda x: x)((arg_0, arg_1))
actual = _test_fn_with_device(arg_0, arg_1)
self.assertEqual(actual, expected)
@variants.variants(with_device=True)
@parameterized.named_parameters(*DEFAULT_NAMED_PARAMS)
def test_ignore_argnums(self, arg_0, arg_1, expected):
static_type = type(arg_0)
@self.variant(ignore_argnums=(0, 2))
def fn(arg_0, arg_1, float_arg):
self.assertIsInstance(arg_0, static_type)
self.assertIsInstance(arg_1, jax.Array)
self.assertIsInstance(float_arg, float)
return DEFAULT_FN(arg_0, arg_1)
actual = fn(arg_0, arg_1, 5.3)
self.assertEqual(actual, expected)
def _test_fn_single_device(arg_0, arg_1):
tc = unittest.TestCase()
tc.assertIn(np.shape(arg_0), {(), DEFAULT_NDARRAY_PARAMS_SHAPE})
tc.assertIn(np.shape(arg_1), {(), DEFAULT_NDARRAY_PARAMS_SHAPE})
res = DEFAULT_FN(arg_0, arg_1)
psum_res = jax.lax.psum(res, axis_name='i')
return psum_res
class WithPmapSingleDeviceTest(parameterized.TestCase):
tests_single_device = _variant_default_tests_generator(
fn=_test_fn_single_device,
is_jit_context=True,
which_variants=dict(with_pmap=True),
n_devices=1)
class WithPmapAllAvailableDeviceTest(parameterized.TestCase):
def setUp(self):
super().setUp()
# Choose devices and a backend.
n_tpu = asserts._ai.num_devices_available('tpu')
n_gpu = asserts._ai.num_devices_available('gpu')
if n_tpu > 1:
self.n_devices, self.backend = n_tpu, 'tpu'
elif n_gpu > 1:
self.n_devices, self.backend = n_gpu, 'gpu'
else:
self.n_devices, self.backend = FLAGS['chex_n_cpu_devices'].value, 'cpu'
@variants.variants(with_pmap=True)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_pmap(self, arg_0, arg_1, expected):
n_devices, backend = self.n_devices, self.backend
n_copies = 3
arg_0_type = type(arg_0)
arg_1_type = type(arg_1)
@self.variant(reduce_fn=None, n_devices=n_devices, backend=backend)
def fn(arg_0, arg_1):
self.assertNotIsInstance(arg_0, arg_0_type)
self.assertNotIsInstance(arg_1, arg_1_type)
asserts.assert_shape(arg_0, [n_copies])
asserts.assert_shape(arg_1, [n_copies])
res = arg_1 - arg_0
psum_res = jax.lax.psum(res, axis_name='i')
return psum_res
arg_0 = jnp.zeros((n_copies,)) + arg_0
arg_1 = jnp.zeros((n_copies,)) + arg_1
actual = fn(arg_0, arg_1)
self.assertEqual(actual.shape, (n_devices, n_copies))
# Exponents of `n_devices`:
# +1: psum() inside fn()
# +1: jnp.sum() to aggregate results
self.assertEqual(jnp.sum(actual), n_copies * n_devices**2 * expected)
@variants.variants(with_pmap=True)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_pmap_vmapped_fn(self, arg_0, arg_1, expected):
n_devices, backend = self.n_devices, self.backend
n_copies = 7
actual_shape = (n_copies,) + DEFAULT_NDARRAY_PARAMS_SHAPE
varg_0 = _scalar_to_ndarray(arg_0, actual_shape)
varg_1 = _scalar_to_ndarray(arg_1, actual_shape)
vexpected = _scalar_to_ndarray(expected)
arg_0_type = type(varg_0)
arg_1_type = type(varg_1)
@self.variant(reduce_fn=None, n_devices=n_devices, backend=backend)
def fn(arg_0, arg_1):
self.assertNotIsInstance(arg_0, arg_0_type)
self.assertNotIsInstance(arg_1, arg_1_type)
@jax.vmap
def vmapped_fn(arg_0, arg_1):
self.assertIsInstance(arg_0, ArrayBatched)
self.assertIsInstance(arg_1, ArrayBatched)
asserts.assert_shape(arg_0, actual_shape[1:])
asserts.assert_shape(arg_1, actual_shape[1:])
return arg_1 - arg_0
res = vmapped_fn(arg_0, arg_1)
psum_res = jax.lax.psum(res, axis_name='i')
return psum_res
actual = fn(varg_0, varg_1)
self.assertEqual(actual.shape, (n_devices,) + actual_shape)
# Sum over `n_devices` and `n_copies` axes.
actual = actual.sum(axis=0).sum(axis=0)
# Exponents of `n_devices`:
# +1: psum() inside fn()
# +1: jnp.sum() to aggregate results
np.testing.assert_array_equal(actual, n_copies * n_devices**2 * vexpected)
@variants.variants(with_pmap=True)
@parameterized.parameters(*DEFAULT_PARAMS)
def test_pmap_static_argnums(self, arg_0, arg_1, expected):
n_devices, backend = self.n_devices, self.backend
n_copies = 5
actual_shape = (n_copies,)
varg_0 = _scalar_to_ndarray(arg_0, actual_shape)
arg_0_type = type(varg_0)
arg_1_type = type(arg_1)
@self.variant(
reduce_fn=None,
n_devices=n_devices,
backend=backend,
static_argnums=(1,),
axis_name='j',
)
def fn_static(arg_0, arg_1):
self.assertNotIsInstance(arg_0, arg_0_type)
self.assertIsInstance(arg_1, arg_1_type)
asserts.assert_shape(arg_0, [n_copies])
arg_1 = _scalar_to_ndarray(arg_1, actual_shape)
asserts.assert_shape(arg_1, [n_copies])
arg_1 = np.array(arg_1) # don't stage out operations on arg_1
psum_arg_1 = np.sum(jax.lax.psum(arg_1, axis_name='j'))
self.assertEqual(psum_arg_1, arg_1[0] * (n_copies * n_devices))
res = arg_1 - arg_0
psum_res = jax.lax.psum(res, axis_name='j')
return psum_res
actual = fn_static(varg_0, arg_1)
self.assertEqual(actual.shape, (n_devices, n_copies))
# Exponents of `n_devices`:
# +1: psum() inside fn()
# +1: jnp.sum() to aggregate results
self.assertEqual(jnp.sum(actual), n_copies * n_devices**2 * expected)
@variants.variants(with_pmap=True)
def test_pmap_static_argnums_zero(self):
n_devices, backend = self.n_devices, self.backend
n_copies = 5
varg_0 = 10
varg_1 = jnp.zeros(n_copies) + 20
arg_0_type = type(varg_0)
arg_1_type = type(varg_1)
@self.variant(
reduce_fn=None,
n_devices=n_devices,
backend=backend,
static_argnums=0,
)
def fn_static(arg_0, arg_1):
self.assertIsInstance(arg_0, arg_0_type)
self.assertNotIsInstance(arg_1, arg_1_type)
arg_0 = jnp.zeros(n_copies) + arg_0
asserts.assert_shape(arg_0, [n_copies])
asserts.assert_shape(arg_1, [n_copies])
res = arg_1 - arg_0
return jax.lax.psum(res, axis_name='i')
actual = fn_static(varg_0, varg_1)
self.assertEqual(actual.shape, (n_devices, n_copies))
# Exponents of `n_devices`:
# +1: psum() inside fn()
# +1: jnp.sum() to aggregate results
self.assertEqual(jnp.sum(actual), n_copies * n_devices**2 * 10)
@variants.variants(with_pmap=True)
def test_pmap_in_axes(self):
n_devices, backend = self.n_devices, self.backend
n_copies = 7
varg_0 = jnp.zeros((n_devices, n_copies)) + 1
varg_1 = jnp.zeros((n_devices, n_copies)) + 2
arg_0_type = type(varg_0)
arg_1_type = type(varg_1)
@self.variant(
broadcast_args_to_devices=False,
reduce_fn=None,
n_devices=n_devices,
backend=backend,
# Only 0 or None are supported (06/2020).
in_axes=(0, None),
)
def fn(arg_0, arg_1):
self.assertNotIsInstance(arg_0, arg_0_type)
self.assertNotIsInstance(arg_1, arg_1_type)
asserts.assert_shape(arg_0, [n_copies])
asserts.assert_shape(arg_1, [n_devices, n_copies])
res = arg_1 - arg_0
psum_res = jax.lax.psum(res, axis_name='i')
return psum_res
actual = fn(varg_0, varg_1)
self.assertEqual(actual.shape, (n_devices, n_devices, n_copies))
self.assertEqual(jnp.sum(actual), n_copies * n_devices**3)
@variants.variants(with_pmap=True)
def test_pmap_wrong_axis_size(self):
n_devices, backend = self.n_devices, self.backend
@self.variant(
broadcast_args_to_devices=False,
n_devices=n_devices,
backend=backend,
# Only 0 or None are supported (06/2020).
in_axes=(None, 0),
)
def fn(arg_0, arg_1):
raise RuntimeError('This line should not be executed.')
varg_0 = jnp.zeros(n_devices + 1)
varg_1 = jnp.zeros(n_devices + 2)
with self.assertRaisesRegex(
ValueError, 'Pmappable.* axes size must be equal to number of devices.*'
f'expected the first dim to be {n_devices}'):
fn(varg_0, varg_1)
if __name__ == '__main__':
absltest.main()