jbilcke-hf's picture
jbilcke-hf HF Staff
initial commit log 🪵🦫
91fb4ef
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from accelerate.logging import get_logger
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
logger = get_logger("finetrainers") # pylint: disable=invalid-name
def load_condition_models(
model_id: str = "Lightricks/LTX-Video",
text_encoder_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> Dict[str, nn.Module]:
tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
text_encoder = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
)
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
def load_latent_models(
model_id: str = "Lightricks/LTX-Video",
vae_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> Dict[str, nn.Module]:
vae = AutoencoderKLLTXVideo.from_pretrained(
model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
)
return {"vae": vae}
def load_diffusion_models(
model_id: str = "Lightricks/LTX-Video",
transformer_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
**kwargs,
) -> Dict[str, nn.Module]:
transformer = LTXVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
def initialize_pipeline(
model_id: str = "Lightricks/LTX-Video",
text_encoder_dtype: torch.dtype = torch.bfloat16,
transformer_dtype: torch.dtype = torch.bfloat16,
vae_dtype: torch.dtype = torch.bfloat16,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
transformer: Optional[LTXVideoTransformer3DModel] = None,
vae: Optional[AutoencoderKLLTXVideo] = None,
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
device: Optional[torch.device] = None,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
is_training: bool = False,
**kwargs,
) -> LTXPipeline:
component_name_pairs = [
("tokenizer", tokenizer),
("text_encoder", text_encoder),
("transformer", transformer),
("vae", vae),
("scheduler", scheduler),
]
components = {}
for name, component in component_name_pairs:
if component is not None:
components[name] = component
pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
pipe.vae = pipe.vae.to(dtype=vae_dtype)
# The transformer should already be in the correct dtype when training, so we don't need to cast it here.
# If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
# DDP optimizer step.
if not is_training:
pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload(device=device)
else:
pipe.to(device=device)
return pipe
def prepare_conditions(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 128,
**kwargs,
) -> torch.Tensor:
device = device or text_encoder.device
dtype = dtype or text_encoder.dtype
if isinstance(prompt, str):
prompt = [prompt]
return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length)
def prepare_latents(
vae: AutoencoderKLLTXVideo,
image_or_video: torch.Tensor,
patch_size: int = 1,
patch_size_t: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
generator: Optional[torch.Generator] = None,
precompute: bool = False,
) -> torch.Tensor:
device = device or vae.device
if image_or_video.ndim == 4:
image_or_video = image_or_video.unsqueeze(2)
assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
if not precompute:
latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
latents = latents.to(dtype=dtype)
_, _, num_frames, height, width = latents.shape
latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std)
latents = _pack_latents(latents, patch_size, patch_size_t)
return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
else:
if vae.use_slicing and image_or_video.shape[0] > 1:
encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
h = torch.cat(encoded_slices)
else:
h = vae._encode(image_or_video)
_, _, num_frames, height, width = h.shape
# TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file
# if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored
# so as to reduce the disk memory requirements of the precomputed files.
return {
"latents": h,
"num_frames": num_frames,
"height": height,
"width": width,
"latents_mean": vae.latents_mean,
"latents_std": vae.latents_std,
}
def post_latent_preparation(
latents: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
num_frames: int,
height: int,
width: int,
patch_size: int = 1,
patch_size_t: int = 1,
**kwargs,
) -> torch.Tensor:
latents = _normalize_latents(latents, latents_mean, latents_std)
latents = _pack_latents(latents, patch_size, patch_size_t)
return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
return {
"prompts": [x["prompt"] for x in batch[0]],
"videos": torch.stack([x["video"] for x in batch[0]]),
}
def forward_pass(
transformer: LTXVideoTransformer3DModel,
prompt_embeds: torch.Tensor,
prompt_attention_mask: torch.Tensor,
latents: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.LongTensor,
num_frames: int,
height: int,
width: int,
**kwargs,
) -> torch.Tensor:
# TODO(aryan): make configurable
frame_rate = 25
latent_frame_rate = frame_rate / 8
spatial_compression_ratio = 32
rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio]
denoised_latents = transformer(
hidden_states=noisy_latents,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
encoder_attention_mask=prompt_attention_mask,
num_frames=num_frames,
height=height,
width=width,
rope_interpolation_scale=rope_interpolation_scale,
return_dict=False,
)[0]
return {"latents": denoised_latents}
def validation(
pipeline: LTXPipeline,
prompt: str,
image: Optional[Image.Image] = None,
video: Optional[List[Image.Image]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
frame_rate: int = 24,
num_videos_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
**kwargs,
):
generation_kwargs = {
"prompt": prompt,
"height": height,
"width": width,
"num_frames": num_frames,
"frame_rate": frame_rate,
"num_videos_per_prompt": num_videos_per_prompt,
"generator": generator,
"return_dict": True,
"output_type": "pil",
}
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
video = pipeline(**generation_kwargs).frames[0]
return [("video", video)]
def _encode_prompt_t5(
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
prompt: List[str],
device: torch.device,
dtype: torch.dtype,
max_sequence_length,
) -> torch.Tensor:
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
LTX_VIDEO_T2V_LORA_CONFIG = {
"pipeline_cls": LTXPipeline,
"load_condition_models": load_condition_models,
"load_latent_models": load_latent_models,
"load_diffusion_models": load_diffusion_models,
"initialize_pipeline": initialize_pipeline,
"prepare_conditions": prepare_conditions,
"prepare_latents": prepare_latents,
"post_latent_preparation": post_latent_preparation,
"collate_fn": collate_fn_t2v,
"forward_pass": forward_pass,
"validation": validation,
}