File size: 12,525 Bytes
c168216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf4d2a0
624c83b
c168216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624c83b
65e6ade
c168216
624c83b
65e6ade
 
c168216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624c83b
 
 
65e6ade
 
624c83b
65e6ade
 
624c83b
65e6ade
 
 
 
 
 
 
c168216
 
624c83b
c168216
 
624c83b
c168216
 
624c83b
c168216
624c83b
 
 
 
 
c168216
 
624c83b
c168216
624c83b
c168216
624c83b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c168216
 
 
 
 
624c83b
c168216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624c83b
c168216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Dict, Type, Any, Union

import numpy as np
import torch
import torch.nn as nn

from diffusers.models import (
    FluxTransformer2DModel,
    HunyuanVideoTransformer3DModel,
    LTXVideoTransformer3DModel,
    LuminaNextDiT2DModel,
    MochiTransformer3DModel,
)
from diffusers.models.hooks import ModelHook, add_hook_to_module
from diffusers.utils import logging
from diffusers.pipelines.pipeline_utils import DiffusionPipeline

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# Source: https://github.com/ali-vilab/TeaCache
# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use.
# fmt: off
_MODEL_TO_POLY_COEFFICIENTS = {
    FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
    HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02],
    LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03],
    LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344],
    MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03],
}
# fmt: on

_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = {
    FluxTransformer2DModel: 0.25,
    HunyuanVideoTransformer3DModel: 0.1,
    LTXVideoTransformer3DModel: 0.05,
    LuminaNextDiT2DModel: 0.2,
    MochiTransformer3DModel: 0.06,
}

_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = {
    FluxTransformer2DModel: "transformer_blocks.0.norm1",
}

_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = {
    FluxTransformer2DModel: "norm_out",
}

_DEFAULT_SKIP_LAYER_IDENTIFIERS = [
    "blocks",
    "transformer_blocks",
    "single_transformer_blocks",
    "temporal_transformer_blocks",
]


@dataclass
class TeaCacheConfig:
    l1_threshold: Optional[float] = None
    # Fixed with default_factory
    skip_layer_identifiers: List[str] = field(default_factory=lambda: _DEFAULT_SKIP_LAYER_IDENTIFIERS.copy())
    _polynomial_coefficients: Optional[List[float]] = None
    # Added missing field
    timestep_modulated_layer_identifier: Optional[str] = None


class TeaCacheDenoiserState:
    def __init__(self):
        self.iteration: int = 0
        self.accumulated_l1_difference: float = 0.0
        self.timestep_modulated_cache: torch.Tensor = None
        self.residual_cache: torch.Tensor = None
        self.should_skip_blocks: bool = False

    def reset(self):
        self.iteration = 0
        self.accumulated_l1_difference = 0.0
        self.timestep_modulated_cache = None
        self.residual_cache = None


