File size: 9,378 Bytes
f5f3483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Chexification utilities."""

import atexit
import collections
from concurrent import futures
import dataclasses
import functools
import re
from typing import Any, Callable, FrozenSet

from absl import logging
from chex._src import asserts_internal as _ai
import jax
from jax.experimental import checkify


@dataclasses.dataclass(frozen=True)
class _ChexifyChecks:
  """A set of checks imported from checkify."""

  user: FrozenSet[checkify.ErrorCategory] = checkify.user_checks
  nan: FrozenSet[checkify.ErrorCategory] = checkify.nan_checks
  index: FrozenSet[checkify.ErrorCategory] = checkify.index_checks
  div: FrozenSet[checkify.ErrorCategory] = checkify.div_checks
  float: FrozenSet[checkify.ErrorCategory] = checkify.float_checks
  automatic: FrozenSet[checkify.ErrorCategory] = checkify.automatic_checks
  all: FrozenSet[checkify.ErrorCategory] = checkify.all_checks


_chexify_error_pattern = re.compile(
    re.escape(_ai.get_chexify_err_message('ANY', 'ANY')).replace('ANY', '.*')
)


def _check_error(err: checkify.Error) -> None:
  """Checks the error and converts it to chex format."""
  try:
    checkify.check_error(err)
  except ValueError as exc:
    msg = str(exc)
    if _chexify_error_pattern.match(msg):
      # Remove internal code pointers.
      internal_info_pos = msg.rfind('(check failed at')
      if internal_info_pos != -1:
        msg = msg[:internal_info_pos]
      raise AssertionError(msg)  # pylint:disable=raise-missing-from
    else:
      raise


def block_until_chexify_assertions_complete() -> None:
  """Waits until all asynchronous checks complete.

  See `chexify` for more detail.
  """
  for wait_fn in _ai.CHEXIFY_STORAGE.wait_fns:
    wait_fn()


@atexit.register  # to catch uninspected error stats
def _check_if_hanging_assertions():
  if _ai.CHEXIFY_STORAGE.wait_fns:
    logging.warning(
        '[Chex] Some of chexify assetion statuses were not inspected due to '
        'async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html).'
        ' Consider calling `chex.block_until_chexify_assertions_complete()` at '
        'the end of computations that rely on jitted chex assetions.')
    block_until_chexify_assertions_complete()


# Public API.
ChexifyChecks = _ChexifyChecks()


def chexify(
    fn: Callable[..., Any],
    async_check: bool = True,
    errors: FrozenSet[checkify.ErrorCategory] = ChexifyChecks.user,
) -> Callable[..., Any]:
  """Wraps a transformed function `fn` to enable Chex value assertions.

  Chex value/runtime assertions access concrete values of tensors (e.g.
  `assert_tree_all_finite`) which are not available during JAX tracing, see
  https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
  and
  https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#ConcretizationTypeError.

  This wrapper enables them in jitted/pmapped functions by performing a
  specifically designed JAX transformation
  https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#the-checkify-transformation
  and calling functionalised checks
  https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.checkify.check.html

  Example:

  .. code::

    @chex.chexify
    @jax.jit
    def logp1_abs_safe(x: chex.Array) -> chex.Array:
      chex.assert_tree_all_finite(x)
      return jnp.log(jnp.abs(x) + 1)

    logp1_abs_safe(jnp.ones(2))  # OK
    logp1_abs_safe(jnp.array([jnp.nan, 3]))  # FAILS
    logp1_abs_safe.wait_checks()

  Note 1: This wrapper allows identifying the first failed assertion in a jitted
  code by printing a pointer to the line where the failed assertion was invoked.
  For getting verbose messages (including concrete tensor values), an unjitted
  version of the code will need to be executed with the same input values. Chex
  does not currently provide tools to help with this.

  Note 2: This wrapper fully supports asynchronous executions
  (see https://jax.readthedocs.io/en/latest/async_dispatch.html).
  To block program execution until asynchronous checks for a _chexified_
  function `fn` complete, call `fn.wait_checks()`. Similarly,
  `chex.block_until_chexify_assertions_complete()` will block program execution
  until _all_ asyncronous checks complete.

  Note 3: Chex automatically selects the backend for executing its assertions
  (i.e. CPU or device accelerator) depending on the program context.

  Note 4: Value assertions can have impact on the performance of a function, see
  https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html#limitations

  Note 5: static assertions, such as `assert_shape` or
  `assert_trees_all_equal_dtypes`, can be called from a jitted function without
  `chexify` wrapper (since they do not access concrete values, only
  shapes and/or dtypes which are available during JAX tracing).

  More examples can be found at
  https://github.com/deepmind/chex/blob/master/chex/_src/asserts_chexify_test.py

  Args:
    fn: A transformed function to wrap.
    async_check: Whether to check errors in the async dispatch mode. See
      https://jax.readthedocs.io/en/latest/async_dispatch.html.
    errors: A set of `checkify.ErrorCategory` values which defines the set of
      enabled checks. By default only explicit ``checks`` are enabled (`user`).
      You can also for example enable NaN and Div-by-0 errors by passing the
      `float` set, or for example combine multiple sets through set
      operations (`float | user`).

  Returns:
    A _chexified_ function, i.e. the one with enabled value assertions.
    The returned function has `wait_checks()` method that blocks the caller
    until all pending async checks complete.
  """
  # Hardware/XLA failures can only happen on the C++ side. They are expected to
  # issue critical errors that will immediately crash the whole program.
  # Nevertheless, Chex sets its own timeout for every chexified XLA comp. to
  # ensure that a program never blocks on Chex side when running in async mode.
  async_timeout = 1800  # 30 minutes

  # Get function name.
  if isinstance(fn, functools.partial):
    func_name = fn.func.__name__
  else:
    func_name = fn.__name__

  if async_check:
    # Spawn a thread for processing blocking calls.
    thread_pool = futures.ThreadPoolExecutor(1, f'async_chex_{func_name}')
    # A deque for futures.
    async_check_futures = collections.deque()

  # Checkification.
  checkified_fn = checkify.checkify(fn, errors=errors)

  @functools.wraps(fn)
  def _chexified_fn(*args, **kwargs):
    if _ai.CHEXIFY_STORAGE.level:
      raise RuntimeError(
          'Nested @chexify wrapping is disallowed. '
          'Make sure that you only wrap the function at the outermost level.')

    if _ai.has_tracers((args, kwargs)):
      raise RuntimeError(
          '@chexify must be applied on top of all (p)jit/pmap transformations'
          ' (otherwise it will result in `UnexpectedTracerError`). If you have'
          ' functions that use value assertions, do not wrap them'
          ' individually -- just wrap the outermost function after'
          ' applying all your JAX transformations. See the example at '
          'https://github.com/google-deepmind/chex#static-and-value-aka-runtime-assertions'  # pylint:disable=line-too-long
      )

    if async_check:
      # Check completed calls.
      while async_check_futures and async_check_futures[0].done():
        _check_error(async_check_futures.popleft().result(async_timeout))

    # Run the checkified function.
    _ai.CHEXIFY_STORAGE.level += 1
    try:
      err, out = checkified_fn(*args, **kwargs)
    finally:
      _ai.CHEXIFY_STORAGE.level -= 1

    # Check errors.
    if async_check:
      # Blocking call is deferred to the thread.
      async_check_futures.append(
          thread_pool.submit(lambda: jax.device_get(err)))
    else:
      # Blocks until `fn`'s outputs are ready.
      _check_error(err)

    return out

  def _wait_checks():
    if async_check:
      while async_check_futures:
        _check_error(async_check_futures.popleft().result(async_timeout))

  # Add a barrier callback to the global storage.
  _ai.CHEXIFY_STORAGE.wait_fns.append(_wait_checks)

  # Add the callback to the chexified funtion's properties.
  if not hasattr(_chexified_fn, 'wait_checks'):
    _chexified_fn.wait_checks = _wait_checks
  else:
    logging.warning(
        "Function %s already defines 'wait_checks' method; "
        'Chex will not redefine it.', func_name)

  return _chexified_fn


def with_jittable_assertions(fn: Callable[..., Any],
                             async_check: bool = True) -> Callable[..., Any]:
  """An alias for `chexify` (see the docs)."""
  return chexify(fn, async_check)