PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Chex variants utilities."""
import enum
import functools
import inspect
import itertools
from typing import Any, Sequence
import unittest
from absl import flags
from absl.testing import parameterized
from chex._src import fake
from chex._src import pytypes
import jax
from jax import tree_util
import jax.numpy as jnp
import toolz
FLAGS = flags.FLAGS
flags.DEFINE_bool(
"chex_skip_pmap_variant_if_single_device", True,
"Whether to skip pmap variant if only one device is available.")
# We choose to subclass instead of a simple alias, as Python doesn't allow
# multiple inheritance from the same class, and users may want to subclass their
# tests from both `chex.TestCase` and `parameterized.TestCase`.
#
# User is free to use any base class that supports generators unrolling
# instead of `variants.TestCase` or `parameterized.TestCase`. If a base class
# doesn't support this feature variant test fails with a corresponding error.
class TestCase(parameterized.TestCase):
"""A class for Chex tests that use variants.
See the docstring for ``chex.variants`` for more information.
Note: ``chex.variants`` returns a generator producing one test per variant.
Therefore, the used test class must support dynamic unrolling of these
generators during module import. It is implemented (and battle-tested) in
``absl.parameterized.TestCase``, and here we subclass from it.
"""
def variant(self, *args, **kwargs):
"""Raises a RuntimeError if not overriden or redefined."""
raise RuntimeError(
"self.variant is not defined: forgot to wrap a test in @chex.variants?")
class ChexVariantType(enum.Enum):
"""An enumeration of available Chex variants.
Use ``self.variant.type`` to get type of the current test variant.
See the docstring of ``chex.variants`` for more information.
"""
WITH_JIT = 1
WITHOUT_JIT = 2
WITH_DEVICE = 3
WITHOUT_DEVICE = 4
WITH_PMAP = 5
def __str__(self) -> str:
return "_" + self.name.lower()
tree_map = tree_util.tree_map
def params_product(*params_lists: Sequence[Sequence[Any]],
named: bool = False) -> Sequence[Sequence[Any]]:
"""Generates a cartesian product of `params_lists`.
See tests from ``variants_test.py`` for examples of usage.
Args:
*params_lists: A list of params combinations.
named: Whether to generate test names (for
`absl.parameterized.named_parameters(...)`).
Returns:
A cartesian product of `params_lists` combinations.
"""
def generate():
for combination in itertools.product(*params_lists):
if named:
name = "_".join(t[0] for t in combination)
args_tuples = (t[1:] for t in combination)
args = sum(args_tuples, ())
yield (name, *args)
else:
yield sum(combination, ())
return list(generate())
def count_num_calls(fn):
"""Counts the number of times the function was called."""
num_calls = 0
@functools.wraps(fn)
def fn_wrapped(*args, **kwargs):
nonlocal num_calls
num_calls += 1
return fn(*args, **kwargs)
return fn_wrapped, lambda: num_calls
class VariantsTestCaseGenerator:
"""TestCase generator for chex variants. Supports sharding."""
def __init__(self, test_object, which_variants):
self._which_variants = which_variants
self._generated_names_freq = {}
if hasattr(test_object, "__iter__"):
# `test_object` is a generator (e.g. parameterised test).
self._test_methods = list(test_object)
else:
# `test_object` is a single test method.
self._test_methods = [test_object]
def add_variants(self, which_variants):
"""Merge variants."""
for var, incl in which_variants.items():
self._which_variants[var] = self._which_variants.get(var, False) or incl
@property
def __name__(self):
msg = ("A test wrapper attempts to access __name__ of "
"VariantsTestCaseGenerator. Usually, this happens when "
"@parameterized wraps @variants.variants. Make sure that the "
"@variants.variants wrapper is an outer one, i.e. nothing wraps it.")
raise RuntimeError(msg)
def __call__(self):
msg = ("A test wrapper attempts to invoke __call__ of "
"VariantsTestCaseGenerator: make sure that all `TestCase` instances "
"that use variants inherit from `chex.TestCase`.")
raise RuntimeError(msg)
def _set_test_name(self, test_method, variant):
"""Set a name for the generated test."""
name = getattr(test_method, "__name__", "")
params_repr = getattr(test_method, "__x_params_repr__", "")
chex_suffix = f"{variant}"
candidate_name = "_".join(filter(None, [name, params_repr, chex_suffix]))
name_freq = self._generated_names_freq.get(candidate_name, 0)
if name_freq:
# Ensure that test names are unique.
new_name = name + "_" + str(name_freq)
unique_name = "_".join(filter(None, [new_name, params_repr, chex_suffix]))
else:
unique_name = candidate_name
self._generated_names_freq[candidate_name] = name_freq + 1
# Always use name for compatibility with `absl.testing.parameterized`.
setattr(test_method, "__name__", unique_name)
setattr(test_method, "__x_params_repr__", "")
setattr(test_method, "__x_use_name__", True)
return test_method
def _inner_iter(self, test_method):
"""Generate chex variants for a single test."""
def make_test(variant: ChexVariantType):
@functools.wraps(test_method)
def test(self, *args, **kwargs):
# Skip pmap variant if only one device is available.
if (variant is ChexVariantType.WITH_PMAP and
FLAGS["chex_skip_pmap_variant_if_single_device"].value and
jax.device_count() < 2):
raise unittest.SkipTest(
f"Only 1 device is available ({jax.devices()}).")
# n_cpu_devices assert.
if FLAGS["chex_assert_multiple_cpu_devices"].value:
required_n_cpus = fake.get_n_cpu_devices_from_xla_flags()
if required_n_cpus < 2:
raise RuntimeError(
f"Required number of CPU devices is {required_n_cpus} < 2."
"Consider setting up your test module to use multiple CPU "
" devices (see README.md) or disabling "
"`chex_assert_multiple_cpu_devices` flag.")
available_n_cpus = jax.device_count("cpu")
if required_n_cpus != available_n_cpus:
raise RuntimeError(
"Number of available CPU devices is not equal to the required: "
f"{available_n_cpus} != {required_n_cpus}")
# Set up the variant.
self.variant, num_calls = count_num_calls(_variant_decorators[variant])
self.variant.type = variant
res = test_method(self, *args, **kwargs)
if num_calls() == 0:
raise RuntimeError(
"Test is wrapped in @chex.variants, but never calls self.variant."
" Consider debugging the test or removing @chex.variants wrapper."
f" (variant: {variant})")
return res
self._set_test_name(test, variant)
return test
selected_variants = [
var_name for var_name, is_included in self._which_variants.items()
if is_included
]
if not selected_variants:
raise ValueError(f"No variants selected for test: {test_method}.")
return (make_test(var_name) for var_name in selected_variants)
def __iter__(self):
"""Generate chex variants for each test case."""
return itertools.chain(*(self._inner_iter(m) for m in self._test_methods))
@toolz.curry
def _variants_fn(test_object, **which_variants) -> VariantsTestCaseGenerator:
"""Implements `variants` and `all_variants`."""
# Convert keys to enum entries.
which_variants = {
ChexVariantType[name.upper()]: var
for name, var in which_variants.items()
}
if isinstance(test_object, VariantsTestCaseGenerator):
# Merge variants for nested wrappers.
test_object.add_variants(which_variants)
else:
test_object = VariantsTestCaseGenerator(test_object, which_variants)
return test_object
@toolz.curry
# pylint: disable=redefined-outer-name
def variants(test_method,
with_jit: bool = False,
without_jit: bool = False,
with_device: bool = False,
without_device: bool = False,
with_pmap: bool = False) -> VariantsTestCaseGenerator:
# pylint: enable=redefined-outer-name
"""Decorates a test to expose Chex variants.
The decorated test has access to a decorator called ``self.variant``, which
may be applied to functions to test different JAX behaviors. Consider:
.. code-block:: python
@chex.variants(with_jit=True, without_jit=True)
def test(self):
@self.variant
def f(x, y):
return x + y
self.assertEqual(f(1, 2), 3)
In this example, the function ``test`` will be called twice: once with `f`
jitted (i.e. using `jax.jit`) and another where `f` is not jitted.
Variants `with_jit=True` and `with_pmap=True` accept additional specific to
them arguments. Example:
.. code-block:: python
@chex.variants(with_jit=True)
def test(self):
@self.variant(static_argnums=(1,))
def f(x, y):
# `y` is not traced.
return x + y
self.assertEqual(f(1, 2), 3)
Variant `with_pmap=True` also accepts `broadcast_args_to_devices`
(whether to broadcast each input argument to all participating devices),
`reduce_fn` (a function to apply to results of pmapped `fn`), and
`n_devices` (number of devices to use in the `pmap` computation).
See the docstring of `_with_pmap` for more details (including default values).
If used with ``absl.testing.parameterized``, `@chex.variants` must wrap it:
.. code-block:: python
@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters('test', *args)
def test(self, *args):
...
Tests that use this wrapper must be inherited from ``parameterized.TestCase``.
For more examples see ``variants_test.py``.
Args:
test_method: A test method to decorate.
with_jit: Whether to test with `jax.jit`.
without_jit: Whether to test without `jax.jit`. Any jit compilation done
within the test method will not be affected.
with_device: Whether to test with args placed on device, using
`jax.device_put`.
without_device: Whether to test with args (explicitly) not placed on device,
using `jax.device_get`.
with_pmap: Whether to test with `jax.pmap`, with computation duplicated
across devices.
Returns:
A decorated ``test_method``.
"""
return _variants_fn(
test_method,
with_jit=with_jit,
without_jit=without_jit,
with_device=with_device,
without_device=without_device,
with_pmap=with_pmap)
@toolz.curry
# pylint: disable=redefined-outer-name
def all_variants(test_method,
with_jit: bool = True,
without_jit: bool = True,
with_device: bool = True,
without_device: bool = True,
with_pmap: bool = True) -> VariantsTestCaseGenerator:
# pylint: enable=redefined-outer-name
"""Equivalent to ``chex.variants`` but with flipped defaults."""
return _variants_fn(
test_method,
with_jit=with_jit,
without_jit=without_jit,
with_device=with_device,
without_device=without_device,
with_pmap=with_pmap)
def check_variant_arguments(variant_fn):
"""Raises `ValueError` if `variant_fn` got an unknown argument."""
@functools.wraps(variant_fn)
def wrapper(*args, **kwargs):
unknown_args = set(kwargs.keys()) - _valid_kwargs_keys
if unknown_args:
raise ValueError(f"Unknown arguments in `self.variant`: {unknown_args}.")
return variant_fn(*args, **kwargs)
return wrapper
@toolz.curry
@check_variant_arguments
def _with_jit(fn,
static_argnums=None,
static_argnames=None,
device=None,
backend=None,
**unused_kwargs):
"""Variant that applies `jax.jit` to fn."""
return jax.jit(
fn,
static_argnums=static_argnums,
static_argnames=static_argnames,
device=device,
backend=backend)
@toolz.curry
@check_variant_arguments
def _without_jit(fn, **unused_kwargs):
"""Variant that does not apply `jax.jit` to a fn (identity)."""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper
@toolz.curry
@check_variant_arguments
def _with_device(fn, ignore_argnums=(), static_argnums=(), **unused_kwargs):
"""Variant that applies `jax.device_put` to the args of fn."""
if isinstance(ignore_argnums, int):
ignore_argnums = (ignore_argnums,)
if isinstance(static_argnums, int):
static_argnums = (static_argnums,)
@functools.wraps(fn)
def wrapper(*args, **kwargs):
def put(x):
try:
return jax.device_put(x)
except TypeError: # not a valid JAX type
return x
device_args = [
arg if (idx in ignore_argnums or idx in static_argnums) else tree_map(
put, arg) for idx, arg in enumerate(args)
]
device_kwargs = tree_map(put, kwargs)
return fn(*device_args, **device_kwargs)
return wrapper
@toolz.curry
@check_variant_arguments
def _without_device(fn, **unused_kwargs):
"""Variant that applies `jax.device_get` to the args of fn."""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
def get(x):
if isinstance(x, jax.Array):
return jax.device_get(x)
return x
no_device_args = tree_map(get, args)
no_device_kwargs = tree_map(get, kwargs)
return fn(*no_device_args, **no_device_kwargs)
return wrapper
@toolz.curry
@check_variant_arguments
def _with_pmap(fn,
broadcast_args_to_devices=True,
reduce_fn="first_device_output",
n_devices=None,
axis_name="i",
devices=None,
in_axes=0,
static_broadcasted_argnums=(),
static_argnums=(),
backend=None,
**unused_kwargs):
"""Variant that applies `jax.pmap` to fn.
Args:
fn: A function to wrap.
broadcast_args_to_devices: Whether to broadcast `fn` args to pmap format
(i.e. pmapped axes' sizes == a number of devices).
reduce_fn: A function to apply to outputs of `fn`.
n_devices: A number of devices to use (can specify a `backend` if required).
axis_name: An argument for `pmap`.
devices: An argument for `pmap`.
in_axes: An argument for `pmap`.
static_broadcasted_argnums: An argument for `pmap`.
static_argnums: An alias of ``static_broadcasted_argnums``.
backend: An argument for `pmap`.
**unused_kwargs: Unused kwargs (e.g. related to other variants).
Returns:
Wrapped `fn` that accepts `args` and `kwargs` and returns a superposition of
`reduce_fn` and `fn` applied to them.
Raises:
ValueError: If `broadcast_args_to_devices` used with `in_axes` or
`static_broadcasted_argnums`; if number of available devices is less than
required; if pmappable arg axes' sizes are not equal to the number of
devices.
SkipTest: If the flag ``chex_skip_pmap_variant_if_single_device`` is set and
there is only one device available.
"""
if (FLAGS["chex_skip_pmap_variant_if_single_device"].value and
jax.device_count() < 2):
raise unittest.SkipTest(f"Only 1 device is available ({jax.devices()}).")
if broadcast_args_to_devices and in_axes != 0:
raise ValueError(
"Do not use `broadcast_args_to_devices` when specifying `in_axes`.")
# Set up a reduce function.
if reduce_fn == "first_device_output":
reduce_fn = lambda t: tree_map(lambda x: x[0], t)
elif reduce_fn == "identity" or reduce_fn is None: # Identity.
reduce_fn = lambda t: t
if not static_argnums and static_argnums != 0:
static_argnums = static_broadcasted_argnums
if isinstance(static_argnums, int):
static_argnums = (static_argnums,)
pmap_kwargs = dict(
axis_name=axis_name,
devices=devices,
in_axes=in_axes,
static_broadcasted_argnums=static_argnums,
backend=backend)
pmapped_fn = jax.pmap(fn, **pmap_kwargs)
@functools.wraps(pmapped_fn)
def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree):
if kwargs and (in_axes != 0 or static_argnums):
raise ValueError("Do not use kwargs with `in_axes` or `static_argnums` "
"in pmapped function.")
devices_ = list(devices or jax.devices(backend))
n_devices_ = n_devices or len(devices_)
devices_ = devices_[:n_devices_]
if len(devices_) != n_devices_:
raise ValueError("Number of available devices is less than required for "
f"test ({len(devices_)} < {n_devices_})")
bcast_fn = lambda x: jnp.broadcast_to(x, (n_devices_,) + jnp.array(x).shape)
if broadcast_args_to_devices:
args = [
tree_map(bcast_fn, arg) if idx not in static_argnums else arg
for idx, arg in enumerate(args)
]
kwargs = tree_map(bcast_fn, kwargs)
else:
# Pmappable axes size must be equal to number of devices.
in_axes_ = in_axes if isinstance(in_axes,
(tuple, list)) else [in_axes] * len(args)
is_pmappable_arg = [
idx not in static_argnums and in_axes_[idx] is not None
for idx in range(len(args))
]
for is_pmappable_arg, arg in zip(is_pmappable_arg, args):
if not is_pmappable_arg:
continue
if not all(
x.shape[0] == n_devices_ for x in jax.tree_util.tree_leaves(arg)):
shapes = tree_map(jnp.shape, arg)
raise ValueError(
f"Pmappable arg axes size must be equal to number of devices, "
f"got: {shapes} (expected the first dim to be {n_devices_}). "
"Consider setting `broadcast_args_to_devices=True`.")
new_kwargs = dict(
axis_name=axis_name,
devices=devices_,
in_axes=in_axes,
static_broadcasted_argnums=static_argnums,
backend=backend)
# Re-compile fn if kwargs changed.
nonlocal pmap_kwargs
nonlocal pmapped_fn
if new_kwargs != pmap_kwargs:
pmap_kwargs = new_kwargs
pmapped_fn = jax.pmap(fn, **pmap_kwargs)
res = pmapped_fn(*args, **kwargs)
return reduce_fn(res)
return wrapper
_variant_decorators = dict({
ChexVariantType.WITH_JIT: _with_jit,
ChexVariantType.WITHOUT_JIT: _without_jit,
ChexVariantType.WITH_DEVICE: _with_device,
ChexVariantType.WITHOUT_DEVICE: _without_device,
ChexVariantType.WITH_PMAP: _with_pmap,
})
class Variant:
"""Variant class for typing and string representation."""
def __init__(self, name, fn):
self._fn = fn
self._name = name
def __repr__(self):
return self._name
def __call__(self, *args, **kwargs):
# Could apply decorators (currying, arg-checking) here
return self._fn(*args, **kwargs)
# Expose variant objects.
without_device = Variant("chex_without_device", _without_device)
without_jit = Variant("chex_without_jit", _without_jit)
with_device = Variant("chex_with_device", _with_device)
with_jit = Variant("chex_with_jit", _with_jit)
with_pmap = Variant("chex_with_pmap", _with_pmap)
ALL_VARIANTS = (without_device, without_jit, with_device, with_jit, with_pmap)
# Collect valid argument names from all variant decorators.
_valid_kwargs_keys = set()
for fn_ in _variant_decorators.values():
original_fn = fn_.func.__wrapped__
_valid_kwargs_keys.update(inspect.getfullargspec(original_fn).args)