Spaces:
Sleeping
Sleeping
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
A collection of utilities for ensuring that training can always occur. Heavily influenced by the | |
[toma](https://github.com/BlackHC/toma) library. | |
""" | |
import functools | |
import gc | |
import inspect | |
import torch | |
from .imports import is_npu_available, is_xpu_available | |
def release_memory(*objects): | |
""" | |
Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`. | |
Returned objects should be reassigned to the same variables. | |
Args: | |
objects (`Iterable`): | |
An iterable of objects | |
Returns: | |
A list of `None` objects to replace `objects` | |
Example: | |
```python | |
>>> import torch | |
>>> from accelerate.utils import release_memory | |
>>> a = torch.ones(1000, 1000).cuda() | |
>>> b = torch.ones(1000, 1000).cuda() | |
>>> a, b = release_memory(a, b) | |
``` | |
""" | |
if not isinstance(objects, list): | |
objects = list(objects) | |
for i in range(len(objects)): | |
objects[i] = None | |
gc.collect() | |
if is_xpu_available(): | |
torch.xpu.empty_cache() | |
elif is_npu_available(): | |
torch.npu.empty_cache() | |
else: | |
torch.cuda.empty_cache() | |
return objects | |
def should_reduce_batch_size(exception: Exception) -> bool: | |
""" | |
Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory | |
Args: | |
exception (`Exception`): | |
An exception | |
""" | |
_statements = [ | |
"CUDA out of memory.", # CUDA OOM | |
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU | |
"DefaultCPUAllocator: can't allocate memory", # CPU OOM | |
] | |
if isinstance(exception, RuntimeError) and len(exception.args) == 1: | |
return any(err in exception.args[0] for err in _statements) | |
return False | |
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128): | |
""" | |
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or | |
CUDNN, the batch size is cut in half and passed to `function` | |
`function` must take in a `batch_size` parameter as its first argument. | |
Args: | |
function (`callable`, *optional*): | |
A function to wrap | |
starting_batch_size (`int`, *optional*): | |
The batch size to try and fit into memory | |
Example: | |
```python | |
>>> from accelerate.utils import find_executable_batch_size | |
>>> @find_executable_batch_size(starting_batch_size=128) | |
... def train(batch_size, model, optimizer): | |
... ... | |
>>> train(model, optimizer) | |
``` | |
""" | |
if function is None: | |
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) | |
batch_size = starting_batch_size | |
def decorator(*args, **kwargs): | |
nonlocal batch_size | |
gc.collect() | |
if is_xpu_available(): | |
torch.xpu.empty_cache() | |
elif is_npu_available(): | |
torch.npu.empty_cache() | |
else: | |
torch.cuda.empty_cache() | |
params = list(inspect.signature(function).parameters.keys()) | |
# Guard against user error | |
if len(params) < (len(args) + 1): | |
arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) | |
raise TypeError( | |
f"Batch size was passed into `{function.__name__}` as the first argument when called." | |
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" | |
) | |
while True: | |
if batch_size == 0: | |
raise RuntimeError("No executable batch size found, reached zero.") | |
try: | |
return function(batch_size, *args, **kwargs) | |
except Exception as e: | |
if should_reduce_batch_size(e): | |
gc.collect() | |
if is_xpu_available(): | |
torch.xpu.empty_cache() | |
elif is_npu_available(): | |
torch.npu.empty_cache() | |
else: | |
torch.cuda.empty_cache() | |
batch_size //= 2 | |
else: | |
raise | |
return decorator | |