Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from contextlib import contextmanager, nullcontext | |
from typing import Dict, List, Optional, Set, Tuple, Union | |
import torch | |
from ..utils import get_logger, is_accelerate_available | |
from .hooks import HookRegistry, ModelHook | |
if is_accelerate_available(): | |
from accelerate.hooks import AlignDevicesHook, CpuOffload | |
from accelerate.utils import send_to_device | |
logger = get_logger(__name__) # pylint: disable=invalid-name | |
# fmt: off | |
_GROUP_OFFLOADING = "group_offloading" | |
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker" | |
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" | |
_SUPPORTED_PYTORCH_LAYERS = ( | |
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, | |
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, | |
torch.nn.Linear, | |
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX | |
# because of double invocation of the same norm layer in CogVideoXLayerNorm | |
) | |
# fmt: on | |
class ModuleGroup: | |
def __init__( | |
self, | |
modules: List[torch.nn.Module], | |
offload_device: torch.device, | |
onload_device: torch.device, | |
offload_leader: torch.nn.Module, | |
onload_leader: Optional[torch.nn.Module] = None, | |
parameters: Optional[List[torch.nn.Parameter]] = None, | |
buffers: Optional[List[torch.Tensor]] = None, | |
non_blocking: bool = False, | |
stream: Union[torch.cuda.Stream, torch.Stream, None] = None, | |
record_stream: Optional[bool] = False, | |
low_cpu_mem_usage: bool = False, | |
onload_self: bool = True, | |
) -> None: | |
self.modules = modules | |
self.offload_device = offload_device | |
self.onload_device = onload_device | |
self.offload_leader = offload_leader | |
self.onload_leader = onload_leader | |
self.parameters = parameters or [] | |
self.buffers = buffers or [] | |
self.non_blocking = non_blocking or stream is not None | |
self.stream = stream | |
self.record_stream = record_stream | |
self.onload_self = onload_self | |
self.low_cpu_mem_usage = low_cpu_mem_usage | |
self.cpu_param_dict = self._init_cpu_param_dict() | |
if self.stream is None and self.record_stream: | |
raise ValueError("`record_stream` cannot be True when `stream` is None.") | |
def _init_cpu_param_dict(self): | |
cpu_param_dict = {} | |
if self.stream is None: | |
return cpu_param_dict | |
for module in self.modules: | |
for param in module.parameters(): | |
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() | |
for buffer in module.buffers(): | |
cpu_param_dict[buffer] = ( | |
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() | |
) | |
for param in self.parameters: | |
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() | |
for buffer in self.buffers: | |
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() | |
return cpu_param_dict | |
def _pinned_memory_tensors(self): | |
pinned_dict = {} | |
try: | |
for param, tensor in self.cpu_param_dict.items(): | |
if not tensor.is_pinned(): | |
pinned_dict[param] = tensor.pin_memory() | |
else: | |
pinned_dict[param] = tensor | |
yield pinned_dict | |
finally: | |
pinned_dict = None | |
def onload_(self): | |
r"""Onloads the group of modules to the onload_device.""" | |
torch_accelerator_module = ( | |
getattr(torch, torch.accelerator.current_accelerator().type) | |
if hasattr(torch, "accelerator") | |
else torch.cuda | |
) | |
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) | |
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None | |
if self.stream is not None: | |
# Wait for previous Host->Device transfer to complete | |
self.stream.synchronize() | |
with context: | |
if self.stream is not None: | |
with self._pinned_memory_tensors() as pinned_memory: | |
for group_module in self.modules: | |
for param in group_module.parameters(): | |
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) | |
if self.record_stream: | |
param.data.record_stream(current_stream) | |
for buffer in group_module.buffers(): | |
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) | |
if self.record_stream: | |
buffer.data.record_stream(current_stream) | |
for param in self.parameters: | |
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) | |
if self.record_stream: | |
param.data.record_stream(current_stream) | |
for buffer in self.buffers: | |
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) | |
if self.record_stream: | |
buffer.data.record_stream(current_stream) | |
else: | |
for group_module in self.modules: | |
for param in group_module.parameters(): | |
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) | |
for buffer in group_module.buffers(): | |
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) | |
for param in self.parameters: | |
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) | |
for buffer in self.buffers: | |
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) | |
if self.record_stream: | |
buffer.data.record_stream(current_stream) | |
def offload_(self): | |
r"""Offloads the group of modules to the offload_device.""" | |
torch_accelerator_module = ( | |
getattr(torch, torch.accelerator.current_accelerator().type) | |
if hasattr(torch, "accelerator") | |
else torch.cuda | |
) | |
if self.stream is not None: | |
if not self.record_stream: | |
torch_accelerator_module.current_stream().synchronize() | |
for group_module in self.modules: | |
for param in group_module.parameters(): | |
param.data = self.cpu_param_dict[param] | |
for param in self.parameters: | |
param.data = self.cpu_param_dict[param] | |
for buffer in self.buffers: | |
buffer.data = self.cpu_param_dict[buffer] | |
else: | |
for group_module in self.modules: | |
group_module.to(self.offload_device, non_blocking=self.non_blocking) | |
for param in self.parameters: | |
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) | |
for buffer in self.buffers: | |
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) | |
class GroupOffloadingHook(ModelHook): | |
r""" | |
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for | |
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader" | |
module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module | |
group is responsible for onloading the current module group. | |
""" | |
_is_stateful = False | |
def __init__( | |
self, | |
group: ModuleGroup, | |
next_group: Optional[ModuleGroup] = None, | |
) -> None: | |
self.group = group | |
self.next_group = next_group | |
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | |
if self.group.offload_leader == module: | |
self.group.offload_() | |
return module | |
def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | |
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward | |
# method is the onload_leader of the group. | |
if self.group.onload_leader is None: | |
self.group.onload_leader = module | |
# If the current module is the onload_leader of the group, we onload the group if it is supposed | |
# to onload itself. In the case of using prefetching with streams, we onload the next group if | |
# it is not supposed to onload itself. | |
if self.group.onload_leader == module: | |
if self.group.onload_self: | |
self.group.onload_() | |
if self.next_group is not None and not self.next_group.onload_self: | |
self.next_group.onload_() | |
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) | |
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) | |
return args, kwargs | |
def post_forward(self, module: torch.nn.Module, output): | |
if self.group.offload_leader == module: | |
self.group.offload_() | |
return output | |
class LazyPrefetchGroupOffloadingHook(ModelHook): | |
r""" | |
A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. | |
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer | |
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows | |
prefetching groups in the correct order. | |
""" | |
_is_stateful = False | |
def __init__(self): | |
self.execution_order: List[Tuple[str, torch.nn.Module]] = [] | |
self._layer_execution_tracker_module_names = set() | |
def initialize_hook(self, module): | |
def make_execution_order_update_callback(current_name, current_submodule): | |
def callback(): | |
logger.debug(f"Adding {current_name} to the execution order") | |
self.execution_order.append((current_name, current_submodule)) | |
return callback | |
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any | |
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the | |
# layers are executed during the forward pass. | |
for name, submodule in module.named_modules(): | |
if name == "" or not hasattr(submodule, "_diffusers_hook"): | |
continue | |
registry = HookRegistry.check_if_exists_or_initialize(submodule) | |
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING) | |
if group_offloading_hook is not None: | |
# For the first forward pass, we have to load in a blocking manner | |
group_offloading_hook.group.non_blocking = False | |
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule)) | |
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER) | |
self._layer_execution_tracker_module_names.add(name) | |
return module | |
def post_forward(self, module, output): | |
# At this point, for the current modules' submodules, we know the execution order of the layers. We can now | |
# remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each | |
# group offloading hook. | |
num_executed = len(self.execution_order) | |
execution_order_module_names = {name for name, _ in self.execution_order} | |
# It may be possible that some layers were not executed during the forward pass. This can happen if the layer | |
# is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we | |
# may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors | |
# if the missing layers end up being executed in the future. | |
if execution_order_module_names != self._layer_execution_tracker_module_names: | |
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names) | |
logger.warning( | |
"It seems like some layers were not executed during the forward pass. This may lead to problems when " | |
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please " | |
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n" | |
f"{unexecuted_layers=}" | |
) | |
# Remove the layer execution tracker hooks from the submodules | |
base_module_registry = module._diffusers_hook | |
registries = [submodule._diffusers_hook for _, submodule in self.execution_order] | |
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] | |
for i in range(num_executed): | |
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False) | |
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass | |
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False) | |
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True. | |
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to | |
# see the benefits of prefetching. | |
for hook in group_offloading_hooks: | |
hook.group.non_blocking = True | |
# Set required attributes for prefetching | |
if num_executed > 0: | |
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) | |
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group | |
base_module_group_offloading_hook.next_group.onload_self = False | |
for i in range(num_executed - 1): | |
name1, _ = self.execution_order[i] | |
name2, _ = self.execution_order[i + 1] | |
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}") | |
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group | |
group_offloading_hooks[i].next_group.onload_self = False | |
return output | |
class LayerExecutionTrackerHook(ModelHook): | |
r""" | |
A hook that tracks the order in which the layers are executed during the forward pass by calling back to the | |
LazyPrefetchGroupOffloadingHook to update the execution order. | |
""" | |
_is_stateful = False | |
def __init__(self, execution_order_update_callback): | |
self.execution_order_update_callback = execution_order_update_callback | |
def pre_forward(self, module, *args, **kwargs): | |
self.execution_order_update_callback() | |
return args, kwargs | |
def apply_group_offloading( | |
module: torch.nn.Module, | |
onload_device: torch.device, | |
offload_device: torch.device = torch.device("cpu"), | |
offload_type: str = "block_level", | |
num_blocks_per_group: Optional[int] = None, | |
non_blocking: bool = False, | |
use_stream: bool = False, | |
record_stream: bool = False, | |
low_cpu_mem_usage: bool = False, | |
) -> None: | |
r""" | |
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and | |
where it is beneficial, we need to first provide some context on how other supported offloading methods work. | |
Typically, offloading is done at two levels: | |
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It | |
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device | |
when needed for computation. This method is more memory-efficient than keeping all components on the accelerator, | |
but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of | |
the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward | |
pass. | |
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It | |
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and | |
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator | |
memory, but can be slower due to the excessive number of device synchronizations. | |
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers, | |
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level | |
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is | |
reduced. | |
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to | |
overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This | |
is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to | |
the accelerator device while the current layer is being executed - this increases the memory requirements slightly. | |
Note that this implementation also supports leaf-level offloading but can be made much faster when using streams. | |
Args: | |
module (`torch.nn.Module`): | |
The module to which group offloading is applied. | |
onload_device (`torch.device`): | |
The device to which the group of modules are onloaded. | |
offload_device (`torch.device`, defaults to `torch.device("cpu")`): | |
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. | |
offload_type (`str`, defaults to "block_level"): | |
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is | |
"block_level". | |
num_blocks_per_group (`int`, *optional*): | |
The number of blocks per group when using offload_type="block_level". This is required when using | |
offload_type="block_level". | |
non_blocking (`bool`, defaults to `False`): | |
If True, offloading and onloading is done with non-blocking data transfer. | |
use_stream (`bool`, defaults to `False`): | |
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for | |
overlapping computation and data transfer. | |
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor | |
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the | |
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more | |
details. | |
low_cpu_mem_usage (`bool`, defaults to `False`): | |
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This | |
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when | |
the CPU memory is a bottleneck but may counteract the benefits of using streams. | |
Example: | |
```python | |
>>> from diffusers import CogVideoXTransformer3DModel | |
>>> from diffusers.hooks import apply_group_offloading | |
>>> transformer = CogVideoXTransformer3DModel.from_pretrained( | |
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 | |
... ) | |
>>> apply_group_offloading( | |
... transformer, | |
... onload_device=torch.device("cuda"), | |
... offload_device=torch.device("cpu"), | |
... offload_type="block_level", | |
... num_blocks_per_group=2, | |
... use_stream=True, | |
... ) | |
``` | |
""" | |
stream = None | |
if use_stream: | |
if torch.cuda.is_available(): | |
stream = torch.cuda.Stream() | |
elif hasattr(torch, "xpu") and torch.xpu.is_available(): | |
stream = torch.Stream() | |
else: | |
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") | |
_raise_error_if_accelerate_model_or_sequential_hook_present(module) | |
if offload_type == "block_level": | |
if num_blocks_per_group is None: | |
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") | |
_apply_group_offloading_block_level( | |
module=module, | |
num_blocks_per_group=num_blocks_per_group, | |
offload_device=offload_device, | |
onload_device=onload_device, | |
non_blocking=non_blocking, | |
stream=stream, | |
record_stream=record_stream, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
) | |
elif offload_type == "leaf_level": | |
_apply_group_offloading_leaf_level( | |
module=module, | |
offload_device=offload_device, | |
onload_device=onload_device, | |
non_blocking=non_blocking, | |
stream=stream, | |
record_stream=record_stream, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
) | |
else: | |
raise ValueError(f"Unsupported offload_type: {offload_type}") | |
def _apply_group_offloading_block_level( | |
module: torch.nn.Module, | |
num_blocks_per_group: int, | |
offload_device: torch.device, | |
onload_device: torch.device, | |
non_blocking: bool, | |
stream: Union[torch.cuda.Stream, torch.Stream, None] = None, | |
record_stream: Optional[bool] = False, | |
low_cpu_mem_usage: bool = False, | |
) -> None: | |
r""" | |
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to | |
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. | |
Args: | |
module (`torch.nn.Module`): | |
The module to which group offloading is applied. | |
offload_device (`torch.device`): | |
The device to which the group of modules are offloaded. This should typically be the CPU. | |
onload_device (`torch.device`): | |
The device to which the group of modules are onloaded. | |
non_blocking (`bool`): | |
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation | |
and data transfer. | |
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*): | |
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful | |
for overlapping computation and data transfer. | |
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor | |
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the | |
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more | |
details. | |
low_cpu_mem_usage (`bool`, defaults to `False`): | |
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This | |
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when | |
the CPU memory is a bottleneck but may counteract the benefits of using streams. | |
""" | |
if stream is not None and num_blocks_per_group != 1: | |
logger.warning( | |
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1." | |
) | |
num_blocks_per_group = 1 | |
# Create module groups for ModuleList and Sequential blocks | |
modules_with_group_offloading = set() | |
unmatched_modules = [] | |
matched_module_groups = [] | |
for name, submodule in module.named_children(): | |
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): | |
unmatched_modules.append((name, submodule)) | |
modules_with_group_offloading.add(name) | |
continue | |
for i in range(0, len(submodule), num_blocks_per_group): | |
current_modules = submodule[i : i + num_blocks_per_group] | |
group = ModuleGroup( | |
modules=current_modules, | |
offload_device=offload_device, | |
onload_device=onload_device, | |
offload_leader=current_modules[-1], | |
onload_leader=current_modules[0], | |
non_blocking=non_blocking, | |
stream=stream, | |
record_stream=record_stream, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
onload_self=True, | |
) | |
matched_module_groups.append(group) | |
for j in range(i, i + len(current_modules)): | |
modules_with_group_offloading.add(f"{name}.{j}") | |
# Apply group offloading hooks to the module groups | |
for i, group in enumerate(matched_module_groups): | |
for group_module in group.modules: | |
_apply_group_offloading_hook(group_module, group, None) | |
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately | |
# when the forward pass of this module is called. This is because the top-level module is not | |
# part of any group (as doing so would lead to no VRAM savings). | |
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) | |
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) | |
parameters = [param for _, param in parameters] | |
buffers = [buffer for _, buffer in buffers] | |
# Create a group for the unmatched submodules of the top-level module so that they are on the correct | |
# device when the forward pass is called. | |
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] | |
unmatched_group = ModuleGroup( | |
modules=unmatched_modules, | |
offload_device=offload_device, | |
onload_device=onload_device, | |
offload_leader=module, | |
onload_leader=module, | |
parameters=parameters, | |
buffers=buffers, | |
non_blocking=False, | |
stream=None, | |
record_stream=False, | |
onload_self=True, | |
) | |
if stream is None: | |
_apply_group_offloading_hook(module, unmatched_group, None) | |
else: | |
_apply_lazy_group_offloading_hook(module, unmatched_group, None) | |
def _apply_group_offloading_leaf_level( | |
module: torch.nn.Module, | |
offload_device: torch.device, | |
onload_device: torch.device, | |
non_blocking: bool, | |
stream: Union[torch.cuda.Stream, torch.Stream, None] = None, | |
record_stream: Optional[bool] = False, | |
low_cpu_mem_usage: bool = False, | |
) -> None: | |
r""" | |
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory | |
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device | |
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can | |
reduce memory usage without any performance degradation. | |
Args: | |
module (`torch.nn.Module`): | |
The module to which group offloading is applied. | |
offload_device (`torch.device`): | |
The device to which the group of modules are offloaded. This should typically be the CPU. | |
onload_device (`torch.device`): | |
The device to which the group of modules are onloaded. | |
non_blocking (`bool`): | |
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation | |
and data transfer. | |
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*): | |
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful | |
for overlapping computation and data transfer. | |
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor | |
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the | |
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more | |
details. | |
low_cpu_mem_usage (`bool`, defaults to `False`): | |
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This | |
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when | |
the CPU memory is a bottleneck but may counteract the benefits of using streams. | |
""" | |
# Create module groups for leaf modules and apply group offloading hooks | |
modules_with_group_offloading = set() | |
for name, submodule in module.named_modules(): | |
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): | |
continue | |
group = ModuleGroup( | |
modules=[submodule], | |
offload_device=offload_device, | |
onload_device=onload_device, | |
offload_leader=submodule, | |
onload_leader=submodule, | |
non_blocking=non_blocking, | |
stream=stream, | |
record_stream=record_stream, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
onload_self=True, | |
) | |
_apply_group_offloading_hook(submodule, group, None) | |
modules_with_group_offloading.add(name) | |
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass | |
# of the module is called | |
module_dict = dict(module.named_modules()) | |
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) | |
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) | |
# Find closest module parent for each parameter and buffer, and attach group hooks | |
parent_to_parameters = {} | |
for name, param in parameters: | |
parent_name = _find_parent_module_in_module_dict(name, module_dict) | |
if parent_name in parent_to_parameters: | |
parent_to_parameters[parent_name].append(param) | |
else: | |
parent_to_parameters[parent_name] = [param] | |
parent_to_buffers = {} | |
for name, buffer in buffers: | |
parent_name = _find_parent_module_in_module_dict(name, module_dict) | |
if parent_name in parent_to_buffers: | |
parent_to_buffers[parent_name].append(buffer) | |
else: | |
parent_to_buffers[parent_name] = [buffer] | |
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys()) | |
for name in parent_names: | |
parameters = parent_to_parameters.get(name, []) | |
buffers = parent_to_buffers.get(name, []) | |
parent_module = module_dict[name] | |
assert getattr(parent_module, "_diffusers_hook", None) is None | |
group = ModuleGroup( | |
modules=[], | |
offload_device=offload_device, | |
onload_device=onload_device, | |
offload_leader=parent_module, | |
onload_leader=parent_module, | |
parameters=parameters, | |
buffers=buffers, | |
non_blocking=non_blocking, | |
stream=stream, | |
record_stream=record_stream, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
onload_self=True, | |
) | |
_apply_group_offloading_hook(parent_module, group, None) | |
if stream is not None: | |
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer | |
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the | |
# execution order and apply prefetching in the correct order. | |
unmatched_group = ModuleGroup( | |
modules=[], | |
offload_device=offload_device, | |
onload_device=onload_device, | |
offload_leader=module, | |
onload_leader=module, | |
parameters=None, | |
buffers=None, | |
non_blocking=False, | |
stream=None, | |
record_stream=False, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
onload_self=True, | |
) | |
_apply_lazy_group_offloading_hook(module, unmatched_group, None) | |
def _apply_group_offloading_hook( | |
module: torch.nn.Module, | |
group: ModuleGroup, | |
next_group: Optional[ModuleGroup] = None, | |
) -> None: | |
registry = HookRegistry.check_if_exists_or_initialize(module) | |
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent | |
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook. | |
if registry.get_hook(_GROUP_OFFLOADING) is None: | |
hook = GroupOffloadingHook(group, next_group) | |
registry.register_hook(hook, _GROUP_OFFLOADING) | |
def _apply_lazy_group_offloading_hook( | |
module: torch.nn.Module, | |
group: ModuleGroup, | |
next_group: Optional[ModuleGroup] = None, | |
) -> None: | |
registry = HookRegistry.check_if_exists_or_initialize(module) | |
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent | |
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook. | |
if registry.get_hook(_GROUP_OFFLOADING) is None: | |
hook = GroupOffloadingHook(group, next_group) | |
registry.register_hook(hook, _GROUP_OFFLOADING) | |
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() | |
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) | |
def _gather_parameters_with_no_group_offloading_parent( | |
module: torch.nn.Module, modules_with_group_offloading: Set[str] | |
) -> List[torch.nn.Parameter]: | |
parameters = [] | |
for name, parameter in module.named_parameters(): | |
has_parent_with_group_offloading = False | |
atoms = name.split(".") | |
while len(atoms) > 0: | |
parent_name = ".".join(atoms) | |
if parent_name in modules_with_group_offloading: | |
has_parent_with_group_offloading = True | |
break | |
atoms.pop() | |
if not has_parent_with_group_offloading: | |
parameters.append((name, parameter)) | |
return parameters | |
def _gather_buffers_with_no_group_offloading_parent( | |
module: torch.nn.Module, modules_with_group_offloading: Set[str] | |
) -> List[torch.Tensor]: | |
buffers = [] | |
for name, buffer in module.named_buffers(): | |
has_parent_with_group_offloading = False | |
atoms = name.split(".") | |
while len(atoms) > 0: | |
parent_name = ".".join(atoms) | |
if parent_name in modules_with_group_offloading: | |
has_parent_with_group_offloading = True | |
break | |
atoms.pop() | |
if not has_parent_with_group_offloading: | |
buffers.append((name, buffer)) | |
return buffers | |
def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str: | |
atoms = name.split(".") | |
while len(atoms) > 0: | |
parent_name = ".".join(atoms) | |
if parent_name in module_dict: | |
return parent_name | |
atoms.pop() | |
return "" | |
def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None: | |
if not is_accelerate_available(): | |
return | |
for name, submodule in module.named_modules(): | |
if not hasattr(submodule, "_hf_hook"): | |
continue | |
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)): | |
raise ValueError( | |
f"Cannot apply group offloading to a module that is already applying an alternative " | |
f"offloading strategy from Accelerate. If you want to apply group offloading, please " | |
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})" | |
) | |
def _is_group_offload_enabled(module: torch.nn.Module) -> bool: | |
for submodule in module.modules(): | |
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: | |
return True | |
return False | |
def _get_group_onload_device(module: torch.nn.Module) -> torch.device: | |
for submodule in module.modules(): | |
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: | |
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device | |
raise ValueError("Group offloading is not enabled for the provided module.") | |