|
import torch |
|
import numpy as np |
|
|
|
from einops import rearrange |
|
from typing import Optional, Tuple, Union |
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
|
|
from .models.cogvideox.custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel |
|
from .models.cogvideox.enhance_a_video.globals import set_num_frames |
|
|
|
def poly1d(coefficients, x): |
|
result = torch.zeros_like(x) |
|
for i, coeff in enumerate(coefficients): |
|
result += coeff * (x ** (len(coefficients) - 1 - i)) |
|
return result.abs() |
|
|
|
def fft(tensor): |
|
tensor_fft = torch.fft.fft2(tensor) |
|
tensor_fft_shifted = torch.fft.fftshift(tensor_fft) |
|
B, C, H, W = tensor.size() |
|
radius = min(H, W) // 5 |
|
|
|
Y, X = torch.meshgrid(torch.arange(H), torch.arange(W)) |
|
center_x, center_y = W // 2, H // 2 |
|
mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 |
|
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) |
|
high_freq_mask = ~low_freq_mask |
|
|
|
low_freq_fft = tensor_fft_shifted * low_freq_mask |
|
high_freq_fft = tensor_fft_shifted * high_freq_mask |
|
|
|
return low_freq_fft, high_freq_fft |
|
|
|
def teacache_cogvideox_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
timestep: Union[int, float, torch.LongTensor], |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
ofs: Optional[Union[int, float, torch.LongTensor]] = None, |
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
controlnet_states: torch.Tensor = None, |
|
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, |
|
video_flow_features: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
batch_size, num_frames, channels, height, width = hidden_states.shape |
|
|
|
set_num_frames(num_frames) |
|
|
|
|
|
timesteps = timestep |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
if self.ofs_embedding is not None: |
|
ofs_emb = self.ofs_proj(ofs) |
|
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) |
|
ofs_emb = self.ofs_embedding(ofs_emb) |
|
emb = emb + ofs_emb |
|
|
|
|
|
p = self.config.patch_size |
|
p_t = self.config.patch_size_t |
|
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) |
|
hidden_states = self.embedding_dropout(hidden_states) |
|
|
|
text_seq_length = encoder_hidden_states.shape[1] |
|
encoder_hidden_states = hidden_states[:, :text_seq_length] |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
|
|
if not hasattr(self, 'accumulated_rel_l1_distance'): |
|
should_calc = True |
|
self.accumulated_rel_l1_distance = 0 |
|
else: |
|
try: |
|
if not self.config.use_rotary_positional_embeddings: |
|
|
|
coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03] |
|
else: |
|
|
|
coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] |
|
|
|
self.accumulated_rel_l1_distance += poly1d(coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) |
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: |
|
should_calc = False |
|
else: |
|
should_calc = True |
|
self.accumulated_rel_l1_distance = 0 |
|
except: |
|
should_calc = True |
|
self.accumulated_rel_l1_distance = 0 |
|
|
|
self.previous_modulated_input = emb |
|
|
|
if self.use_fastercache: |
|
self.fastercache_counter += 1 |
|
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 != 0: |
|
if not should_calc: |
|
hidden_states += self.previous_residual |
|
encoder_hidden_states += self.previous_residual_encoder |
|
else: |
|
ori_hidden_states = hidden_states.clone() |
|
ori_encoder_hidden_states = encoder_hidden_states.clone() |
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
hidden_states, encoder_hidden_states = block( |
|
hidden_states=hidden_states[:1], |
|
encoder_hidden_states=encoder_hidden_states[:1], |
|
temb=emb[:1], |
|
image_rotary_emb=image_rotary_emb, |
|
video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, |
|
fuser = self.fuser_list[i] if self.fuser_list is not None else None, |
|
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, |
|
fastercache_counter = self.fastercache_counter, |
|
fastercache_start_step = self.fastercache_start_step, |
|
fastercache_device = self.fastercache_device |
|
) |
|
|
|
if (controlnet_states is not None) and (i < len(controlnet_states)): |
|
controlnet_states_block = controlnet_states[i] |
|
controlnet_block_weight = 1.0 |
|
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): |
|
controlnet_block_weight = controlnet_weights[i] |
|
elif isinstance(controlnet_weights, (float, int)): |
|
controlnet_block_weight = controlnet_weights |
|
|
|
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight |
|
self.previous_residual = hidden_states - ori_hidden_states |
|
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states |
|
|
|
if not self.config.use_rotary_positional_embeddings: |
|
|
|
hidden_states = self.norm_final(hidden_states) |
|
else: |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
hidden_states = self.norm_final(hidden_states) |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb=emb[:1]) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if p_t is None: |
|
output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p) |
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
|
else: |
|
output = hidden_states.reshape( |
|
1, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
|
) |
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
|
|
|
(bb, tt, cc, hh, ww) = output.shape |
|
cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) |
|
lf_c, hf_c = fft(cond.float()) |
|
|
|
if self.fastercache_counter <= self.fastercache_lf_step: |
|
self.delta_lf = self.delta_lf * 1.1 |
|
if self.fastercache_counter >= self.fastercache_hf_step: |
|
self.delta_hf = self.delta_hf * 1.1 |
|
|
|
new_hf_uc = self.delta_hf + hf_c |
|
new_lf_uc = self.delta_lf + lf_c |
|
|
|
combine_uc = new_lf_uc + new_hf_uc |
|
combined_fft = torch.fft.ifftshift(combine_uc) |
|
recovered_uncond = torch.fft.ifft2(combined_fft).real |
|
recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww) |
|
output = torch.cat([output, recovered_uncond]) |
|
else: |
|
if not should_calc: |
|
hidden_states += self.previous_residual |
|
encoder_hidden_states += self.previous_residual_encoder |
|
else: |
|
ori_hidden_states = hidden_states.clone() |
|
ori_encoder_hidden_states = encoder_hidden_states.clone() |
|
for i, block in enumerate(self.transformer_blocks): |
|
hidden_states, encoder_hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=emb, |
|
image_rotary_emb=image_rotary_emb, |
|
video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, |
|
fuser = self.fuser_list[i] if self.fuser_list is not None else None, |
|
block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, |
|
fastercache_counter = self.fastercache_counter, |
|
fastercache_start_step = self.fastercache_start_step, |
|
fastercache_device = self.fastercache_device |
|
) |
|
|
|
|
|
if (controlnet_states is not None) and (i < len(controlnet_states)): |
|
controlnet_states_block = controlnet_states[i] |
|
controlnet_block_weight = 1.0 |
|
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): |
|
controlnet_block_weight = controlnet_weights[i] |
|
print(controlnet_block_weight) |
|
elif isinstance(controlnet_weights, (float, int)): |
|
controlnet_block_weight = controlnet_weights |
|
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight |
|
self.previous_residual = hidden_states - ori_hidden_states |
|
self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states |
|
|
|
if not self.config.use_rotary_positional_embeddings: |
|
|
|
hidden_states = self.norm_final(hidden_states) |
|
else: |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
hidden_states = self.norm_final(hidden_states) |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb=emb) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if p_t is None: |
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) |
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
|
else: |
|
output = hidden_states.reshape( |
|
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
|
) |
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
|
|
|
if self.fastercache_counter >= self.fastercache_start_step + 1: |
|
(bb, tt, cc, hh, ww) = output.shape |
|
cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) |
|
uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) |
|
|
|
lf_c, hf_c = fft(cond) |
|
lf_uc, hf_uc = fft(uncond) |
|
|
|
self.delta_lf = lf_uc - lf_c |
|
self.delta_hf = hf_uc - hf_c |
|
|
|
if not return_dict: |
|
return (output,) |
|
return Transformer2DModelOutput(sample=output) |
|
|
|
class TeaCacheForCogVideoX: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"model": ("COGVIDEOMODEL", {"tooltip": "The CogVideoX model the TeaCache will be applied to."}), |
|
"enable_teacache": ("BOOLEAN", {"default": True, "tooltip": "Enable teacache will speed up inference but may lose visual quality."}), |
|
"rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}) |
|
} |
|
} |
|
|
|
RETURN_TYPES = ("COGVIDEOMODEL",) |
|
RETURN_NAMES = ("model",) |
|
FUNCTION = "apply_teacache" |
|
CATEGORY = "TeaCache" |
|
TITLE = "TeaCache For CogVideoX" |
|
|
|
def apply_teacache(self, model, enable_teacache: bool, rel_l1_thresh: float): |
|
if enable_teacache: |
|
transformer = model["pipe"].transformer |
|
transformer.__class__.rel_l1_thresh = rel_l1_thresh |
|
transformer.forward = teacache_cogvideox_forward.__get__( |
|
transformer, |
|
transformer.__class__ |
|
) |
|
else: |
|
transformer = model["pipe"].transformer |
|
transformer.forward = CogVideoXTransformer3DModel.forward.__get__( |
|
transformer, |
|
transformer.__class__ |
|
) |
|
|
|
return (model,) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"TeaCacheForCogVideoX": TeaCacheForCogVideoX |
|
} |
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()} |
|
|