File size: 15,760 Bytes
22a452a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union

import torch

from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")


@dataclass
class PyramidAttentionBroadcastConfig:
    r"""
    Configuration for Pyramid Attention Broadcast.

    Args:
        spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
            The number of times a specific spatial attention broadcast is skipped before computing the attention states
            to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
            old attention states will be re-used) before computing the new attention states again.
        temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
            The number of times a specific temporal attention broadcast is skipped before computing the attention
            states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
            (i.e., old attention states will be re-used) before computing the new attention states again.
        cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
            The number of times a specific cross-attention broadcast is skipped before computing the attention states
            to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
            old attention states will be re-used) before computing the new attention states again.
        spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
            The range of timesteps to skip in the spatial attention layer. The attention computations will be
            conditionally skipped if the current timestep is within the specified range.
        temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
            The range of timesteps to skip in the temporal attention layer. The attention computations will be
            conditionally skipped if the current timestep is within the specified range.
        cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
            The range of timesteps to skip in the cross-attention layer. The attention computations will be
            conditionally skipped if the current timestep is within the specified range.
        spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
            The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
        temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
            The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
        cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
            The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
    """

    spatial_attention_block_skip_range: Optional[int] = None
    temporal_attention_block_skip_range: Optional[int] = None
    cross_attention_block_skip_range: Optional[int] = None

    spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
    temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
    cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)

    spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
    temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
    cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS

    current_timestep_callback: Callable[[], int] = None

    # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
    # so not added for now)

    def __repr__(self) -> str:
        return (
            f"PyramidAttentionBroadcastConfig(\n"
            f"  spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
            f"  temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
            f"  cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
            f"  spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
            f"  temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
            f"  cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
            f"  spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
            f"  temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
            f"  cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
            f"  current_timestep_callback={self.current_timestep_callback}\n"
            ")"
        )


class PyramidAttentionBroadcastState:
    r"""
    State for Pyramid Attention Broadcast.

    Attributes:
        iteration (`int`):
            The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
            called before starting a new inference forward pass for PAB to work correctly.
        cache (`Any`):
            The cached output from the previous forward pass. This is used to re-use the attention states when the
            attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
    """

    def __init__(self) -> None:
        self.iteration = 0
        self.cache = None

    def reset(self):
        self.iteration = 0
        self.cache = None

    def __repr__(self):
        cache_repr = ""
        if self.cache is None:
            cache_repr = "None"
        else:
            cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
        return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"


class PyramidAttentionBroadcastHook(ModelHook):
    r"""A hook that applies Pyramid Attention Broadcast to a given module."""

    _is_stateful = True

    def __init__(
        self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
    ) -> None:
        super().__init__()

        self.timestep_skip_range = timestep_skip_range
        self.block_skip_range = block_skip_range
        self.current_timestep_callback = current_timestep_callback

    def initialize_hook(self, module):
        self.state = PyramidAttentionBroadcastState()
        return module

    def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
        is_within_timestep_range = (
            self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
        )
        should_compute_attention = (
            self.state.cache is None
            or self.state.iteration == 0
            or not is_within_timestep_range
            or self.state.iteration % self.block_skip_range == 0
        )

        if should_compute_attention:
            output = self.fn_ref.original_forward(*args, **kwargs)
        else:
            output = self.state.cache

        self.state.cache = output
        self.state.iteration += 1
        return output

    def reset_state(self, module: torch.nn.Module) -> None:
        self.state.reset()
        return module


def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
    r"""
    Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.

    PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
    reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
    similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
    spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
    than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.

    Args:
        module (`torch.nn.Module`):
            The module to apply Pyramid Attention Broadcast to.
        config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
            The configuration to use for Pyramid Attention Broadcast.

    Example:

    ```python
    >>> import torch
    >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
    >>> from diffusers.utils import export_to_video

    >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    >>> pipe.to("cuda")

    >>> config = PyramidAttentionBroadcastConfig(
    ...     spatial_attention_block_skip_range=2,
    ...     spatial_attention_timestep_skip_range=(100, 800),
    ...     current_timestep_callback=lambda: pipe.current_timestep,
    ... )
    >>> apply_pyramid_attention_broadcast(pipe.transformer, config)
    ```
    """
    if config.current_timestep_callback is None:
        raise ValueError(
            "The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
        )

    if (
        config.spatial_attention_block_skip_range is None
        and config.temporal_attention_block_skip_range is None
        and config.cross_attention_block_skip_range is None
    ):
        logger.warning(
            "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
            "or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
            "To avoid this warning, please set one of the above parameters."
        )
        config.spatial_attention_block_skip_range = 2

    for name, submodule in module.named_modules():
        if not isinstance(submodule, _ATTENTION_CLASSES):
            # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
            # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
            # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
            continue
        _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)


def _apply_pyramid_attention_broadcast_on_attention_class(
    name: str, module: Attention, config: PyramidAttentionBroadcastConfig
) -> bool:
    is_spatial_self_attention = (
        any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
        and config.spatial_attention_block_skip_range is not None
        and not getattr(module, "is_cross_attention", False)
    )
    is_temporal_self_attention = (
        any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
        and config.temporal_attention_block_skip_range is not None
        and not getattr(module, "is_cross_attention", False)
    )
    is_cross_attention = (
        any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
        and config.cross_attention_block_skip_range is not None
        and getattr(module, "is_cross_attention", False)
    )

    block_skip_range, timestep_skip_range, block_type = None, None, None
    if is_spatial_self_attention:
        block_skip_range = config.spatial_attention_block_skip_range
        timestep_skip_range = config.spatial_attention_timestep_skip_range
        block_type = "spatial"
    elif is_temporal_self_attention:
        block_skip_range = config.temporal_attention_block_skip_range
        timestep_skip_range = config.temporal_attention_timestep_skip_range
        block_type = "temporal"
    elif is_cross_attention:
        block_skip_range = config.cross_attention_block_skip_range
        timestep_skip_range = config.cross_attention_timestep_skip_range
        block_type = "cross"

    if block_skip_range is None or timestep_skip_range is None:
        logger.info(
            f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
            f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
            f"however, that this layer may still be valid for applying PAB. Please specify the correct "
            f"block identifiers in the configuration."
        )
        return False

    logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
    _apply_pyramid_attention_broadcast_hook(
        module, timestep_skip_range, block_skip_range, config.current_timestep_callback
    )
    return True


def _apply_pyramid_attention_broadcast_hook(
    module: Union[Attention, MochiAttention],
    timestep_skip_range: Tuple[int, int],
    block_skip_range: int,
    current_timestep_callback: Callable[[], int],
):
    r"""
    Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.

    Args:
        module (`torch.nn.Module`):
            The module to apply Pyramid Attention Broadcast to.
        timestep_skip_range (`Tuple[int, int]`):
            The range of timesteps to skip in the attention layer. The attention computations will be conditionally
            skipped if the current timestep is within the specified range.
        block_skip_range (`int`):
            The number of times a specific attention broadcast is skipped before computing the attention states to
            re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
            attention states will be re-used) before computing the new attention states again.
        current_timestep_callback (`Callable[[], int]`):
            A callback function that returns the current inference timestep.
    """
    registry = HookRegistry.check_if_exists_or_initialize(module)
    hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
    registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)