jbilcke-hf's picture
jbilcke-hf HF Staff
initial commit log 🪵🦫
91fb4ef
from typing import Any, Dict, List, Optional, Union
import torch
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
from .utils import prepare_rotary_positional_embeddings
def load_condition_models(
model_id: str = "THUDM/CogVideoX-5b",
text_encoder_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs,
):
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
text_encoder = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
)
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
def load_latent_models(
model_id: str = "THUDM/CogVideoX-5b",
vae_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs,
):
vae = AutoencoderKLCogVideoX.from_pretrained(
model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
)
return {"vae": vae}
def load_diffusion_models(
model_id: str = "THUDM/CogVideoX-5b",
transformer_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs,
):
transformer = CogVideoXTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
)
scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
return {"transformer": transformer, "scheduler": scheduler}
def initialize_pipeline(
model_id: str = "THUDM/CogVideoX-5b",
text_encoder_dtype: torch.dtype = torch.bfloat16,
transformer_dtype: torch.dtype = torch.bfloat16,
vae_dtype: torch.dtype = torch.bfloat16,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
transformer: Optional[CogVideoXTransformer3DModel] = None,
vae: Optional[AutoencoderKLCogVideoX] = None,
scheduler: Optional[CogVideoXDDIMScheduler] = None,
device: Optional[torch.device] = None,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
is_training: bool = False,
**kwargs,
) -> CogVideoXPipeline:
component_name_pairs = [
("tokenizer", tokenizer),
("text_encoder", text_encoder),
("transformer", transformer),
("vae", vae),
("scheduler", scheduler),
]
components = {}
for name, component in component_name_pairs:
if component is not None:
components[name] = component
pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
pipe.vae = pipe.vae.to(dtype=vae_dtype)
# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
# DDP optimizer step.
if not is_training:
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload(device=device)
else:
pipe.to(device=device)
return pipe
def prepare_conditions(
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 226, # TODO: this should be configurable
**kwargs,
):
device = device or text_encoder.device
dtype = dtype or text_encoder.dtype
return _get_t5_prompt_embeds(
tokenizer=tokenizer,
text_encoder=text_encoder,
prompt=prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
def prepare_latents(
vae: AutoencoderKLCogVideoX,
image_or_video: torch.Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
generator: Optional[torch.Generator] = None,
precompute: bool = False,
**kwargs,
) -> torch.Tensor:
device = device or vae.device
dtype = dtype or vae.dtype
if image_or_video.ndim == 4:
image_or_video = image_or_video.unsqueeze(2)
assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
image_or_video = image_or_video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
if not precompute:
latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
if not vae.config.invert_scale_latents:
latents = latents * vae.config.scaling_factor
# For training Cog 1.5, we don't need to handle the scaling factor here.
# The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
# is probably only needed for image-to-video training.
# TODO(aryan): investigate this
# else:
# latents = 1 / vae.config.scaling_factor * latents
latents = latents.to(dtype=dtype)
return {"latents": latents}
else:
# handle vae scaling in the `train()` method directly.
if vae.use_slicing and image_or_video.shape[0] > 1:
encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
h = torch.cat(encoded_slices)
else:
h = vae._encode(image_or_video)
return {"latents": h}
def post_latent_preparation(
vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs
) -> torch.Tensor:
if not vae_config.invert_scale_latents:
latents = latents * vae_config.scaling_factor
# For training Cog 1.5, we don't need to handle the scaling factor here.
# The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
# is probably only needed for image-to-video training.
# TODO(aryan): investigate this
# else:
# latents = 1 / vae_config.scaling_factor * latents
latents = _pad_frames(latents, patch_size_t)
latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
return {"latents": latents}
def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
return {
"prompts": [x["prompt"] for x in batch[0]],
"videos": torch.stack([x["video"] for x in batch[0]]),
}
def calculate_noisy_latents(
scheduler: CogVideoXDDIMScheduler,
noise: torch.Tensor,
latents: torch.Tensor,
timesteps: torch.LongTensor,
) -> torch.Tensor:
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
return noisy_latents
def forward_pass(
transformer: CogVideoXTransformer3DModel,
scheduler: CogVideoXDDIMScheduler,
prompt_embeds: torch.Tensor,
latents: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.LongTensor,
ofs_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
VAE_SPATIAL_SCALE_FACTOR = 8
transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
batch_size, num_frames, num_channels, height, width = noisy_latents.shape
rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
image_rotary_emb = (
prepare_rotary_positional_embeddings(
height=height * VAE_SPATIAL_SCALE_FACTOR,
width=width * VAE_SPATIAL_SCALE_FACTOR,
num_frames=num_frames,
vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
patch_size=transformer_config.patch_size,
patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None,
attention_head_dim=transformer_config.attention_head_dim,
device=transformer.device,
base_height=rope_base_height,
base_width=rope_base_width,
)
if transformer_config.use_rotary_positional_embeddings
else None
)
ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0)
velocity = transformer(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=prompt_embeds,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
# For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
# code paths as scheduler.get_velocity(), which can be confusing to understand.
denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps)
return {"latents": denoised_latents}
def validation(
pipeline: CogVideoXPipeline,
prompt: str,
image: Optional[Image.Image] = None,
video: Optional[List[Image.Image]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_videos_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
**kwargs,
):
generation_kwargs = {
"prompt": prompt,
"height": height,
"width": width,
"num_frames": num_frames,
"num_videos_per_prompt": num_videos_per_prompt,
"generator": generator,
"return_dict": True,
"output_type": "pil",
}
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
output = pipeline(**generation_kwargs).frames[0]
return [("video", output)]
def _get_t5_prompt_embeds(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return {"prompt_embeds": prompt_embeds}
def _pad_frames(latents: torch.Tensor, patch_size_t: int):
if patch_size_t is None or patch_size_t == 1:
return latents
# `latents` should be of the following format: [B, C, F, H, W].
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
latent_num_frames = latents.shape[2]
additional_frames = patch_size_t - latent_num_frames % patch_size_t
if additional_frames > 0:
last_frame = latents[:, :, -1:, :, :]
padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1)
latents = torch.cat([latents, padding_frames], dim=2)
return latents
# TODO(aryan): refactor into model specs for better re-use
COGVIDEOX_T2V_LORA_CONFIG = {
"pipeline_cls": CogVideoXPipeline,
"load_condition_models": load_condition_models,
"load_latent_models": load_latent_models,
"load_diffusion_models": load_diffusion_models,
"initialize_pipeline": initialize_pipeline,
"prepare_conditions": prepare_conditions,
"prepare_latents": prepare_latents,
"post_latent_preparation": post_latent_preparation,
"collate_fn": collate_fn_t2v,
"calculate_noisy_latents": calculate_noisy_latents,
"forward_pass": forward_pass,
"validation": validation,
}