PTWZ's picture
Upload folder using huggingface_hub
f5f3483 verified
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A context manager that objects to JAX compilation for specified backends.
This is useful, for example, when certain JAX code needs to run in an
environment where an accelerator is present but reserved for other purposes.
Typically one would use `jax.jit(..., backend='cpu')` to keep the code away
from the accelerator, but it is hard to check by hand that this has been done
without exception throughout an entire subsystem. Then, `restrict_backends()`
can be used to detect any overlooked case and report it by raising an exception.
Similarly, it can be useful for a system such as a learner to make sure that
all required JAX programs have been assigned to their respective backends by
the end of its first iteration; this helps to show that it will not later run
into memory fragmentation problems. By entering a `restrict_backends()` context
at the end of the first iteration, the system can detect any overlooked cases.
"""
import contextlib
import functools
from typing import Optional, Sequence
# pylint: disable=g-import-not-at-top
try:
from jax._src import compiler
except ImportError:
# TODO(phawkins): remove this path after jax>=0.4.15 is the minimum version
# required by chex.
from jax._src import dispatch as compiler # type: ignore
# pylint: enable=g-import-not-at-top
class RestrictedBackendError(RuntimeError):
pass
@contextlib.contextmanager
def restrict_backends(
*,
allowed: Optional[Sequence[str]] = None,
forbidden: Optional[Sequence[str]] = None):
"""Disallows JAX compilation for certain backends.
Args:
allowed: Names of backend platforms (e.g. 'cpu' or 'tpu') for which
compilation is still to be permitted.
forbidden: Names of backend platforms for which compilation is to be
forbidden.
Yields:
None, in a context where compilation for forbidden platforms will raise
a `RestrictedBackendError`.
Raises:
ValueError: if neither `allowed` nor `forbidden` is specified (i.e. they
are both `None`), or if anything is both allowed and forbidden.
"""
allowed = tuple(allowed) if allowed is not None else None
forbidden = tuple(forbidden) if forbidden is not None else None
if allowed is None and forbidden is None:
raise ValueError('No restrictions specified.')
contradictions = set(allowed or ()) & set(forbidden or ())
if contradictions:
raise ValueError(
f"Backends {contradictions} can't be both allowed and forbidden.")
def is_allowed(backend_platform):
return ((backend_platform in allowed) if allowed is not None else
(backend_platform not in forbidden))
inner_backend_compile = compiler.backend_compile
@functools.wraps(inner_backend_compile)
def wrapper(backend, *args, **kwargs):
if not is_allowed(backend.platform):
raise RestrictedBackendError(
f'Compiling a JAX program for {backend.platform} is forbidden by '
f'restrict_backends().')
return inner_backend_compile(backend, *args, **kwargs)
try:
compiler.backend_compile = wrapper
yield
finally:
backend_compile = compiler.backend_compile
assert backend_compile is wrapper, backend_compile
compiler.backend_compile = inner_backend_compile