Delete teacache.py
Browse files- teacache.py +0 -206
teacache.py
DELETED
@@ -1,206 +0,0 @@
|
|
1 |
-
from typing import Any, Optional, Type
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
5 |
-
from diffusers.utils import (
|
6 |
-
USE_PEFT_BACKEND,
|
7 |
-
is_torch_version,
|
8 |
-
scale_lora_layers,
|
9 |
-
unscale_lora_layers
|
10 |
-
)
|
11 |
-
|
12 |
-
class TeaCacheConfig:
|
13 |
-
"""Configuration for TeaCache optimization"""
|
14 |
-
def __init__(
|
15 |
-
self,
|
16 |
-
enabled: bool = True,
|
17 |
-
rel_l1_thresh: float = 0.05, # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
|
18 |
-
num_inference_steps: int = 50
|
19 |
-
):
|
20 |
-
self.enabled = enabled
|
21 |
-
self.rel_l1_thresh = rel_l1_thresh
|
22 |
-
self.num_inference_steps = num_inference_steps
|
23 |
-
|
24 |
-
# Internal state
|
25 |
-
self.cnt = 0
|
26 |
-
self.accumulated_rel_l1_distance = 0
|
27 |
-
self.previous_modulated_input = None
|
28 |
-
self.previous_residual = None
|
29 |
-
|
30 |
-
def create_teacache_forward(original_forward: Any):
|
31 |
-
"""Factory function to create a TeaCache-enabled forward pass"""
|
32 |
-
|
33 |
-
def teacache_forward(
|
34 |
-
self,
|
35 |
-
hidden_states: torch.Tensor,
|
36 |
-
encoder_hidden_states: torch.Tensor,
|
37 |
-
timestep: torch.LongTensor,
|
38 |
-
encoder_attention_mask: torch.Tensor,
|
39 |
-
num_frames: int,
|
40 |
-
height: int,
|
41 |
-
width: int,
|
42 |
-
rope_interpolation_scale: Optional[tuple[float, float, float]] = None,
|
43 |
-
attention_kwargs: Optional[dict[str, Any]] = None,
|
44 |
-
return_dict: bool = True,
|
45 |
-
) -> torch.Tensor:
|
46 |
-
# Handle LoRA scaling
|
47 |
-
if attention_kwargs is not None:
|
48 |
-
attention_kwargs = attention_kwargs.copy()
|
49 |
-
lora_scale = attention_kwargs.pop("scale", 1.0)
|
50 |
-
else:
|
51 |
-
attention_kwargs = {}
|
52 |
-
lora_scale = 1.0
|
53 |
-
|
54 |
-
if USE_PEFT_BACKEND:
|
55 |
-
scale_lora_layers(self, lora_scale)
|
56 |
-
|
57 |
-
# Initial processing
|
58 |
-
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
59 |
-
|
60 |
-
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
61 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
62 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
63 |
-
|
64 |
-
batch_size = hidden_states.size(0)
|
65 |
-
hidden_states = self.proj_in(hidden_states)
|
66 |
-
|
67 |
-
# Time embedding
|
68 |
-
temb, embedded_timestep = self.time_embed(
|
69 |
-
timestep.flatten(),
|
70 |
-
batch_size=batch_size,
|
71 |
-
hidden_dtype=hidden_states.dtype,
|
72 |
-
)
|
73 |
-
temb = temb.view(batch_size, -1, temb.size(-1))
|
74 |
-
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
75 |
-
|
76 |
-
# Caption projection
|
77 |
-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
78 |
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
79 |
-
|
80 |
-
# TeaCache optimization logic
|
81 |
-
should_calc = True
|
82 |
-
if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
|
83 |
-
inp = hidden_states.clone()
|
84 |
-
temb_ = temb.clone()
|
85 |
-
inp = self.transformer_blocks[0].norm1(inp)
|
86 |
-
|
87 |
-
num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
|
88 |
-
ada_values = (
|
89 |
-
self.transformer_blocks[0].scale_shift_table[None, None] +
|
90 |
-
temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
|
91 |
-
)
|
92 |
-
shift_msa, scale_msa, *_ = ada_values.unbind(dim=2)
|
93 |
-
modulated_inp = inp * (1 + scale_msa) + shift_msa
|
94 |
-
|
95 |
-
# Determine if we should calculate or reuse
|
96 |
-
if self.teacache_config.cnt == 0 or self.teacache_config.cnt == self.teacache_config.num_inference_steps - 1:
|
97 |
-
should_calc = True
|
98 |
-
self.teacache_config.accumulated_rel_l1_distance = 0
|
99 |
-
else:
|
100 |
-
# Polynomial coefficients for rescaling
|
101 |
-
coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
|
102 |
-
rescale_func = np.poly1d(coefficients)
|
103 |
-
|
104 |
-
rel_diff = (
|
105 |
-
(modulated_inp - self.teacache_config.previous_modulated_input).abs().mean() /
|
106 |
-
self.teacache_config.previous_modulated_input.abs().mean()
|
107 |
-
).cpu().item()
|
108 |
-
|
109 |
-
self.teacache_config.accumulated_rel_l1_distance += rescale_func(rel_diff)
|
110 |
-
|
111 |
-
if self.teacache_config.accumulated_rel_l1_distance < self.teacache_config.rel_l1_thresh:
|
112 |
-
should_calc = False
|
113 |
-
else:
|
114 |
-
should_calc = True
|
115 |
-
self.teacache_config.accumulated_rel_l1_distance = 0
|
116 |
-
|
117 |
-
self.teacache_config.previous_modulated_input = modulated_inp
|
118 |
-
self.teacache_config.cnt += 1
|
119 |
-
if self.teacache_config.cnt == self.teacache_config.num_inference_steps:
|
120 |
-
self.teacache_config.cnt = 0
|
121 |
-
|
122 |
-
# Process hidden states
|
123 |
-
if hasattr(self, 'teacache_config') and self.teacache_config.enabled and not should_calc:
|
124 |
-
hidden_states += self.teacache_config.previous_residual
|
125 |
-
else:
|
126 |
-
ori_hidden_states = hidden_states.clone() if hasattr(self, 'teacache_config') and self.teacache_config.enabled else None
|
127 |
-
|
128 |
-
for block in self.transformer_blocks:
|
129 |
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
130 |
-
def create_custom_forward(module, return_dict=None):
|
131 |
-
def custom_forward(*inputs):
|
132 |
-
if return_dict is not None:
|
133 |
-
return module(*inputs, return_dict=return_dict)
|
134 |
-
else:
|
135 |
-
return module(*inputs)
|
136 |
-
return custom_forward
|
137 |
-
|
138 |
-
ckpt_kwargs: dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
139 |
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
140 |
-
create_custom_forward(block),
|
141 |
-
hidden_states,
|
142 |
-
encoder_hidden_states,
|
143 |
-
temb,
|
144 |
-
image_rotary_emb,
|
145 |
-
encoder_attention_mask,
|
146 |
-
**ckpt_kwargs,
|
147 |
-
)
|
148 |
-
else:
|
149 |
-
hidden_states = block(
|
150 |
-
hidden_states=hidden_states,
|
151 |
-
encoder_hidden_states=encoder_hidden_states,
|
152 |
-
temb=temb,
|
153 |
-
image_rotary_emb=image_rotary_emb,
|
154 |
-
encoder_attention_mask=encoder_attention_mask,
|
155 |
-
)
|
156 |
-
|
157 |
-
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
158 |
-
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
159 |
-
|
160 |
-
hidden_states = self.norm_out(hidden_states)
|
161 |
-
hidden_states = hidden_states * (1 + scale) + shift
|
162 |
-
|
163 |
-
if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
|
164 |
-
self.teacache_config.previous_residual = hidden_states - ori_hidden_states
|
165 |
-
|
166 |
-
output = self.proj_out(hidden_states)
|
167 |
-
|
168 |
-
if USE_PEFT_BACKEND:
|
169 |
-
unscale_lora_layers(self, lora_scale)
|
170 |
-
|
171 |
-
if not return_dict:
|
172 |
-
return (output,)
|
173 |
-
|
174 |
-
return {"sample": output}
|
175 |
-
|
176 |
-
return teacache_forward
|
177 |
-
|
178 |
-
def enable_teacache(model_class: Type[LTXVideoTransformer3DModel], config: TeaCacheConfig) -> None:
|
179 |
-
"""Enable TeaCache optimization for a model class
|
180 |
-
|
181 |
-
Args:
|
182 |
-
model_class: The model class to patch
|
183 |
-
config: TeaCache configuration
|
184 |
-
"""
|
185 |
-
# Store original forward method if needed
|
186 |
-
if not hasattr(model_class, '_original_forward'):
|
187 |
-
model_class._original_forward = model_class.forward
|
188 |
-
|
189 |
-
# Create new forward method with TeaCache
|
190 |
-
model_class.forward = create_teacache_forward(model_class._original_forward)
|
191 |
-
|
192 |
-
# Add config attribute to class
|
193 |
-
model_class.teacache_config = config
|
194 |
-
|
195 |
-
def disable_teacache(model_class: Type[LTXVideoTransformer3DModel]) -> None:
|
196 |
-
"""Disable TeaCache optimization for a model class
|
197 |
-
|
198 |
-
Args:
|
199 |
-
model_class: The model class to unpatch
|
200 |
-
"""
|
201 |
-
if hasattr(model_class, '_original_forward'):
|
202 |
-
model_class.forward = model_class._original_forward
|
203 |
-
delattr(model_class, '_original_forward')
|
204 |
-
|
205 |
-
if hasattr(model_class, 'teacache_config'):
|
206 |
-
delattr(model_class, 'teacache_config')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|