jbilcke-hf HF staff commited on
Commit
c168216
·
verified ·
1 Parent(s): f4b832e

Rename monkey.py to teacache.py

Browse files
Files changed (2) hide show
  1. monkey.py +0 -17
  2. teacache.py +251 -0
monkey.py DELETED
@@ -1,17 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import List, Optional
3
-
4
- # Create a fixed version of the TeaCacheConfig class
5
- @dataclass
6
- class TeaCacheConfig:
7
- l1_threshold: Optional[float] = None
8
- skip_layer_identifiers: List[str] = field(default_factory=lambda: ["blocks", "transformer_blocks",
9
- "single_transformer_blocks",
10
- "temporal_transformer_blocks"])
11
- _polynomial_coefficients: Optional[List[float]] = None
12
-
13
- # Apply the monkey patch
14
- def apply_patches():
15
- import diffusers.pipelines.teacache_utils
16
- diffusers.pipelines.teacache_utils.TeaCacheConfig = TeaCacheConfig
17
- print("Applied monkey patch for TeaCacheConfig")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
teacache.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from diffusers.models import (
24
+ FluxTransformer2DModel,
25
+ HunyuanVideoTransformer3DModel,
26
+ LTXVideoTransformer3DModel,
27
+ LuminaNextDiT2DModel,
28
+ MochiTransformer3DModel,
29
+ )
30
+ from diffusers.models.hooks import ModelHook, add_hook_to_module
31
+ from diffusers.utils import logging
32
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+ # Source: https://github.com/ali-vilab/TeaCache
37
+ # TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use.
38
+ # fmt: off
39
+ _MODEL_TO_POLY_COEFFICIENTS = {
40
+ FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
41
+ HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02],
42
+ LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03],
43
+ LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344],
44
+ MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03],
45
+ }
46
+ # fmt: on
47
+
48
+ _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = {
49
+ FluxTransformer2DModel: 0.25,
50
+ HunyuanVideoTransformer3DModel: 0.1,
51
+ LTXVideoTransformer3DModel: 0.05,
52
+ LuminaNextDiT2DModel: 0.2,
53
+ MochiTransformer3DModel: 0.06,
54
+ }
55
+
56
+ _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = {
57
+ FluxTransformer2DModel: "transformer_blocks.0.norm1",
58
+ }
59
+
60
+ _MODEL_TO_SKIP_END_LAYER_IDENTIFIER = {
61
+ FluxTransformer2DModel: "norm_out",
62
+ }
63
+
64
+ _DEFAULT_SKIP_LAYER_IDENTIFIERS = [
65
+ "blocks",
66
+ "transformer_blocks",
67
+ "single_transformer_blocks",
68
+ "temporal_transformer_blocks",
69
+ ]
70
+
71
+
72
+ @dataclass
73
+ class TeaCacheConfig:
74
+ l1_threshold: Optional[float] = None
75
+
76
+ # Julian: this is a fix for @dataclass
77
+ skip_layer_identifiers: List[str] = field(default_factory=lambda: _DEFAULT_SKIP_LAYER_IDENTIFIERS)
78
+
79
+ _polynomial_coefficients: Optional[List[float]] = None
80
+
81
+
82
+ class TeaCacheDenoiserState:
83
+ def __init__(self):
84
+ self.iteration: int = 0
85
+ self.accumulated_l1_difference: float = 0.0
86
+ self.timestep_modulated_cache: torch.Tensor = None
87
+ self.residual_cache: torch.Tensor = None
88
+ self.should_skip_blocks: bool = False
89
+
90
+ def reset(self):
91
+ self.iteration = 0
92
+ self.accumulated_l1_difference = 0.0
93
+ self.timestep_modulated_cache = None
94
+ self.residual_cache = None
95
+
96
+
97
+ def apply_teacache(
98
+ pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None
99
+ ) -> None:
100
+ r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module.
101
+
102
+ Args:
103
+ TODO
104
+ """
105
+
106
+ if config is None:
107
+ logger.warning("No TeaCacheConfig provided. Using default configuration.")
108
+ config = TeaCacheConfig()
109
+
110
+ if denoiser is None:
111
+ denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
112
+
113
+ if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())):
114
+ if config.l1_threshold is None:
115
+ logger.info(
116
+ f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
117
+ f"For higher speedup, increase the threshold."
118
+ )
119
+ config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)]
120
+ if config.timestep_modulated_layer_identifier is None:
121
+ logger.info(
122
+ f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper."
123
+ )
124
+ config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)]
125
+ if config._polynomial_coefficients is None:
126
+ logger.info(
127
+ f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper."
128
+ )
129
+ config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)]
130
+ else:
131
+ if config.l1_threshold is None:
132
+ raise ValueError(
133
+ f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
134
+ f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
135
+ )
136
+ if config.timestep_modulated_layer_identifier is None:
137
+ raise ValueError(
138
+ f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
139
+ f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
140
+ )
141
+ if config._polynomial_coefficients is None:
142
+ raise ValueError(
143
+ f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not "
144
+ f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
145
+ f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
146
+ )
147
+
148
+ timestep_modulated_layer_matches = list(
149
+ {
150
+ module
151
+ for name, module in denoiser.named_modules()
152
+ if re.match(config.timestep_modulated_layer_identifier, name)
153
+ }
154
+ )
155
+
156
+ if len(timestep_modulated_layer_matches) == 0:
157
+ raise ValueError(
158
+ f"No layer in the denoiser module matched the provided timestep modulated layer identifier: "
159
+ f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier."
160
+ )
161
+ if len(timestep_modulated_layer_matches) > 1:
162
+ logger.warning(
163
+ f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: "
164
+ f"{config.timestep_modulated_layer_identifier}. Using the first match."
165
+ )
166
+
167
+ denoiser_state = TeaCacheDenoiserState()
168
+
169
+ timestep_modulated_layer = timestep_modulated_layer_matches[0]
170
+ hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients)
171
+ add_hook_to_module(timestep_modulated_layer, hook, append=True)
172
+
173
+ skip_layer_identifiers = config.skip_layer_identifiers
174
+ skip_layer_matches = list(
175
+ {
176
+ module
177
+ for name, module in denoiser.named_modules()
178
+ if any(re.match(identifier, name) for identifier in skip_layer_identifiers)
179
+ }
180
+ )
181
+
182
+ for skip_layer in skip_layer_matches:
183
+ hook = DenoiserStateBasedSkipLayerHook(denoiser_state)
184
+ add_hook_to_module(skip_layer, hook, append=True)
185
+
186
+
187
+ class TimestepModulatedOutputCacheHook(ModelHook):
188
+ # The denoiser hook will reset its state, so we don't have to handle it here
189
+ _is_stateful = False
190
+
191
+ def __init__(
192
+ self,
193
+ denoiser_state: TeaCacheDenoiserState,
194
+ l1_threshold: float,
195
+ polynomial_coefficients: List[float],
196
+ ) -> None:
197
+ self.denoiser_state = denoiser_state
198
+ self.l1_threshold = l1_threshold
199
+ # TODO(aryan): implement torch equivalent
200
+ self.rescale_fn = np.poly1d(polynomial_coefficients)
201
+
202
+ def post_forward(self, module, output):
203
+ if isinstance(output, tuple):
204
+ # This assumes that the first element of the output tuple is the timestep modulated noise output.
205
+ # For Diffusers models, this is true. For models outside diffusers, users will have to ensure
206
+ # that the first element of the output tuple is the timestep modulated noise output (seems to be
207
+ # the case for most research model implementations).
208
+ timestep_modulated_noise = output[0]
209
+ elif torch.is_tensor(output):
210
+ timestep_modulated_noise = output
211
+ else:
212
+ raise ValueError(
213
+ f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. "
214
+ f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep "
215
+ f"modulated noise output as the first element."
216
+ )
217
+
218
+ if self.denoiser_state.timestep_modulated_cache is not None:
219
+ l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean()
220
+ normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean()
221
+ rescaled_l1_diff = self.rescale_fn(normalized_l1_diff)
222
+ self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff
223
+
224
+ if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold:
225
+ self.denoiser_state.should_skip_blocks = True
226
+ self.denoiser_state.accumulated_l1_difference = 0.0
227
+ else:
228
+ self.denoiser_state.should_skip_blocks = False
229
+
230
+ self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise
231
+ return output
232
+
233
+
234
+ class DenoiserStateBasedSkipLayerHook(ModelHook):
235
+ _is_stateful = False
236
+
237
+ def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None:
238
+ self.denoiser_state = denoiser_state
239
+
240
+ def new_forward(self, module, *args, **kwargs):
241
+ args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
242
+
243
+ if not self.denoiser_state.should_skip_blocks:
244
+ output = module._old_forward(*args, **kwargs)
245
+ else:
246
+ # Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states).
247
+ # Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these
248
+ # anywhere if self.denoiser_state.should_skip_blocks is True.
249
+ output = (None, None)
250
+
251
+ return module._diffusers_hook.post_forward(module, output)