PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Chexification utilities."""
import atexit
import collections
from concurrent import futures
import dataclasses
import functools
import re
from typing import Any, Callable, FrozenSet
from absl import logging
from chex._src import asserts_internal as _ai
import jax
from jax.experimental import checkify
@dataclasses.dataclass(frozen=True)
class _ChexifyChecks:
"""A set of checks imported from checkify."""
user: FrozenSet[checkify.ErrorCategory] = checkify.user_checks
nan: FrozenSet[checkify.ErrorCategory] = checkify.nan_checks
index: FrozenSet[checkify.ErrorCategory] = checkify.index_checks
div: FrozenSet[checkify.ErrorCategory] = checkify.div_checks
float: FrozenSet[checkify.ErrorCategory] = checkify.float_checks
automatic: FrozenSet[checkify.ErrorCategory] = checkify.automatic_checks
all: FrozenSet[checkify.ErrorCategory] = checkify.all_checks
_chexify_error_pattern = re.compile(
re.escape(_ai.get_chexify_err_message('ANY', 'ANY')).replace('ANY', '.*')
)
def _check_error(err: checkify.Error) -> None:
"""Checks the error and converts it to chex format."""
try:
checkify.check_error(err)
except ValueError as exc:
msg = str(exc)
if _chexify_error_pattern.match(msg):
# Remove internal code pointers.
internal_info_pos = msg.rfind('(check failed at')
if internal_info_pos != -1:
msg = msg[:internal_info_pos]
raise AssertionError(msg) # pylint:disable=raise-missing-from
else:
raise
def block_until_chexify_assertions_complete() -> None:
"""Waits until all asynchronous checks complete.
See `chexify` for more detail.
"""
for wait_fn in _ai.CHEXIFY_STORAGE.wait_fns:
wait_fn()
@atexit.register # to catch uninspected error stats
def _check_if_hanging_assertions():
if _ai.CHEXIFY_STORAGE.wait_fns:
logging.warning(
'[Chex] Some of chexify assetion statuses were not inspected due to '
'async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html).'
' Consider calling `chex.block_until_chexify_assertions_complete()` at '
'the end of computations that rely on jitted chex assetions.')
block_until_chexify_assertions_complete()
# Public API.
ChexifyChecks = _ChexifyChecks()
def chexify(
fn: Callable[..., Any],
async_check: bool = True,
errors: FrozenSet[checkify.ErrorCategory] = ChexifyChecks.user,
) -> Callable[..., Any]:
"""Wraps a transformed function `fn` to enable Chex value assertions.
Chex value/runtime assertions access concrete values of tensors (e.g.
`assert_tree_all_finite`) which are not available during JAX tracing, see
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
and
https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError.
This wrapper enables them in jitted/pmapped functions by performing a
specifically designed JAX transformation
https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation
and calling functionalised checks
https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html
Example:
.. code::
@chex.chexify
@jax.jit
def logp1_abs_safe(x: chex.Array) -> chex.Array:
chex.assert_tree_all_finite(x)
return jnp.log(jnp.abs(x) + 1)
logp1_abs_safe(jnp.ones(2)) # OK
logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS
logp1_abs_safe.wait_checks()
Note 1: This wrapper allows identifying the first failed assertion in a jitted
code by printing a pointer to the line where the failed assertion was invoked.
For getting verbose messages (including concrete tensor values), an unjitted
version of the code will need to be executed with the same input values. Chex
does not currently provide tools to help with this.
Note 2: This wrapper fully supports asynchronous executions
(see https://jax.readthedocs.io/en/latest/async_dispatch.html).
To block program execution until asynchronous checks for a _chexified_
function `fn` complete, call `fn.wait_checks()`. Similarly,
`chex.block_until_chexify_assertions_complete()` will block program execution
until _all_ asyncronous checks complete.
Note 3: Chex automatically selects the backend for executing its assertions
(i.e. CPU or device accelerator) depending on the program context.
Note 4: Value assertions can have impact on the performance of a function, see
https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations
Note 5: static assertions, such as `assert_shape` or
`assert_trees_all_equal_dtypes`, can be called from a jitted function without
`chexify` wrapper (since they do not access concrete values, only
shapes and/or dtypes which are available during JAX tracing).
More examples can be found at
https://github.com/deepmind/chex/blob/master/chex/_src/asserts_chexify_test.py
Args:
fn: A transformed function to wrap.
async_check: Whether to check errors in the async dispatch mode. See
https://jax.readthedocs.io/en/latest/async_dispatch.html.
errors: A set of `checkify.ErrorCategory` values which defines the set of
enabled checks. By default only explicit ``checks`` are enabled (`user`).
You can also for example enable NaN and Div-by-0 errors by passing the
`float` set, or for example combine multiple sets through set
operations (`float | user`).
Returns:
A _chexified_ function, i.e. the one with enabled value assertions.
The returned function has `wait_checks()` method that blocks the caller
until all pending async checks complete.
"""
# Hardware/XLA failures can only happen on the C++ side. They are expected to
# issue critical errors that will immediately crash the whole program.
# Nevertheless, Chex sets its own timeout for every chexified XLA comp. to
# ensure that a program never blocks on Chex side when running in async mode.
async_timeout = 1800 # 30 minutes
# Get function name.
if isinstance(fn, functools.partial):
func_name = fn.func.__name__
else:
func_name = fn.__name__
if async_check:
# Spawn a thread for processing blocking calls.
thread_pool = futures.ThreadPoolExecutor(1, f'async_chex_{func_name}')
# A deque for futures.
async_check_futures = collections.deque()
# Checkification.
checkified_fn = checkify.checkify(fn, errors=errors)
@functools.wraps(fn)
def _chexified_fn(*args, **kwargs):
if _ai.CHEXIFY_STORAGE.level:
raise RuntimeError(
'Nested @chexify wrapping is disallowed. '
'Make sure that you only wrap the function at the outermost level.')
if _ai.has_tracers((args, kwargs)):
raise RuntimeError(
'@chexify must be applied on top of all (p)jit/pmap transformations'
' (otherwise it will result in `UnexpectedTracerError`). If you have'
' functions that use value assertions, do not wrap them'
' individually -- just wrap the outermost function after'
' applying all your JAX transformations. See the example at '
'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions' # pylint:disable=line-too-long
)
if async_check:
# Check completed calls.
while async_check_futures and async_check_futures[0].done():
_check_error(async_check_futures.popleft().result(async_timeout))
# Run the checkified function.
_ai.CHEXIFY_STORAGE.level += 1
try:
err, out = checkified_fn(*args, **kwargs)
finally:
_ai.CHEXIFY_STORAGE.level -= 1
# Check errors.
if async_check:
# Blocking call is deferred to the thread.
async_check_futures.append(
thread_pool.submit(lambda: jax.device_get(err)))
else:
# Blocks until `fn`'s outputs are ready.
_check_error(err)
return out
def _wait_checks():
if async_check:
while async_check_futures:
_check_error(async_check_futures.popleft().result(async_timeout))
# Add a barrier callback to the global storage.
_ai.CHEXIFY_STORAGE.wait_fns.append(_wait_checks)
# Add the callback to the chexified funtion's properties.
if not hasattr(_chexified_fn, 'wait_checks'):
_chexified_fn.wait_checks = _wait_checks
else:
logging.warning(
"Function %s already defines 'wait_checks' method; "
'Chex will not redefine it.', func_name)
return _chexified_fn
def with_jittable_assertions(fn: Callable[..., Any],
async_check: bool = True) -> Callable[..., Any]:
"""An alias for `chexify` (see the docs)."""
return chexify(fn, async_check)