jbilcke-hf HF staff commited on
Commit
6ecbc3e
·
verified ·
1 Parent(s): 43fb04c

Delete teacache.py

Browse files
Files changed (1) hide show
  1. teacache.py +0 -206
teacache.py DELETED
@@ -1,206 +0,0 @@
1
- from typing import Any, Optional, Type
2
- import numpy as np
3
- import torch
4
- from diffusers.models.transformers import LTXVideoTransformer3DModel
5
- from diffusers.utils import (
6
- USE_PEFT_BACKEND,
7
- is_torch_version,
8
- scale_lora_layers,
9
- unscale_lora_layers
10
- )
11
-
12
- class TeaCacheConfig:
13
- """Configuration for TeaCache optimization"""
14
- def __init__(
15
- self,
16
- enabled: bool = True,
17
- rel_l1_thresh: float = 0.05, # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
18
- num_inference_steps: int = 50
19
- ):
20
- self.enabled = enabled
21
- self.rel_l1_thresh = rel_l1_thresh
22
- self.num_inference_steps = num_inference_steps
23
-
24
- # Internal state
25
- self.cnt = 0
26
- self.accumulated_rel_l1_distance = 0
27
- self.previous_modulated_input = None
28
- self.previous_residual = None
29
-
30
- def create_teacache_forward(original_forward: Any):
31
- """Factory function to create a TeaCache-enabled forward pass"""
32
-
33
- def teacache_forward(
34
- self,
35
- hidden_states: torch.Tensor,
36
- encoder_hidden_states: torch.Tensor,
37
- timestep: torch.LongTensor,
38
- encoder_attention_mask: torch.Tensor,
39
- num_frames: int,
40
- height: int,
41
- width: int,
42
- rope_interpolation_scale: Optional[tuple[float, float, float]] = None,
43
- attention_kwargs: Optional[dict[str, Any]] = None,
44
- return_dict: bool = True,
45
- ) -> torch.Tensor:
46
- # Handle LoRA scaling
47
- if attention_kwargs is not None:
48
- attention_kwargs = attention_kwargs.copy()
49
- lora_scale = attention_kwargs.pop("scale", 1.0)
50
- else:
51
- attention_kwargs = {}
52
- lora_scale = 1.0
53
-
54
- if USE_PEFT_BACKEND:
55
- scale_lora_layers(self, lora_scale)
56
-
57
- # Initial processing
58
- image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
59
-
60
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
61
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
62
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
63
-
64
- batch_size = hidden_states.size(0)
65
- hidden_states = self.proj_in(hidden_states)
66
-
67
- # Time embedding
68
- temb, embedded_timestep = self.time_embed(
69
- timestep.flatten(),
70
- batch_size=batch_size,
71
- hidden_dtype=hidden_states.dtype,
72
- )
73
- temb = temb.view(batch_size, -1, temb.size(-1))
74
- embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
75
-
76
- # Caption projection
77
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
78
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
79
-
80
- # TeaCache optimization logic
81
- should_calc = True
82
- if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
83
- inp = hidden_states.clone()
84
- temb_ = temb.clone()
85
- inp = self.transformer_blocks[0].norm1(inp)
86
-
87
- num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
88
- ada_values = (
89
- self.transformer_blocks[0].scale_shift_table[None, None] +
90
- temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
91
- )
92
- shift_msa, scale_msa, *_ = ada_values.unbind(dim=2)
93
- modulated_inp = inp * (1 + scale_msa) + shift_msa
94
-
95
- # Determine if we should calculate or reuse
96
- if self.teacache_config.cnt == 0 or self.teacache_config.cnt == self.teacache_config.num_inference_steps - 1:
97
- should_calc = True
98
- self.teacache_config.accumulated_rel_l1_distance = 0
99
- else:
100
- # Polynomial coefficients for rescaling
101
- coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
102
- rescale_func = np.poly1d(coefficients)
103
-
104
- rel_diff = (
105
- (modulated_inp - self.teacache_config.previous_modulated_input).abs().mean() /
106
- self.teacache_config.previous_modulated_input.abs().mean()
107
- ).cpu().item()
108
-
109
- self.teacache_config.accumulated_rel_l1_distance += rescale_func(rel_diff)
110
-
111
- if self.teacache_config.accumulated_rel_l1_distance < self.teacache_config.rel_l1_thresh:
112
- should_calc = False
113
- else:
114
- should_calc = True
115
- self.teacache_config.accumulated_rel_l1_distance = 0
116
-
117
- self.teacache_config.previous_modulated_input = modulated_inp
118
- self.teacache_config.cnt += 1
119
- if self.teacache_config.cnt == self.teacache_config.num_inference_steps:
120
- self.teacache_config.cnt = 0
121
-
122
- # Process hidden states
123
- if hasattr(self, 'teacache_config') and self.teacache_config.enabled and not should_calc:
124
- hidden_states += self.teacache_config.previous_residual
125
- else:
126
- ori_hidden_states = hidden_states.clone() if hasattr(self, 'teacache_config') and self.teacache_config.enabled else None
127
-
128
- for block in self.transformer_blocks:
129
- if torch.is_grad_enabled() and self.gradient_checkpointing:
130
- def create_custom_forward(module, return_dict=None):
131
- def custom_forward(*inputs):
132
- if return_dict is not None:
133
- return module(*inputs, return_dict=return_dict)
134
- else:
135
- return module(*inputs)
136
- return custom_forward
137
-
138
- ckpt_kwargs: dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
139
- hidden_states = torch.utils.checkpoint.checkpoint(
140
- create_custom_forward(block),
141
- hidden_states,
142
- encoder_hidden_states,
143
- temb,
144
- image_rotary_emb,
145
- encoder_attention_mask,
146
- **ckpt_kwargs,
147
- )
148
- else:
149
- hidden_states = block(
150
- hidden_states=hidden_states,
151
- encoder_hidden_states=encoder_hidden_states,
152
- temb=temb,
153
- image_rotary_emb=image_rotary_emb,
154
- encoder_attention_mask=encoder_attention_mask,
155
- )
156
-
157
- scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
158
- shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
159
-
160
- hidden_states = self.norm_out(hidden_states)
161
- hidden_states = hidden_states * (1 + scale) + shift
162
-
163
- if hasattr(self, 'teacache_config') and self.teacache_config.enabled:
164
- self.teacache_config.previous_residual = hidden_states - ori_hidden_states
165
-
166
- output = self.proj_out(hidden_states)
167
-
168
- if USE_PEFT_BACKEND:
169
- unscale_lora_layers(self, lora_scale)
170
-
171
- if not return_dict:
172
- return (output,)
173
-
174
- return {"sample": output}
175
-
176
- return teacache_forward
177
-
178
- def enable_teacache(model_class: Type[LTXVideoTransformer3DModel], config: TeaCacheConfig) -> None:
179
- """Enable TeaCache optimization for a model class
180
-
181
- Args:
182
- model_class: The model class to patch
183
- config: TeaCache configuration
184
- """
185
- # Store original forward method if needed
186
- if not hasattr(model_class, '_original_forward'):
187
- model_class._original_forward = model_class.forward
188
-
189
- # Create new forward method with TeaCache
190
- model_class.forward = create_teacache_forward(model_class._original_forward)
191
-
192
- # Add config attribute to class
193
- model_class.teacache_config = config
194
-
195
- def disable_teacache(model_class: Type[LTXVideoTransformer3DModel]) -> None:
196
- """Disable TeaCache optimization for a model class
197
-
198
- Args:
199
- model_class: The model class to unpatch
200
- """
201
- if hasattr(model_class, '_original_forward'):
202
- model_class.forward = model_class._original_forward
203
- delattr(model_class, '_original_forward')
204
-
205
- if hasattr(model_class, 'teacache_config'):
206
- delattr(model_class, 'teacache_config')