def apply_teacache(
    pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None
) -> None:
    r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module.

    Args:
        TODO
    """

    if config is None:
        logger.warning("No TeaCacheConfig provided. Using default configuration.")
        config = TeaCacheConfig()

    if denoiser is None:
        denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet

    # Fix for isinstance() issue - convert to explicit type checking
    model_type = denoiser.__class__
    model_type_name = model_type.__name__
    logger.info(f"Detected model type: {model_type_name}")

    # Modified to avoid using isinstance with dict_keys
    supported_model_types = tuple(_MODEL_TO_POLY_COEFFICIENTS.keys())
    is_supported_model = False
    
    # Check if model_type is one of the supported types
    for supported_type in supported_model_types:
        if model_type == supported_type:
            is_supported_model = True
            break
    
    if is_supported_model:
        if config.l1_threshold is None:
            logger.info(
                f"No L1 threshold was provided for {model_type}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
                f"For higher speedup, increase the threshold."
            )
            config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[model_type]
        if config.timestep_modulated_layer_identifier is None:
            logger.info(
                f"No timestep modulated layer identifier was provided for {model_type}. Using default identifier as provided in the TeaCache paper."
            )
            if model_type in _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER:
                config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[model_type]
            else:
                # Use a default for model types without a specific identifier
                config.timestep_modulated_layer_identifier = "transformer_blocks.0.norm1"
        if config._polynomial_coefficients is None:
            logger.info(
                f"No polynomial coefficients were provided for {model_type}. Using default coefficients as provided in the TeaCache paper."
            )
            config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[model_type]
    else:
        # Handle unsupported model types with more helpful defaults for LTX models
        if model_type_name == "LTXVideoTransformer3DModel" or "LTX" in model_type_name:
            logger.info(f"Model {model_type_name} appears to be an LTX model variant. Using LTX defaults.")
            if config.l1_threshold is None:
                config.l1_threshold = 0.05
            if config.timestep_modulated_layer_identifier is None:
                config.timestep_modulated_layer_identifier = "transformer_blocks.0.norm1"
            if config._polynomial_coefficients is None:
                config._polynomial_coefficients = [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03]
        else:
            if config.l1_threshold is None:
                raise ValueError(
                    f"No L1 threshold was provided for {model_type}. Using TeaCache with this model is not supported "
                    f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
                )
            if config.timestep_modulated_layer_identifier is None:
                raise ValueError(
                    f"No timestep modulated layer identifier was provided for {model_type}. Using TeaCache with this model is not supported "
                    f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
                )
            if config._polynomial_coefficients is None:
                raise ValueError(
                    f"No polynomial coefficients were provided for {model_type}. Using TeaCache with this model is not "
                    f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
                    f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
                )

    timestep_modulated_layer_matches = list(
        {
            module
            for name, module in denoiser.named_modules()
            if re.search(config.timestep_modulated_layer_identifier, name)
        }
    )

    if len(timestep_modulated_layer_matches) == 0:
        raise ValueError(
            f"No layer in the denoiser module matched the provided timestep modulated layer identifier: "
            f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier."
        )
    if len(timestep_modulated_layer_matches) > 1:
        logger.warning(
            f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: "
            f"{config.timestep_modulated_layer_identifier}. Using the first match."
        )

    denoiser_state = TeaCacheDenoiserState()

    timestep_modulated_layer = timestep_modulated_layer_matches[0]
    hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients)
    add_hook_to_module(timestep_modulated_layer, hook, append=True)

    skip_layer_identifiers = config.skip_layer_identifiers
    skip_layer_matches = list(
        {
            module
            for name, module in denoiser.named_modules()
            if any(re.search(identifier, name) for identifier in skip_layer_identifiers)
        }
    )

    for skip_layer in skip_layer_matches:
        hook = DenoiserStateBasedSkipLayerHook(denoiser_state)
        add_hook_to_module(skip_layer, hook, append=True)


class TimestepModulatedOutputCacheHook(ModelHook):
    # The denoiser hook will reset its state, so we don't have to handle it here
    _is_stateful = False

    def __init__(
        self,
        denoiser_state: TeaCacheDenoiserState,
        l1_threshold: float,
        polynomial_coefficients: List[float],
    ) -> None:
        self.denoiser_state = denoiser_state
        self.l1_threshold = l1_threshold
        # TODO(aryan): implement torch equivalent
        self.rescale_fn = np.poly1d(polynomial_coefficients)

    def post_forward(self, module, output):
        if isinstance(output, tuple):
            # This assumes that the first element of the output tuple is the timestep modulated noise output.
            # For Diffusers models, this is true. For models outside diffusers, users will have to ensure
            # that the first element of the output tuple is the timestep modulated noise output (seems to be
            # the case for most research model implementations).
            timestep_modulated_noise = output[0]
        elif torch.is_tensor(output):
            timestep_modulated_noise = output
        else:
            raise ValueError(
                f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. "
                f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep "
                f"modulated noise output as the first element."
            )

        if self.denoiser_state.timestep_modulated_cache is not None:
            l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean()
            normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean()
            rescaled_l1_diff = self.rescale_fn(normalized_l1_diff)
            self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff

            if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold:
                self.denoiser_state.should_skip_blocks = True
                self.denoiser_state.accumulated_l1_difference = 0.0
            else:
                self.denoiser_state.should_skip_blocks = False

        self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise
        return output


class DenoiserStateBasedSkipLayerHook(ModelHook):
    _is_stateful = False

    def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None:
        self.denoiser_state = denoiser_state

    def new_forward(self, module, *args, **kwargs):
        args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)

        if not self.denoiser_state.should_skip_blocks:
            output = module._old_forward(*args, **kwargs)
        else:
            # Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states).
            # Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these
            # anywhere if self.denoiser_state.should_skip_blocks is True.
            output = (None, None)

        return module._diffusers_hook.post_forward(module, output)