Spaces:
Building
Building
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Chexification utilities.""" | |
import atexit | |
import collections | |
from concurrent import futures | |
import dataclasses | |
import functools | |
import re | |
from typing import Any, Callable, FrozenSet | |
from absl import logging | |
from chex._src import asserts_internal as _ai | |
import jax | |
from jax.experimental import checkify | |
class _ChexifyChecks: | |
"""A set of checks imported from checkify.""" | |
user: FrozenSet[checkify.ErrorCategory] = checkify.user_checks | |
nan: FrozenSet[checkify.ErrorCategory] = checkify.nan_checks | |
index: FrozenSet[checkify.ErrorCategory] = checkify.index_checks | |
div: FrozenSet[checkify.ErrorCategory] = checkify.div_checks | |
float: FrozenSet[checkify.ErrorCategory] = checkify.float_checks | |
automatic: FrozenSet[checkify.ErrorCategory] = checkify.automatic_checks | |
all: FrozenSet[checkify.ErrorCategory] = checkify.all_checks | |
_chexify_error_pattern = re.compile( | |
re.escape(_ai.get_chexify_err_message('ANY', 'ANY')).replace('ANY', '.*') | |
) | |
def _check_error(err: checkify.Error) -> None: | |
"""Checks the error and converts it to chex format.""" | |
try: | |
checkify.check_error(err) | |
except ValueError as exc: | |
msg = str(exc) | |
if _chexify_error_pattern.match(msg): | |
# Remove internal code pointers. | |
internal_info_pos = msg.rfind('(check failed at') | |
if internal_info_pos != -1: | |
msg = msg[:internal_info_pos] | |
raise AssertionError(msg) # pylint:disable=raise-missing-from | |
else: | |
raise | |
def block_until_chexify_assertions_complete() -> None: | |
"""Waits until all asynchronous checks complete. | |
See `chexify` for more detail. | |
""" | |
for wait_fn in _ai.CHEXIFY_STORAGE.wait_fns: | |
wait_fn() | |
# to catch uninspected error stats | |
def _check_if_hanging_assertions(): | |
if _ai.CHEXIFY_STORAGE.wait_fns: | |
logging.warning( | |
'[Chex] Some of chexify assetion statuses were not inspected due to ' | |
'async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html).' | |
' Consider calling `chex.block_until_chexify_assertions_complete()` at ' | |
'the end of computations that rely on jitted chex assetions.') | |
block_until_chexify_assertions_complete() | |
# Public API. | |
ChexifyChecks = _ChexifyChecks() | |
def chexify( | |
fn: Callable[..., Any], | |
async_check: bool = True, | |
errors: FrozenSet[checkify.ErrorCategory] = ChexifyChecks.user, | |
) -> Callable[..., Any]: | |
"""Wraps a transformed function `fn` to enable Chex value assertions. | |
Chex value/runtime assertions access concrete values of tensors (e.g. | |
`assert_tree_all_finite`) which are not available during JAX tracing, see | |
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html | |
and | |
https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError. | |
This wrapper enables them in jitted/pmapped functions by performing a | |
specifically designed JAX transformation | |
https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation | |
and calling functionalised checks | |
https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html | |
Example: | |
.. code:: | |
@chex.chexify | |
@jax.jit | |
def logp1_abs_safe(x: chex.Array) -> chex.Array: | |
chex.assert_tree_all_finite(x) | |
return jnp.log(jnp.abs(x) + 1) | |
logp1_abs_safe(jnp.ones(2)) # OK | |
logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS | |
logp1_abs_safe.wait_checks() | |
Note 1: This wrapper allows identifying the first failed assertion in a jitted | |
code by printing a pointer to the line where the failed assertion was invoked. | |
For getting verbose messages (including concrete tensor values), an unjitted | |
version of the code will need to be executed with the same input values. Chex | |
does not currently provide tools to help with this. | |
Note 2: This wrapper fully supports asynchronous executions | |
(see https://jax.readthedocs.io/en/latest/async_dispatch.html). | |
To block program execution until asynchronous checks for a _chexified_ | |
function `fn` complete, call `fn.wait_checks()`. Similarly, | |
`chex.block_until_chexify_assertions_complete()` will block program execution | |
until _all_ asyncronous checks complete. | |
Note 3: Chex automatically selects the backend for executing its assertions | |
(i.e. CPU or device accelerator) depending on the program context. | |
Note 4: Value assertions can have impact on the performance of a function, see | |
https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations | |
Note 5: static assertions, such as `assert_shape` or | |
`assert_trees_all_equal_dtypes`, can be called from a jitted function without | |
`chexify` wrapper (since they do not access concrete values, only | |
shapes and/or dtypes which are available during JAX tracing). | |
More examples can be found at | |
https://github.com/deepmind/chex/blob/master/chex/_src/asserts_chexify_test.py | |
Args: | |
fn: A transformed function to wrap. | |
async_check: Whether to check errors in the async dispatch mode. See | |
https://jax.readthedocs.io/en/latest/async_dispatch.html. | |
errors: A set of `checkify.ErrorCategory` values which defines the set of | |
enabled checks. By default only explicit ``checks`` are enabled (`user`). | |
You can also for example enable NaN and Div-by-0 errors by passing the | |
`float` set, or for example combine multiple sets through set | |
operations (`float | user`). | |
Returns: | |
A _chexified_ function, i.e. the one with enabled value assertions. | |
The returned function has `wait_checks()` method that blocks the caller | |
until all pending async checks complete. | |
""" | |
# Hardware/XLA failures can only happen on the C++ side. They are expected to | |
# issue critical errors that will immediately crash the whole program. | |
# Nevertheless, Chex sets its own timeout for every chexified XLA comp. to | |
# ensure that a program never blocks on Chex side when running in async mode. | |
async_timeout = 1800 # 30 minutes | |
# Get function name. | |
if isinstance(fn, functools.partial): | |
func_name = fn.func.__name__ | |
else: | |
func_name = fn.__name__ | |
if async_check: | |
# Spawn a thread for processing blocking calls. | |
thread_pool = futures.ThreadPoolExecutor(1, f'async_chex_{func_name}') | |
# A deque for futures. | |
async_check_futures = collections.deque() | |
# Checkification. | |
checkified_fn = checkify.checkify(fn, errors=errors) | |
def _chexified_fn(*args, **kwargs): | |
if _ai.CHEXIFY_STORAGE.level: | |
raise RuntimeError( | |
'Nested @chexify wrapping is disallowed. ' | |
'Make sure that you only wrap the function at the outermost level.') | |
if _ai.has_tracers((args, kwargs)): | |
raise RuntimeError( | |
'@chexify must be applied on top of all (p)jit/pmap transformations' | |
' (otherwise it will result in `UnexpectedTracerError`). If you have' | |
' functions that use value assertions, do not wrap them' | |
' individually -- just wrap the outermost function after' | |
' applying all your JAX transformations. See the example at ' | |
'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions' # pylint:disable=line-too-long | |
) | |
if async_check: | |
# Check completed calls. | |
while async_check_futures and async_check_futures[0].done(): | |
_check_error(async_check_futures.popleft().result(async_timeout)) | |
# Run the checkified function. | |
_ai.CHEXIFY_STORAGE.level += 1 | |
try: | |
err, out = checkified_fn(*args, **kwargs) | |
finally: | |
_ai.CHEXIFY_STORAGE.level -= 1 | |
# Check errors. | |
if async_check: | |
# Blocking call is deferred to the thread. | |
async_check_futures.append( | |
thread_pool.submit(lambda: jax.device_get(err))) | |
else: | |
# Blocks until `fn`'s outputs are ready. | |
_check_error(err) | |
return out | |
def _wait_checks(): | |
if async_check: | |
while async_check_futures: | |
_check_error(async_check_futures.popleft().result(async_timeout)) | |
# Add a barrier callback to the global storage. | |
_ai.CHEXIFY_STORAGE.wait_fns.append(_wait_checks) | |
# Add the callback to the chexified funtion's properties. | |
if not hasattr(_chexified_fn, 'wait_checks'): | |
_chexified_fn.wait_checks = _wait_checks | |
else: | |
logging.warning( | |
"Function %s already defines 'wait_checks' method; " | |
'Chex will not redefine it.', func_name) | |
return _chexified_fn | |
def with_jittable_assertions(fn: Callable[..., Any], | |
async_check: bool = True) -> Callable[..., Any]: | |
"""An alias for `chexify` (see the docs).""" | |
return chexify(fn, async_check) | |