Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import re | |
from dataclasses import dataclass | |
from typing import Any, Callable, Optional, Tuple, Union | |
import torch | |
from ..models.attention_processor import Attention, MochiAttention | |
from ..utils import logging | |
from .hooks import HookRegistry, ModelHook | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast" | |
_ATTENTION_CLASSES = (Attention, MochiAttention) | |
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") | |
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) | |
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") | |
class PyramidAttentionBroadcastConfig: | |
r""" | |
Configuration for Pyramid Attention Broadcast. | |
Args: | |
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`): | |
The number of times a specific spatial attention broadcast is skipped before computing the attention states | |
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., | |
old attention states will be re-used) before computing the new attention states again. | |
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): | |
The number of times a specific temporal attention broadcast is skipped before computing the attention | |
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times | |
(i.e., old attention states will be re-used) before computing the new attention states again. | |
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`): | |
The number of times a specific cross-attention broadcast is skipped before computing the attention states | |
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., | |
old attention states will be re-used) before computing the new attention states again. | |
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): | |
The range of timesteps to skip in the spatial attention layer. The attention computations will be | |
conditionally skipped if the current timestep is within the specified range. | |
temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): | |
The range of timesteps to skip in the temporal attention layer. The attention computations will be | |
conditionally skipped if the current timestep is within the specified range. | |
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`): | |
The range of timesteps to skip in the cross-attention layer. The attention computations will be | |
conditionally skipped if the current timestep is within the specified range. | |
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): | |
The identifiers to match against the layer names to determine if the layer is a spatial attention layer. | |
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`): | |
The identifiers to match against the layer names to determine if the layer is a temporal attention layer. | |
cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`): | |
The identifiers to match against the layer names to determine if the layer is a cross-attention layer. | |
""" | |
spatial_attention_block_skip_range: Optional[int] = None | |
temporal_attention_block_skip_range: Optional[int] = None | |
cross_attention_block_skip_range: Optional[int] = None | |
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800) | |
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800) | |
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800) | |
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS | |
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS | |
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS | |
current_timestep_callback: Callable[[], int] = None | |
# TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase | |
# so not added for now) | |
def __repr__(self) -> str: | |
return ( | |
f"PyramidAttentionBroadcastConfig(\n" | |
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" | |
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" | |
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n" | |
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" | |
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" | |
f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n" | |
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" | |
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" | |
f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n" | |
f" current_timestep_callback={self.current_timestep_callback}\n" | |
")" | |
) | |
class PyramidAttentionBroadcastState: | |
r""" | |
State for Pyramid Attention Broadcast. | |
Attributes: | |
iteration (`int`): | |
The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is | |
called before starting a new inference forward pass for PAB to work correctly. | |
cache (`Any`): | |
The cached output from the previous forward pass. This is used to re-use the attention states when the | |
attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module. | |
""" | |
def __init__(self) -> None: | |
self.iteration = 0 | |
self.cache = None | |
def reset(self): | |
self.iteration = 0 | |
self.cache = None | |
def __repr__(self): | |
cache_repr = "" | |
if self.cache is None: | |
cache_repr = "None" | |
else: | |
cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})" | |
return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})" | |
class PyramidAttentionBroadcastHook(ModelHook): | |
r"""A hook that applies Pyramid Attention Broadcast to a given module.""" | |
_is_stateful = True | |
def __init__( | |
self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int] | |
) -> None: | |
super().__init__() | |
self.timestep_skip_range = timestep_skip_range | |
self.block_skip_range = block_skip_range | |
self.current_timestep_callback = current_timestep_callback | |
def initialize_hook(self, module): | |
self.state = PyramidAttentionBroadcastState() | |
return module | |
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: | |
is_within_timestep_range = ( | |
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] | |
) | |
should_compute_attention = ( | |
self.state.cache is None | |
or self.state.iteration == 0 | |
or not is_within_timestep_range | |
or self.state.iteration % self.block_skip_range == 0 | |
) | |
if should_compute_attention: | |
output = self.fn_ref.original_forward(*args, **kwargs) | |
else: | |
output = self.state.cache | |
self.state.cache = output | |
self.state.iteration += 1 | |
return output | |
def reset_state(self, module: torch.nn.Module) -> None: | |
self.state.reset() | |
return module | |
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig): | |
r""" | |
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. | |
PAB is an attention approximation method that leverages the similarity in attention states between timesteps to | |
reduce the computational cost of attention computation. The key takeaway from the paper is that the attention | |
similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and | |
spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently | |
than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process. | |
Args: | |
module (`torch.nn.Module`): | |
The module to apply Pyramid Attention Broadcast to. | |
config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`): | |
The configuration to use for Pyramid Attention Broadcast. | |
Example: | |
```python | |
>>> import torch | |
>>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast | |
>>> from diffusers.utils import export_to_video | |
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) | |
>>> pipe.to("cuda") | |
>>> config = PyramidAttentionBroadcastConfig( | |
... spatial_attention_block_skip_range=2, | |
... spatial_attention_timestep_skip_range=(100, 800), | |
... current_timestep_callback=lambda: pipe.current_timestep, | |
... ) | |
>>> apply_pyramid_attention_broadcast(pipe.transformer, config) | |
``` | |
""" | |
if config.current_timestep_callback is None: | |
raise ValueError( | |
"The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast." | |
) | |
if ( | |
config.spatial_attention_block_skip_range is None | |
and config.temporal_attention_block_skip_range is None | |
and config.cross_attention_block_skip_range is None | |
): | |
logger.warning( | |
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` " | |
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. " | |
"To avoid this warning, please set one of the above parameters." | |
) | |
config.spatial_attention_block_skip_range = 2 | |
for name, submodule in module.named_modules(): | |
if not isinstance(submodule, _ATTENTION_CLASSES): | |
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB | |
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement | |
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. | |
continue | |
_apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config) | |
def _apply_pyramid_attention_broadcast_on_attention_class( | |
name: str, module: Attention, config: PyramidAttentionBroadcastConfig | |
) -> bool: | |
is_spatial_self_attention = ( | |
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) | |
and config.spatial_attention_block_skip_range is not None | |
and not getattr(module, "is_cross_attention", False) | |
) | |
is_temporal_self_attention = ( | |
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) | |
and config.temporal_attention_block_skip_range is not None | |
and not getattr(module, "is_cross_attention", False) | |
) | |
is_cross_attention = ( | |
any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers) | |
and config.cross_attention_block_skip_range is not None | |
and getattr(module, "is_cross_attention", False) | |
) | |
block_skip_range, timestep_skip_range, block_type = None, None, None | |
if is_spatial_self_attention: | |
block_skip_range = config.spatial_attention_block_skip_range | |
timestep_skip_range = config.spatial_attention_timestep_skip_range | |
block_type = "spatial" | |
elif is_temporal_self_attention: | |
block_skip_range = config.temporal_attention_block_skip_range | |
timestep_skip_range = config.temporal_attention_timestep_skip_range | |
block_type = "temporal" | |
elif is_cross_attention: | |
block_skip_range = config.cross_attention_block_skip_range | |
timestep_skip_range = config.cross_attention_timestep_skip_range | |
block_type = "cross" | |
if block_skip_range is None or timestep_skip_range is None: | |
logger.info( | |
f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does ' | |
f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, " | |
f"however, that this layer may still be valid for applying PAB. Please specify the correct " | |
f"block identifiers in the configuration." | |
) | |
return False | |
logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}") | |
_apply_pyramid_attention_broadcast_hook( | |
module, timestep_skip_range, block_skip_range, config.current_timestep_callback | |
) | |
return True | |
def _apply_pyramid_attention_broadcast_hook( | |
module: Union[Attention, MochiAttention], | |
timestep_skip_range: Tuple[int, int], | |
block_skip_range: int, | |
current_timestep_callback: Callable[[], int], | |
): | |
r""" | |
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module. | |
Args: | |
module (`torch.nn.Module`): | |
The module to apply Pyramid Attention Broadcast to. | |
timestep_skip_range (`Tuple[int, int]`): | |
The range of timesteps to skip in the attention layer. The attention computations will be conditionally | |
skipped if the current timestep is within the specified range. | |
block_skip_range (`int`): | |
The number of times a specific attention broadcast is skipped before computing the attention states to | |
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old | |
attention states will be re-used) before computing the new attention states again. | |
current_timestep_callback (`Callable[[], int]`): | |
A callback function that returns the current inference timestep. | |
""" | |
registry = HookRegistry.check_if_exists_or_initialize(module) | |
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback) | |
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK) | |