Rename monkey.py to teacache.py
Browse files- monkey.py +0 -17
- teacache.py +251 -0
monkey.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass, field
|
2 |
-
from typing import List, Optional
|
3 |
-
|
4 |
-
# Create a fixed version of the TeaCacheConfig class
|
5 |
-
@dataclass
|
6 |
-
class TeaCacheConfig:
|
7 |
-
l1_threshold: Optional[float] = None
|
8 |
-
skip_layer_identifiers: List[str] = field(default_factory=lambda: ["blocks", "transformer_blocks",
|
9 |
-
"single_transformer_blocks",
|
10 |
-
"temporal_transformer_blocks"])
|
11 |
-
_polynomial_coefficients: Optional[List[float]] = None
|
12 |
-
|
13 |
-
# Apply the monkey patch
|
14 |
-
def apply_patches():
|
15 |
-
import diffusers.pipelines.teacache_utils
|
16 |
-
diffusers.pipelines.teacache_utils.TeaCacheConfig = TeaCacheConfig
|
17 |
-
print("Applied monkey patch for TeaCacheConfig")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
teacache.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import re
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import List, Optional
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
|
23 |
+
from diffusers.models import (
|
24 |
+
FluxTransformer2DModel,
|
25 |
+
HunyuanVideoTransformer3DModel,
|
26 |
+
LTXVideoTransformer3DModel,
|
27 |
+
LuminaNextDiT2DModel,
|
28 |
+
MochiTransformer3DModel,
|
29 |
+
)
|
30 |
+
from diffusers.models.hooks import ModelHook, add_hook_to_module
|
31 |
+
from diffusers.utils import logging
|
32 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
# Source: https://github.com/ali-vilab/TeaCache
|
37 |
+
# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use.
|
38 |
+
# fmt: off
|
39 |
+
_MODEL_TO_POLY_COEFFICIENTS = {
|
40 |
+
FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
|
41 |
+
HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02],
|
42 |
+
LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03],
|
43 |
+
LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344],
|
44 |
+
MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03],
|
45 |
+
}
|
46 |
+
# fmt: on
|
47 |
+
|
48 |
+
_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = {
|
49 |
+
FluxTransformer2DModel: 0.25,
|
50 |
+
HunyuanVideoTransformer3DModel: 0.1,
|
51 |
+
LTXVideoTransformer3DModel: 0.05,
|
52 |
+
LuminaNextDiT2DModel: 0.2,
|
53 |
+
MochiTransformer3DModel: 0.06,
|
54 |
+
}
|
55 |
+
|
56 |
+
_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = {
|
57 |
+
FluxTransformer2DModel: "transformer_blocks.0.norm1",
|
58 |
+
}
|
59 |
+
|
60 |
+
_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = {
|
61 |
+
FluxTransformer2DModel: "norm_out",
|
62 |
+
}
|
63 |
+
|
64 |
+
_DEFAULT_SKIP_LAYER_IDENTIFIERS = [
|
65 |
+
"blocks",
|
66 |
+
"transformer_blocks",
|
67 |
+
"single_transformer_blocks",
|
68 |
+
"temporal_transformer_blocks",
|
69 |
+
]
|
70 |
+
|
71 |
+
|
72 |
+
@dataclass
|
73 |
+
class TeaCacheConfig:
|
74 |
+
l1_threshold: Optional[float] = None
|
75 |
+
|
76 |
+
# Julian: this is a fix for @dataclass
|
77 |
+
skip_layer_identifiers: List[str] = field(default_factory=lambda: _DEFAULT_SKIP_LAYER_IDENTIFIERS)
|
78 |
+
|
79 |
+
_polynomial_coefficients: Optional[List[float]] = None
|
80 |
+
|
81 |
+
|
82 |
+
class TeaCacheDenoiserState:
|
83 |
+
def __init__(self):
|
84 |
+
self.iteration: int = 0
|
85 |
+
self.accumulated_l1_difference: float = 0.0
|
86 |
+
self.timestep_modulated_cache: torch.Tensor = None
|
87 |
+
self.residual_cache: torch.Tensor = None
|
88 |
+
self.should_skip_blocks: bool = False
|
89 |
+
|
90 |
+
def reset(self):
|
91 |
+
self.iteration = 0
|
92 |
+
self.accumulated_l1_difference = 0.0
|
93 |
+
self.timestep_modulated_cache = None
|
94 |
+
self.residual_cache = None
|
95 |
+
|
96 |
+
|
97 |
+
def apply_teacache(
|
98 |
+
pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None
|
99 |
+
) -> None:
|
100 |
+
r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
TODO
|
104 |
+
"""
|
105 |
+
|
106 |
+
if config is None:
|
107 |
+
logger.warning("No TeaCacheConfig provided. Using default configuration.")
|
108 |
+
config = TeaCacheConfig()
|
109 |
+
|
110 |
+
if denoiser is None:
|
111 |
+
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
|
112 |
+
|
113 |
+
if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())):
|
114 |
+
if config.l1_threshold is None:
|
115 |
+
logger.info(
|
116 |
+
f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
|
117 |
+
f"For higher speedup, increase the threshold."
|
118 |
+
)
|
119 |
+
config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)]
|
120 |
+
if config.timestep_modulated_layer_identifier is None:
|
121 |
+
logger.info(
|
122 |
+
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper."
|
123 |
+
)
|
124 |
+
config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)]
|
125 |
+
if config._polynomial_coefficients is None:
|
126 |
+
logger.info(
|
127 |
+
f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper."
|
128 |
+
)
|
129 |
+
config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)]
|
130 |
+
else:
|
131 |
+
if config.l1_threshold is None:
|
132 |
+
raise ValueError(
|
133 |
+
f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
|
134 |
+
f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
|
135 |
+
)
|
136 |
+
if config.timestep_modulated_layer_identifier is None:
|
137 |
+
raise ValueError(
|
138 |
+
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
|
139 |
+
f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
|
140 |
+
)
|
141 |
+
if config._polynomial_coefficients is None:
|
142 |
+
raise ValueError(
|
143 |
+
f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not "
|
144 |
+
f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
|
145 |
+
f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
|
146 |
+
)
|
147 |
+
|
148 |
+
timestep_modulated_layer_matches = list(
|
149 |
+
{
|
150 |
+
module
|
151 |
+
for name, module in denoiser.named_modules()
|
152 |
+
if re.match(config.timestep_modulated_layer_identifier, name)
|
153 |
+
}
|
154 |
+
)
|
155 |
+
|
156 |
+
if len(timestep_modulated_layer_matches) == 0:
|
157 |
+
raise ValueError(
|
158 |
+
f"No layer in the denoiser module matched the provided timestep modulated layer identifier: "
|
159 |
+
f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier."
|
160 |
+
)
|
161 |
+
if len(timestep_modulated_layer_matches) > 1:
|
162 |
+
logger.warning(
|
163 |
+
f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: "
|
164 |
+
f"{config.timestep_modulated_layer_identifier}. Using the first match."
|
165 |
+
)
|
166 |
+
|
167 |
+
denoiser_state = TeaCacheDenoiserState()
|
168 |
+
|
169 |
+
timestep_modulated_layer = timestep_modulated_layer_matches[0]
|
170 |
+
hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients)
|
171 |
+
add_hook_to_module(timestep_modulated_layer, hook, append=True)
|
172 |
+
|
173 |
+
skip_layer_identifiers = config.skip_layer_identifiers
|
174 |
+
skip_layer_matches = list(
|
175 |
+
{
|
176 |
+
module
|
177 |
+
for name, module in denoiser.named_modules()
|
178 |
+
if any(re.match(identifier, name) for identifier in skip_layer_identifiers)
|
179 |
+
}
|
180 |
+
)
|
181 |
+
|
182 |
+
for skip_layer in skip_layer_matches:
|
183 |
+
hook = DenoiserStateBasedSkipLayerHook(denoiser_state)
|
184 |
+
add_hook_to_module(skip_layer, hook, append=True)
|
185 |
+
|
186 |
+
|
187 |
+
class TimestepModulatedOutputCacheHook(ModelHook):
|
188 |
+
# The denoiser hook will reset its state, so we don't have to handle it here
|
189 |
+
_is_stateful = False
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
denoiser_state: TeaCacheDenoiserState,
|
194 |
+
l1_threshold: float,
|
195 |
+
polynomial_coefficients: List[float],
|
196 |
+
) -> None:
|
197 |
+
self.denoiser_state = denoiser_state
|
198 |
+
self.l1_threshold = l1_threshold
|
199 |
+
# TODO(aryan): implement torch equivalent
|
200 |
+
self.rescale_fn = np.poly1d(polynomial_coefficients)
|
201 |
+
|
202 |
+
def post_forward(self, module, output):
|
203 |
+
if isinstance(output, tuple):
|
204 |
+
# This assumes that the first element of the output tuple is the timestep modulated noise output.
|
205 |
+
# For Diffusers models, this is true. For models outside diffusers, users will have to ensure
|
206 |
+
# that the first element of the output tuple is the timestep modulated noise output (seems to be
|
207 |
+
# the case for most research model implementations).
|
208 |
+
timestep_modulated_noise = output[0]
|
209 |
+
elif torch.is_tensor(output):
|
210 |
+
timestep_modulated_noise = output
|
211 |
+
else:
|
212 |
+
raise ValueError(
|
213 |
+
f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. "
|
214 |
+
f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep "
|
215 |
+
f"modulated noise output as the first element."
|
216 |
+
)
|
217 |
+
|
218 |
+
if self.denoiser_state.timestep_modulated_cache is not None:
|
219 |
+
l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean()
|
220 |
+
normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean()
|
221 |
+
rescaled_l1_diff = self.rescale_fn(normalized_l1_diff)
|
222 |
+
self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff
|
223 |
+
|
224 |
+
if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold:
|
225 |
+
self.denoiser_state.should_skip_blocks = True
|
226 |
+
self.denoiser_state.accumulated_l1_difference = 0.0
|
227 |
+
else:
|
228 |
+
self.denoiser_state.should_skip_blocks = False
|
229 |
+
|
230 |
+
self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise
|
231 |
+
return output
|
232 |
+
|
233 |
+
|
234 |
+
class DenoiserStateBasedSkipLayerHook(ModelHook):
|
235 |
+
_is_stateful = False
|
236 |
+
|
237 |
+
def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None:
|
238 |
+
self.denoiser_state = denoiser_state
|
239 |
+
|
240 |
+
def new_forward(self, module, *args, **kwargs):
|
241 |
+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
242 |
+
|
243 |
+
if not self.denoiser_state.should_skip_blocks:
|
244 |
+
output = module._old_forward(*args, **kwargs)
|
245 |
+
else:
|
246 |
+
# Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states).
|
247 |
+
# Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these
|
248 |
+
# anywhere if self.denoiser_state.should_skip_blocks is True.
|
249 |
+
output = (None, None)
|
250 |
+
|
251 |
+
return module._diffusers_hook.post_forward(module, output)
|