from typing import Any, Dict, List, Optional, Union import torch from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel from PIL import Image from transformers import T5EncoderModel, T5Tokenizer from .utils import prepare_rotary_positional_embeddings def load_condition_models( model_id: str = "THUDM/CogVideoX-5b", text_encoder_dtype: torch.dtype = torch.bfloat16, revision: Optional[str] = None, cache_dir: Optional[str] = None, **kwargs, ): tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) text_encoder = T5EncoderModel.from_pretrained( model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir ) return {"tokenizer": tokenizer, "text_encoder": text_encoder} def load_latent_models( model_id: str = "THUDM/CogVideoX-5b", vae_dtype: torch.dtype = torch.bfloat16, revision: Optional[str] = None, cache_dir: Optional[str] = None, **kwargs, ): vae = AutoencoderKLCogVideoX.from_pretrained( model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir ) return {"vae": vae} def load_diffusion_models( model_id: str = "THUDM/CogVideoX-5b", transformer_dtype: torch.dtype = torch.bfloat16, revision: Optional[str] = None, cache_dir: Optional[str] = None, **kwargs, ): transformer = CogVideoXTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir ) scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler") return {"transformer": transformer, "scheduler": scheduler} def initialize_pipeline( model_id: str = "THUDM/CogVideoX-5b", text_encoder_dtype: torch.dtype = torch.bfloat16, transformer_dtype: torch.dtype = torch.bfloat16, vae_dtype: torch.dtype = torch.bfloat16, tokenizer: Optional[T5Tokenizer] = None, text_encoder: Optional[T5EncoderModel] = None, transformer: Optional[CogVideoXTransformer3DModel] = None, vae: Optional[AutoencoderKLCogVideoX] = None, scheduler: Optional[CogVideoXDDIMScheduler] = None, device: Optional[torch.device] = None, revision: Optional[str] = None, cache_dir: Optional[str] = None, enable_slicing: bool = False, enable_tiling: bool = False, enable_model_cpu_offload: bool = False, is_training: bool = False, **kwargs, ) -> CogVideoXPipeline: component_name_pairs = [ ("tokenizer", tokenizer), ("text_encoder", text_encoder), ("transformer", transformer), ("vae", vae), ("scheduler", scheduler), ] components = {} for name, component in component_name_pairs: if component is not None: components[name] = component pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) pipe.vae = pipe.vae.to(dtype=vae_dtype) # The transformer should already be in the correct dtype when training, so we don't need to cast it here. # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during # DDP optimizer step. if not is_training: pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) if enable_slicing: pipe.vae.enable_slicing() if enable_tiling: pipe.vae.enable_tiling() if enable_model_cpu_offload: pipe.enable_model_cpu_offload(device=device) else: pipe.to(device=device) return pipe def prepare_conditions( tokenizer, text_encoder, prompt: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, max_sequence_length: int = 226, # TODO: this should be configurable **kwargs, ): device = device or text_encoder.device dtype = dtype or text_encoder.dtype return _get_t5_prompt_embeds( tokenizer=tokenizer, text_encoder=text_encoder, prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) def prepare_latents( vae: AutoencoderKLCogVideoX, image_or_video: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, generator: Optional[torch.Generator] = None, precompute: bool = False, **kwargs, ) -> torch.Tensor: device = device or vae.device dtype = dtype or vae.dtype if image_or_video.ndim == 4: image_or_video = image_or_video.unsqueeze(2) assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" image_or_video = image_or_video.to(device=device, dtype=vae.dtype) image_or_video = image_or_video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] if not precompute: latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) if not vae.config.invert_scale_latents: latents = latents * vae.config.scaling_factor # For training Cog 1.5, we don't need to handle the scaling factor here. # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents # is probably only needed for image-to-video training. # TODO(aryan): investigate this # else: # latents = 1 / vae.config.scaling_factor * latents latents = latents.to(dtype=dtype) return {"latents": latents} else: # handle vae scaling in the `train()` method directly. if vae.use_slicing and image_or_video.shape[0] > 1: encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] h = torch.cat(encoded_slices) else: h = vae._encode(image_or_video) return {"latents": h} def post_latent_preparation( vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs ) -> torch.Tensor: if not vae_config.invert_scale_latents: latents = latents * vae_config.scaling_factor # For training Cog 1.5, we don't need to handle the scaling factor here. # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents # is probably only needed for image-to-video training. # TODO(aryan): investigate this # else: # latents = 1 / vae_config.scaling_factor * latents latents = _pad_frames(latents, patch_size_t) latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] return {"latents": latents} def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: return { "prompts": [x["prompt"] for x in batch[0]], "videos": torch.stack([x["video"] for x in batch[0]]), } def calculate_noisy_latents( scheduler: CogVideoXDDIMScheduler, noise: torch.Tensor, latents: torch.Tensor, timesteps: torch.LongTensor, ) -> torch.Tensor: noisy_latents = scheduler.add_noise(latents, noise, timesteps) return noisy_latents def forward_pass( transformer: CogVideoXTransformer3DModel, scheduler: CogVideoXDDIMScheduler, prompt_embeds: torch.Tensor, latents: torch.Tensor, noisy_latents: torch.Tensor, timesteps: torch.LongTensor, ofs_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself. VAE_SPATIAL_SCALE_FACTOR = 8 transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config batch_size, num_frames, num_channels, height, width = noisy_latents.shape rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR image_rotary_emb = ( prepare_rotary_positional_embeddings( height=height * VAE_SPATIAL_SCALE_FACTOR, width=width * VAE_SPATIAL_SCALE_FACTOR, num_frames=num_frames, vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR, patch_size=transformer_config.patch_size, patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None, attention_head_dim=transformer_config.attention_head_dim, device=transformer.device, base_height=rope_base_height, base_width=rope_base_width, ) if transformer_config.use_rotary_positional_embeddings else None ) ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0) velocity = transformer( hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=prompt_embeds, ofs=ofs_emb, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same # code paths as scheduler.get_velocity(), which can be confusing to understand. denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps) return {"latents": denoised_latents} def validation( pipeline: CogVideoXPipeline, prompt: str, image: Optional[Image.Image] = None, video: Optional[List[Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, num_frames: Optional[int] = None, num_videos_per_prompt: int = 1, generator: Optional[torch.Generator] = None, **kwargs, ): generation_kwargs = { "prompt": prompt, "height": height, "width": width, "num_frames": num_frames, "num_videos_per_prompt": num_videos_per_prompt, "generator": generator, "return_dict": True, "output_type": "pil", } generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} output = pipeline(**generation_kwargs).frames[0] return [("video", output)] def _get_t5_prompt_embeds( tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, prompt: Union[str, List[str]] = None, max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return {"prompt_embeds": prompt_embeds} def _pad_frames(latents: torch.Tensor, patch_size_t: int): if patch_size_t is None or patch_size_t == 1: return latents # `latents` should be of the following format: [B, C, F, H, W]. # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t latent_num_frames = latents.shape[2] additional_frames = patch_size_t - latent_num_frames % patch_size_t if additional_frames > 0: last_frame = latents[:, :, -1:, :, :] padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1) latents = torch.cat([latents, padding_frames], dim=2) return latents # TODO(aryan): refactor into model specs for better re-use COGVIDEOX_T2V_LORA_CONFIG = { "pipeline_cls": CogVideoXPipeline, "load_condition_models": load_condition_models, "load_latent_models": load_latent_models, "load_diffusion_models": load_diffusion_models, "initialize_pipeline": initialize_pipeline, "prepare_conditions": prepare_conditions, "prepare_latents": prepare_latents, "post_latent_preparation": post_latent_preparation, "collate_fn": collate_fn_t2v, "calculate_noisy_latents": calculate_noisy_latents, "forward_pass": forward_pass, "validation": validation, }