Spaces:
Building
Building
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A context manager that objects to JAX compilation for specified backends. | |
This is useful, for example, when certain JAX code needs to run in an | |
environment where an accelerator is present but reserved for other purposes. | |
Typically one would use `jax.jit(..., backend='cpu')` to keep the code away | |
from the accelerator, but it is hard to check by hand that this has been done | |
without exception throughout an entire subsystem. Then, `restrict_backends()` | |
can be used to detect any overlooked case and report it by raising an exception. | |
Similarly, it can be useful for a system such as a learner to make sure that | |
all required JAX programs have been assigned to their respective backends by | |
the end of its first iteration; this helps to show that it will not later run | |
into memory fragmentation problems. By entering a `restrict_backends()` context | |
at the end of the first iteration, the system can detect any overlooked cases. | |
""" | |
import contextlib | |
import functools | |
from typing import Optional, Sequence | |
# pylint: disable=g-import-not-at-top | |
try: | |
from jax._src import compiler | |
except ImportError: | |
# TODO(phawkins): remove this path after jax>=0.4.15 is the minimum version | |
# required by chex. | |
from jax._src import dispatch as compiler # type: ignore | |
# pylint: enable=g-import-not-at-top | |
class RestrictedBackendError(RuntimeError): | |
pass | |
def restrict_backends( | |
*, | |
allowed: Optional[Sequence[str]] = None, | |
forbidden: Optional[Sequence[str]] = None): | |
"""Disallows JAX compilation for certain backends. | |
Args: | |
allowed: Names of backend platforms (e.g. 'cpu' or 'tpu') for which | |
compilation is still to be permitted. | |
forbidden: Names of backend platforms for which compilation is to be | |
forbidden. | |
Yields: | |
None, in a context where compilation for forbidden platforms will raise | |
a `RestrictedBackendError`. | |
Raises: | |
ValueError: if neither `allowed` nor `forbidden` is specified (i.e. they | |
are both `None`), or if anything is both allowed and forbidden. | |
""" | |
allowed = tuple(allowed) if allowed is not None else None | |
forbidden = tuple(forbidden) if forbidden is not None else None | |
if allowed is None and forbidden is None: | |
raise ValueError('No restrictions specified.') | |
contradictions = set(allowed or ()) & set(forbidden or ()) | |
if contradictions: | |
raise ValueError( | |
f"Backends {contradictions} can't be both allowed and forbidden.") | |
def is_allowed(backend_platform): | |
return ((backend_platform in allowed) if allowed is not None else | |
(backend_platform not in forbidden)) | |
inner_backend_compile = compiler.backend_compile | |
def wrapper(backend, *args, **kwargs): | |
if not is_allowed(backend.platform): | |
raise RestrictedBackendError( | |
f'Compiling a JAX program for {backend.platform} is forbidden by ' | |
f'restrict_backends().') | |
return inner_backend_compile(backend, *args, **kwargs) | |
try: | |
compiler.backend_compile = wrapper | |
yield | |
finally: | |
backend_compile = compiler.backend_compile | |
assert backend_compile is wrapper, backend_compile | |
compiler.backend_compile = inner_backend_compile | |