import torch import numpy as np from einops import rearrange from typing import Optional, Tuple, Union from diffusers.models.modeling_outputs import Transformer2DModelOutput from .models.cogvideox.custom_cogvideox_transformer_3d import CogVideoXTransformer3DModel from .models.cogvideox.enhance_a_video.globals import set_num_frames def poly1d(coefficients, x): result = torch.zeros_like(x) for i, coeff in enumerate(coefficients): result += coeff * (x ** (len(coefficients) - 1 - i)) return result.abs() def fft(tensor): tensor_fft = torch.fft.fft2(tensor) tensor_fft_shifted = torch.fft.fftshift(tensor_fft) B, C, H, W = tensor.size() radius = min(H, W) // 5 Y, X = torch.meshgrid(torch.arange(H), torch.arange(W)) center_x, center_y = W // 2, H // 2 mask = (X - center_x) ** 2 + (Y - center_y) ** 2 <= radius ** 2 low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) high_freq_mask = ~low_freq_mask low_freq_fft = tensor_fft_shifted * low_freq_mask high_freq_fft = tensor_fft_shifted * high_freq_mask return low_freq_fft, high_freq_fft def teacache_cogvideox_forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, ofs: Optional[Union[int, float, torch.LongTensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, controlnet_states: torch.Tensor = None, controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, video_flow_features: Optional[torch.Tensor] = None, return_dict: bool = True, ): batch_size, num_frames, channels, height, width = hidden_states.shape set_num_frames(num_frames) # enhance a video global # 1. Time embedding timesteps = timestep t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) if self.ofs_embedding is not None: #1.5 I2V ofs_emb = self.ofs_proj(ofs) ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) ofs_emb = self.ofs_embedding(ofs_emb) emb = emb + ofs_emb # 2. Patch embedding p = self.config.patch_size p_t = self.config.patch_size_t hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) text_seq_length = encoder_hidden_states.shape[1] encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] # enable teacache if not hasattr(self, 'accumulated_rel_l1_distance'): should_calc = True self.accumulated_rel_l1_distance = 0 else: try: if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03] else: # CogVideoX-5B coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] self.accumulated_rel_l1_distance += poly1d(coefficients, ((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) if self.accumulated_rel_l1_distance < self.rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 except: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = emb if self.use_fastercache: self.fastercache_counter += 1 if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 != 0: if not should_calc: hidden_states += self.previous_residual encoder_hidden_states += self.previous_residual_encoder else: ori_hidden_states = hidden_states.clone() ori_encoder_hidden_states = encoder_hidden_states.clone() # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states[:1], encoder_hidden_states=encoder_hidden_states[:1], temb=emb[:1], image_rotary_emb=image_rotary_emb, video_flow_feature=video_flow_features[i][:1] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, fastercache_start_step = self.fastercache_start_step, fastercache_device = self.fastercache_device ) if (controlnet_states is not None) and (i < len(controlnet_states)): controlnet_states_block = controlnet_states[i] controlnet_block_weight = 1.0 if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): controlnet_block_weight = controlnet_weights[i] elif isinstance(controlnet_weights, (float, int)): controlnet_block_weight = controlnet_weights hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight self.previous_residual = hidden_states - ori_hidden_states self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B hidden_states = self.norm_final(hidden_states) else: # CogVideoX-5B hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = self.norm_final(hidden_states) hidden_states = hidden_states[:, text_seq_length:] # 4. Final block hidden_states = self.norm_out(hidden_states, temb=emb[:1]) hidden_states = self.proj_out(hidden_states) # 5. Unpatchify # Note: we use `-1` instead of `channels`: # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) if p_t is None: output = hidden_states.reshape(1, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) else: output = hidden_states.reshape( 1, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) (bb, tt, cc, hh, ww) = output.shape cond = rearrange(output, "B T C H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) lf_c, hf_c = fft(cond.float()) if self.fastercache_counter <= self.fastercache_lf_step: self.delta_lf = self.delta_lf * 1.1 if self.fastercache_counter >= self.fastercache_hf_step: self.delta_hf = self.delta_hf * 1.1 new_hf_uc = self.delta_hf + hf_c new_lf_uc = self.delta_lf + lf_c combine_uc = new_lf_uc + new_hf_uc combined_fft = torch.fft.ifftshift(combine_uc) recovered_uncond = torch.fft.ifft2(combined_fft).real recovered_uncond = rearrange(recovered_uncond.to(output.dtype), "(B T) C H W -> B T C H W", B=bb, C=cc, T=tt, H=hh, W=ww) output = torch.cat([output, recovered_uncond]) else: if not should_calc: hidden_states += self.previous_residual encoder_hidden_states += self.previous_residual_encoder else: ori_hidden_states = hidden_states.clone() ori_encoder_hidden_states = encoder_hidden_states.clone() for i, block in enumerate(self.transformer_blocks): hidden_states, encoder_hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, image_rotary_emb=image_rotary_emb, video_flow_feature=video_flow_features[i] if video_flow_features is not None else None, fuser = self.fuser_list[i] if self.fuser_list is not None else None, block_use_fastercache = i <= self.fastercache_num_blocks_to_cache, fastercache_counter = self.fastercache_counter, fastercache_start_step = self.fastercache_start_step, fastercache_device = self.fastercache_device ) # controlnet if (controlnet_states is not None) and (i < len(controlnet_states)): controlnet_states_block = controlnet_states[i] controlnet_block_weight = 1.0 if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): controlnet_block_weight = controlnet_weights[i] print(controlnet_block_weight) elif isinstance(controlnet_weights, (float, int)): controlnet_block_weight = controlnet_weights hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight self.previous_residual = hidden_states - ori_hidden_states self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states if not self.config.use_rotary_positional_embeddings: # CogVideoX-2B hidden_states = self.norm_final(hidden_states) else: # CogVideoX-5B hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = self.norm_final(hidden_states) hidden_states = hidden_states[:, text_seq_length:] # 4. Final block hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.proj_out(hidden_states) # 5. Unpatchify # Note: we use `-1` instead of `channels`: # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) if p_t is None: output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) else: output = hidden_states.reshape( batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) if self.fastercache_counter >= self.fastercache_start_step + 1: (bb, tt, cc, hh, ww) = output.shape cond = rearrange(output[0:1].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) uncond = rearrange(output[1:2].float(), "B T C H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww) lf_c, hf_c = fft(cond) lf_uc, hf_uc = fft(uncond) self.delta_lf = lf_uc - lf_c self.delta_hf = hf_uc - hf_c if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class TeaCacheForCogVideoX: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("COGVIDEOMODEL", {"tooltip": "The CogVideoX model the TeaCache will be applied to."}), "enable_teacache": ("BOOLEAN", {"default": True, "tooltip": "Enable teacache will speed up inference but may lose visual quality."}), "rel_l1_thresh": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}) } } RETURN_TYPES = ("COGVIDEOMODEL",) RETURN_NAMES = ("model",) FUNCTION = "apply_teacache" CATEGORY = "TeaCache" TITLE = "TeaCache For CogVideoX" def apply_teacache(self, model, enable_teacache: bool, rel_l1_thresh: float): if enable_teacache: transformer = model["pipe"].transformer transformer.__class__.rel_l1_thresh = rel_l1_thresh transformer.forward = teacache_cogvideox_forward.__get__( transformer, transformer.__class__ ) else: transformer = model["pipe"].transformer transformer.forward = CogVideoXTransformer3DModel.forward.__get__( transformer, transformer.__class__ ) return (model,) NODE_CLASS_MAPPINGS = { "TeaCacheForCogVideoX": TeaCacheForCogVideoX } NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()}