VideoModelStudio / training /cogvideox /cogvideox_image_to_video_sft.py
jbilcke-hf's picture
jbilcke-hf HF Staff
initial commit log 🪵🦫
91fb4ef
# Copyright 2024 The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import math
import os
import random
import shutil
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict
import diffusers
import torch
import transformers
import wandb
from accelerate import Accelerator, DistributedType, init_empty_weights
from accelerate.logging import get_logger
from accelerate.utils import (
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
set_seed,
)
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXTransformer3DModel,
)
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from huggingface_hub import create_repo, upload_folder
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel
from args import get_args # isort:skip
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
from text_encoder import compute_prompt_embeddings # isort:skip
from utils import (
get_gradient_norm,
get_optimizer,
prepare_rotary_positional_embeddings,
print_memory,
reset_memory,
unwrap_model,
)
logger = get_logger(__name__)
def save_model_card(
repo_id: str,
videos=None,
base_model: str = None,
validation_prompt=None,
repo_folder=None,
fps=8,
):
widget_dict = []
if videos is not None:
for i, video in enumerate(videos):
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
widget_dict.append(
{
"text": validation_prompt if validation_prompt else " ",
"output": {"url": f"video_{i}.mp4"},
}
)
model_description = f"""
# CogVideoX Full Finetune
<Gallery />
## Model description
This is a full finetune of the CogVideoX model `{base_model}`.
## License
Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="other",
base_model=base_model,
prompt=validation_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-video",
"image-to-video",
"diffusers-training",
"diffusers",
"cogvideox",
"cogvideox-diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
accelerator: Accelerator,
pipe: CogVideoXImageToVideoPipeline,
args: Dict[str, Any],
pipeline_args: Dict[str, Any],
is_final_validation: bool = False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}."
)
pipe = pipe.to(accelerator.device)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
videos = []
for _ in range(args.num_validation_videos):
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
videos.append(video)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "wandb":
video_filenames = []
for i, video in enumerate(videos):
prompt = (
pipeline_args["prompt"][:25]
.replace(" ", "_")
.replace(" ", "_")
.replace("'", "_")
.replace('"', "_")
.replace("/", "_")
)
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
export_to_video(video, filename, fps=8)
video_filenames.append(filename)
tracker.log(
{
phase_name: [
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
for i, filename in enumerate(video_filenames)
]
}
)
return videos
def run_validation(
args: Dict[str, Any],
accelerator: Accelerator,
transformer,
scheduler,
model_config: Dict[str, Any],
weight_dtype: torch.dtype,
) -> None:
accelerator.print("===== Memory before validation =====")
print_memory(accelerator.device)
torch.cuda.synchronize(accelerator.device)
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=unwrap_model(accelerator, transformer),
scheduler=scheduler,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
if args.enable_slicing:
pipe.vae.enable_slicing()
if args.enable_tiling:
pipe.vae.enable_tiling()
if args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload()
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
validation_images = args.validation_images.split(args.validation_prompt_separator)
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
pipeline_args = {
"image": load_image(validation_image),
"prompt": validation_prompt,
"guidance_scale": args.guidance_scale,
"use_dynamic_cfg": args.use_dynamic_cfg,
"height": args.height,
"width": args.width,
"max_sequence_length": model_config.max_text_seq_length,
}
log_validation(
pipe=pipe,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
)
accelerator.print("===== Memory after validation =====")
print_memory(accelerator.device)
reset_memory(accelerator.device)
del pipe
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize(accelerator.device)
class CollateFunction:
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
self.weight_dtype = weight_dtype
self.load_tensors = load_tensors
def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]:
prompts = [x["prompt"] for x in data[0]]
if self.load_tensors:
prompts = torch.stack(prompts).to(dtype=self.weight_dtype, non_blocking=True)
images = [x["image"] for x in data[0]]
images = torch.stack(images).to(dtype=self.weight_dtype, non_blocking=True)
videos = [x["video"] for x in data[0]]
videos = torch.stack(videos).to(dtype=self.weight_dtype, non_blocking=True)
return {
"images": images,
"videos": videos,
"prompts": prompts,
}
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_process_group_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
).repo_id
# Prepare models and scheduler
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
text_encoder = T5EncoderModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
# CogVideoX-2b weights are stored in float16
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
transformer = CogVideoXTransformer3DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=load_dtype,
revision=args.revision,
variant=args.variant,
)
if args.ignore_learned_positional_embeddings:
del transformer.patch_embed.pos_embedding
transformer.patch_embed.use_learned_positional_embeddings = False
transformer.config.use_learned_positional_embeddings = False
vae = AutoencoderKLCogVideoX.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
if args.enable_slicing:
vae.enable_slicing()
if args.enable_tiling:
vae.enable_tiling()
text_encoder.requires_grad_(False)
vae.requires_grad_(False)
transformer.requires_grad_(True)
VAE_SCALING_FACTOR = vae.config.scaling_factor
VAE_SCALE_FACTOR_SPATIAL = 2 ** (len(vae.config.block_out_channels) - 1)
RoPE_BASE_HEIGHT = transformer.config.sample_height * VAE_SCALE_FACTOR_SPATIAL
RoPE_BASE_WIDTH = transformer.config.sample_width * VAE_SCALE_FACTOR_SPATIAL
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.state.deepspeed_plugin:
# DeepSpeed is handling precision, use what's in the DeepSpeed config
if (
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
):
weight_dtype = torch.float16
if (
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
):
weight_dtype = torch.bfloat16
else:
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
model = unwrap_model(accelerator, model)
model.save_pretrained(
os.path.join(output_dir, "transformer"), safe_serialization=True, max_shard_size="5GB"
)
else:
raise ValueError(f"Unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
def load_model_hook(models, input_dir):
transformer_ = None
init_under_meta = False
# This is a bit of a hack but I don't know any other solution.
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
transformer_ = unwrap_model(accelerator, model)
else:
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
else:
with init_empty_weights():
transformer_ = CogVideoXTransformer3DModel.from_config(
args.pretrained_model_name_or_path, subfolder="transformer"
)
init_under_meta = True
load_model = CogVideoXTransformer3DModel.from_pretrained(os.path.join(input_dir, "transformer"))
transformer_.register_to_config(**load_model.config)
transformer_.load_state_dict(load_model.state_dict(), assign=init_under_meta)
del load_model
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
cast_training_params([transformer_])
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
cast_training_params([transformer], dtype=torch.float32)
transformer_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
# Optimization parameters
transformer_parameters_with_lr = {
"params": transformer_parameters,
"lr": args.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr]
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(
params_to_optimize=params_to_optimize,
optimizer_name=args.optimizer,
learning_rate=args.learning_rate,
beta1=args.beta1,
beta2=args.beta2,
beta3=args.beta3,
epsilon=args.epsilon,
weight_decay=args.weight_decay,
prodigy_decouple=args.prodigy_decouple,
prodigy_use_bias_correction=args.prodigy_use_bias_correction,
prodigy_safeguard_warmup=args.prodigy_safeguard_warmup,
use_8bit=args.use_8bit,
use_4bit=args.use_4bit,
use_torchao=args.use_torchao,
use_deepspeed=use_deepspeed_optimizer,
use_cpu_offload_optimizer=args.use_cpu_offload_optimizer,
offload_gradients=args.offload_gradients,
)
# Dataset and DataLoader
dataset_init_kwargs = {
"data_root": args.data_root,
"dataset_file": args.dataset_file,
"caption_column": args.caption_column,
"video_column": args.video_column,
"max_num_frames": args.max_num_frames,
"id_token": args.id_token,
"height_buckets": args.height_buckets,
"width_buckets": args.width_buckets,
"frame_buckets": args.frame_buckets,
"load_tensors": args.load_tensors,
"random_flip": args.random_flip,
"image_to_video": True,
}
if args.video_reshape_mode is None:
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
else:
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
)
collate_fn = CollateFunction(weight_dtype, args.load_tensors)
train_dataloader = DataLoader(
train_dataset,
batch_size=1,
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
collate_fn=collate_fn,
num_workers=args.dataloader_num_workers,
pin_memory=args.pin_memory,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
if args.use_cpu_offload_optimizer:
lr_scheduler = None
accelerator.print(
"CPU Offload Optimizer cannot be used with DeepSpeed or builtin PyTorch LR Schedulers. If "
"you are training with those settings, they will be ignored."
)
else:
if use_deepspeed_scheduler:
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,
total_num_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
)
else:
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = args.tracker_name or "cogvideox-sft"
accelerator.init_trackers(tracker_name, config=vars(args))
accelerator.print("===== Memory before training =====")
reset_memory(accelerator.device)
print_memory(accelerator.device)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("***** Running training *****")
accelerator.print(f" Num trainable parameters = {num_trainable_parameters}")
accelerator.print(f" Num examples = {len(train_dataset)}")
accelerator.print(f" Num batches each epoch = {len(train_dataloader)}")
accelerator.print(f" Num epochs = {args.num_train_epochs}")
accelerator.print(f" Instantaneous batch size per device = {args.train_batch_size}")
accelerator.print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
accelerator.print(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
accelerator.print(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if not args.resume_from_checkpoint:
initial_global_step = 0
else:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
# For DeepSpeed training
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
if args.load_tensors:
del vae, text_encoder
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize(accelerator.device)
alphas_cumprod = scheduler.alphas_cumprod.to(accelerator.device, dtype=torch.float32)
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
logs = {}
with accelerator.accumulate(models_to_accumulate):
images = batch["images"].to(accelerator.device, non_blocking=True)
videos = batch["videos"].to(accelerator.device, non_blocking=True)
prompts = batch["prompts"]
# Encode videos
if not args.load_tensors:
images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
image_noise_sigma = torch.normal(
mean=-3.0, std=0.5, size=(images.size(0),), device=accelerator.device, dtype=weight_dtype
)
image_noise_sigma = torch.exp(image_noise_sigma)
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
image_latent_dist = vae.encode(noisy_images).latent_dist
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(videos).latent_dist
else:
image_latent_dist = DiagonalGaussianDistribution(images)
latent_dist = DiagonalGaussianDistribution(videos)
image_latents = image_latent_dist.sample() * VAE_SCALING_FACTOR
image_latents = image_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
image_latents = image_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
video_latents = latent_dist.sample() * VAE_SCALING_FACTOR
video_latents = video_latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
video_latents = video_latents.to(memory_format=torch.contiguous_format, dtype=weight_dtype)
padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])
latent_padding = image_latents.new_zeros(padding_shape)
image_latents = torch.cat([image_latents, latent_padding], dim=1)
if random.random() < args.noised_image_dropout:
image_latents = torch.zeros_like(image_latents)
# Encode prompts
if not args.load_tensors:
prompt_embeds = compute_prompt_embeddings(
tokenizer,
text_encoder,
prompts,
model_config.max_text_seq_length,
accelerator.device,
weight_dtype,
requires_grad=False,
)
else:
prompt_embeds = prompts.to(dtype=weight_dtype)
# Sample noise that will be added to the latents
noise = torch.randn_like(video_latents)
batch_size, num_frames, num_channels, height, width = video_latents.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0,
scheduler.config.num_train_timesteps,
(batch_size,),
dtype=torch.int64,
device=accelerator.device,
)
# Prepare rotary embeds
image_rotary_emb = (
prepare_rotary_positional_embeddings(
height=height * VAE_SCALE_FACTOR_SPATIAL,
width=width * VAE_SCALE_FACTOR_SPATIAL,
num_frames=num_frames,
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
patch_size=model_config.patch_size,
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
attention_head_dim=model_config.attention_head_dim,
device=accelerator.device,
base_height=RoPE_BASE_HEIGHT,
base_width=RoPE_BASE_WIDTH,
)
if model_config.use_rotary_positional_embeddings
else None
)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)
noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)
model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
ofs_embed_dim = model_config.ofs_embed_dim if hasattr(model_config, "ofs_embed_dim") else None,
ofs_emb = None if ofs_embed_dim is None else noisy_model_input.new_full((1,), fill_value=2.0)
# Predict the noise residual
model_output = transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)
weights = 1 / (1 - alphas_cumprod[timesteps])
while len(weights.shape) < len(model_pred.shape):
weights = weights.unsqueeze(-1)
target = video_latents
loss = torch.mean(
(weights * (model_pred - target) ** 2).reshape(batch_size, -1),
dim=1,
)
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
gradient_norm_after_clip = get_gradient_norm(transformer.parameters())
logs.update(
{
"gradient_norm_before_clip": gradient_norm_before_clip,
"gradient_norm_after_clip": gradient_norm_after_clip,
}
)
if accelerator.state.deepspeed_plugin is None:
optimizer.step()
optimizer.zero_grad()
if not args.use_cpu_offload_optimizer:
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# Checkpointing
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
# Validation
should_run_validation = args.validation_prompt is not None and (
args.validation_steps is not None and global_step % args.validation_steps == 0
)
if should_run_validation:
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
logs.update(
{
"loss": loss.detach().item(),
"lr": last_lr,
}
)
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
should_run_validation = args.validation_prompt is not None and (
args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
)
if should_run_validation:
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(accelerator, transformer)
dtype = (
torch.float16
if args.mixed_precision == "fp16"
else torch.bfloat16
if args.mixed_precision == "bf16"
else torch.float32
)
transformer = transformer.to(dtype)
transformer.save_pretrained(
os.path.join(args.output_dir, "transformer"),
safe_serialization=True,
max_shard_size="5GB",
)
# Cleanup trained models to save memory
if args.load_tensors:
del transformer
else:
del transformer, text_encoder, vae
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize(accelerator.device)
accelerator.print("===== Memory before testing =====")
print_memory(accelerator.device)
reset_memory(accelerator.device)
# Final test inference
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
if args.enable_slicing:
pipe.vae.enable_slicing()
if args.enable_tiling:
pipe.vae.enable_tiling()
if args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload()
# Run inference
validation_outputs = []
if args.validation_prompt and args.num_validation_videos > 0:
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
validation_images = args.validation_images.split(args.validation_prompt_separator)
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
pipeline_args = {
"image": load_image(validation_image),
"prompt": validation_prompt,
"guidance_scale": args.guidance_scale,
"use_dynamic_cfg": args.use_dynamic_cfg,
"height": args.height,
"width": args.width,
}
video = log_validation(
accelerator=accelerator,
pipe=pipe,
args=args,
pipeline_args=pipeline_args,
is_final_validation=True,
)
validation_outputs.extend(video)
accelerator.print("===== Memory after testing =====")
print_memory(accelerator.device)
reset_memory(accelerator.device)
torch.cuda.synchronize(accelerator.device)
if args.push_to_hub:
save_model_card(
repo_id,
videos=validation_outputs,
base_model=args.pretrained_model_name_or_path,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
fps=args.fps,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
args = get_args()
main(args)