Spaces:
Runtime error
Runtime error
# Copyright 2024 The HuggingFace Team. | |
# All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gc | |
import random | |
from glob import glob | |
import math | |
import os | |
import torch.nn.functional as F | |
import numpy as np | |
from pathlib import Path | |
from typing import Any, Dict, Tuple, List | |
import torch | |
import wandb | |
from diffusers import FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel | |
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | |
from diffusers.training_utils import cast_training_params | |
from diffusers.utils import export_to_video | |
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | |
from huggingface_hub import create_repo, upload_folder | |
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict | |
from torch.utils.data import DataLoader | |
from tqdm.auto import tqdm | |
from args import get_args # isort:skip | |
from dataset_simple import LatentEmbedDataset | |
import sys | |
from utils import print_memory, reset_memory # isort:skip | |
# Taken from | |
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/demos/fine_tuner/train.py#L139 | |
def get_cosine_annealing_lr_scheduler( | |
optimizer: torch.optim.Optimizer, | |
warmup_steps: int, | |
total_steps: int, | |
): | |
def lr_lambda(step): | |
if step < warmup_steps: | |
return float(step) / float(max(1, warmup_steps)) | |
else: | |
return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps))) | |
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) | |
def save_model_card( | |
repo_id: str, | |
videos=None, | |
base_model: str = None, | |
validation_prompt=None, | |
repo_folder=None, | |
fps=30, | |
): | |
widget_dict = [] | |
if videos is not None and len(videos) > 0: | |
for i, video in enumerate(videos): | |
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4"), fps=fps) | |
widget_dict.append( | |
{ | |
"text": validation_prompt if validation_prompt else " ", | |
"output": {"url": f"final_video_{i}.mp4"}, | |
} | |
) | |
model_description = f""" | |
# Mochi-1 Preview LoRA Finetune | |
<Gallery /> | |
## Model description | |
This is a lora finetune of the Mochi-1 preview model `{base_model}`. | |
The model was trained using [CogVideoX Factory](https://github.com/a-r-r-o-w/cogvideox-factory) - a repository containing memory-optimized training scripts for the CogVideoX and Mochi family of models using [TorchAO](https://github.com/pytorch/ao) and [DeepSpeed](https://github.com/microsoft/DeepSpeed). The scripts were adopted from [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). | |
## Download model | |
[Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. | |
## Usage | |
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed. | |
```py | |
from diffusers import MochiPipeline | |
from diffusers.utils import export_to_video | |
import torch | |
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview") | |
pipe.load_lora_weights("CHANGE_ME") | |
pipe.enable_model_cpu_offload() | |
with torch.autocast("cuda", torch.bfloat16): | |
video = pipe( | |
prompt="CHANGE_ME", | |
guidance_scale=6.0, | |
num_inference_steps=64, | |
height=480, | |
width=848, | |
max_sequence_length=256, | |
output_type="np" | |
).frames[0] | |
export_to_video(video) | |
``` | |
For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. | |
""" | |
model_card = load_or_create_model_card( | |
repo_id_or_path=repo_id, | |
from_training=True, | |
license="apache-2.0", | |
base_model=base_model, | |
prompt=validation_prompt, | |
model_description=model_description, | |
widget=widget_dict, | |
) | |
tags = [ | |
"text-to-video", | |
"diffusers-training", | |
"diffusers", | |
"lora", | |
"mochi-1-preview", | |
"mochi-1-preview-diffusers", | |
"template:sd-lora", | |
] | |
model_card = populate_model_card(model_card, tags=tags) | |
model_card.save(os.path.join(repo_folder, "README.md")) | |
def log_validation( | |
pipe: MochiPipeline, | |
args: Dict[str, Any], | |
pipeline_args: Dict[str, Any], | |
epoch, | |
wandb_run: str = None, | |
is_final_validation: bool = False, | |
): | |
print( | |
f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." | |
) | |
phase_name = "test" if is_final_validation else "validation" | |
if not args.enable_model_cpu_offload: | |
pipe = pipe.to("cuda") | |
# run inference | |
generator = torch.manual_seed(args.seed) if args.seed else None | |
videos = [] | |
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False): | |
for _ in range(args.num_validation_videos): | |
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] | |
videos.append(video) | |
video_filenames = [] | |
for i, video in enumerate(videos): | |
prompt = ( | |
pipeline_args["prompt"][:25] | |
.replace(" ", "_") | |
.replace(" ", "_") | |
.replace("'", "_") | |
.replace('"', "_") | |
.replace("/", "_") | |
) | |
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") | |
export_to_video(video, filename, fps=30) | |
video_filenames.append(filename) | |
if wandb_run: | |
wandb.log( | |
{ | |
phase_name: [ | |
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}", fps=30) | |
for i, filename in enumerate(video_filenames) | |
] | |
} | |
) | |
return videos | |
# Adapted from the original code: | |
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/pipelines.py#L578 | |
def cast_dit(model, dtype): | |
for name, module in model.named_modules(): | |
if isinstance(module, torch.nn.Linear): | |
assert any( | |
n in name for n in ["time_embed", "proj_out", "blocks", "norm_out"] | |
), f"Unexpected linear layer: {name}" | |
module.to(dtype=dtype) | |
elif isinstance(module, torch.nn.Conv2d): | |
module.to(dtype=dtype) | |
return model | |
def save_checkpoint(model, optimizer, lr_scheduler, global_step, checkpoint_path): | |
lora_state_dict = get_peft_model_state_dict(model) | |
torch.save( | |
{ | |
"state_dict": lora_state_dict, | |
"optimizer": optimizer.state_dict(), | |
"lr_scheduler": lr_scheduler.state_dict(), | |
"global_step": global_step, | |
}, | |
checkpoint_path, | |
) | |
class CollateFunction: | |
def __init__(self, caption_dropout: float = None) -> None: | |
self.caption_dropout = caption_dropout | |
def __call__(self, samples: List[Tuple[dict, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
ldists = torch.cat([data[0]["ldist"] for data in samples], dim=0) | |
z = DiagonalGaussianDistribution(ldists).sample() | |
assert torch.isfinite(z).all() | |
# Sample noise which we will add to the samples. | |
eps = torch.randn_like(z) | |
sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32) | |
prompt_embeds = torch.cat([data[1]["prompt_embeds"] for data in samples], dim=0) | |
prompt_attention_mask = torch.cat([data[1]["prompt_attention_mask"] for data in samples], dim=0) | |
if self.caption_dropout and random.random() < self.caption_dropout: | |
prompt_embeds.zero_() | |
prompt_attention_mask = prompt_attention_mask.long() | |
prompt_attention_mask.zero_() | |
prompt_attention_mask = prompt_attention_mask.bool() | |
return dict( | |
z=z, eps=eps, sigma=sigma, prompt_embeds=prompt_embeds, prompt_attention_mask=prompt_attention_mask | |
) | |
def main(args): | |
if not torch.cuda.is_available(): | |
raise ValueError("Not supported without CUDA.") | |
if args.report_to == "wandb" and args.hub_token is not None: | |
raise ValueError( | |
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." | |
" Please use `huggingface-cli login` to authenticate with the Hub." | |
) | |
# Handle the repository creation | |
if args.output_dir is not None: | |
os.makedirs(args.output_dir, exist_ok=True) | |
if args.push_to_hub: | |
repo_id = create_repo( | |
repo_id=args.hub_model_id or Path(args.output_dir).name, | |
exist_ok=True, | |
).repo_id | |
# Prepare models and scheduler | |
transformer = MochiTransformer3DModel.from_pretrained( | |
args.pretrained_model_name_or_path, | |
subfolder="transformer", | |
revision=args.revision, | |
variant=args.variant, | |
) | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="scheduler" | |
) | |
transformer.requires_grad_(False) | |
transformer.to("cuda") | |
if args.gradient_checkpointing: | |
transformer.enable_gradient_checkpointing() | |
if args.cast_dit: | |
transformer = cast_dit(transformer, torch.bfloat16) | |
if args.compile_dit: | |
transformer.compile() | |
# now we will add new LoRA weights to the attention layers | |
transformer_lora_config = LoraConfig( | |
r=args.rank, | |
lora_alpha=args.lora_alpha, | |
init_lora_weights="gaussian", | |
target_modules=args.target_modules, | |
) | |
transformer.add_adapter(transformer_lora_config) | |
# Enable TF32 for faster training on Ampere GPUs, | |
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
if args.allow_tf32 and torch.cuda.is_available(): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
if args.scale_lr: | |
args.learning_rate = args.learning_rate * args.train_batch_size | |
# only upcast trainable parameters (LoRA) into fp32 | |
cast_training_params([transformer], dtype=torch.float32) | |
# Prepare optimizer | |
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) | |
num_trainable_parameters = sum(param.numel() for param in transformer_lora_parameters) | |
optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.learning_rate, weight_decay=args.weight_decay) | |
# Dataset and DataLoader | |
train_vids = list(sorted(glob(f"{args.data_root}/*.mp4"))) | |
train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")] | |
print(f"Found {len(train_vids)} training videos in {args.data_root}") | |
assert len(train_vids) > 0, f"No training data found in {args.data_root}" | |
collate_fn = CollateFunction(caption_dropout=args.caption_dropout) | |
train_dataset = LatentEmbedDataset(train_vids, repeat=1) | |
train_dataloader = DataLoader( | |
train_dataset, | |
collate_fn=collate_fn, | |
batch_size=args.train_batch_size, | |
num_workers=args.dataloader_num_workers, | |
pin_memory=args.pin_memory, | |
) | |
# LR scheduler and math around the number of training steps. | |
overrode_max_train_steps = False | |
num_update_steps_per_epoch = len(train_dataloader) | |
if args.max_train_steps is None: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
overrode_max_train_steps = True | |
lr_scheduler = get_cosine_annealing_lr_scheduler( | |
optimizer, warmup_steps=args.lr_warmup_steps, total_steps=args.max_train_steps | |
) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = len(train_dataloader) | |
if overrode_max_train_steps: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
# Afterwards we recalculate our number of training epochs | |
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
wandb_run = None | |
if args.report_to == "wandb": | |
tracker_name = args.tracker_name or "mochi-1-lora" | |
wandb_run = wandb.init(project=tracker_name, config=vars(args)) | |
# Resume from checkpoint if specified | |
if args.resume_from_checkpoint: | |
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu", weights_only=True) | |
if "global_step" in checkpoint: | |
global_step = checkpoint["global_step"] | |
if "optimizer" in checkpoint: | |
optimizer.load_state_dict(checkpoint["optimizer"]) | |
if "lr_scheduler" in checkpoint: | |
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
set_peft_model_state_dict(transformer, checkpoint["state_dict"]) | |
print(f"Resuming from checkpoint: {args.resume_from_checkpoint}") | |
print(f"Resuming from global step: {global_step}") | |
else: | |
global_step = 0 | |
print("===== Memory before training =====") | |
reset_memory("cuda") | |
print_memory("cuda") | |
# Train! | |
total_batch_size = args.train_batch_size | |
print("***** Running training *****") | |
print(f" Num trainable parameters = {num_trainable_parameters}") | |
print(f" Num examples = {len(train_dataset)}") | |
print(f" Num batches each epoch = {len(train_dataloader)}") | |
print(f" Num epochs = {args.num_train_epochs}") | |
print(f" Instantaneous batch size per device = {args.train_batch_size}") | |
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
print(f" Total optimization steps = {args.max_train_steps}") | |
first_epoch = 0 | |
progress_bar = tqdm( | |
range(0, args.max_train_steps), | |
initial=global_step, | |
desc="Steps", | |
) | |
for epoch in range(first_epoch, args.num_train_epochs): | |
transformer.train() | |
for step, batch in enumerate(train_dataloader): | |
with torch.no_grad(): | |
z = batch["z"].to("cuda") | |
eps = batch["eps"].to("cuda") | |
sigma = batch["sigma"].to("cuda") | |
prompt_embeds = batch["prompt_embeds"].to("cuda") | |
prompt_attention_mask = batch["prompt_attention_mask"].to("cuda") | |
sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1] | |
# Add noise according to flow matching. | |
# zt = (1 - texp) * x + texp * z1 | |
z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps | |
ut = z - eps | |
# (1 - sigma) because of | |
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656 | |
# Also, we operate on the scaled version of the `timesteps` directly in the `diffusers` implementation. | |
timesteps = (1 - sigma) * scheduler.config.num_train_timesteps | |
with torch.autocast("cuda", torch.bfloat16): | |
model_pred = transformer( | |
hidden_states=z_sigma, | |
encoder_hidden_states=prompt_embeds, | |
encoder_attention_mask=prompt_attention_mask, | |
timestep=timesteps, | |
return_dict=False, | |
)[0] | |
assert model_pred.shape == z.shape | |
loss = F.mse_loss(model_pred.float(), ut.float()) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
lr_scheduler.step() | |
progress_bar.update(1) | |
global_step += 1 | |
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate | |
logs = {"loss": loss.detach().item(), "lr": last_lr} | |
progress_bar.set_postfix(**logs) | |
if wandb_run: | |
wandb_run.log(logs, step=global_step) | |
if args.checkpointing_steps is not None and global_step % args.checkpointing_steps == 0: | |
print(f"Saving checkpoint at step {global_step}") | |
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt") | |
save_checkpoint( | |
transformer, | |
optimizer, | |
lr_scheduler, | |
global_step, | |
checkpoint_path, | |
) | |
if global_step >= args.max_train_steps: | |
break | |
if global_step >= args.max_train_steps: | |
break | |
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: | |
print("===== Memory before validation =====") | |
print_memory("cuda") | |
transformer.eval() | |
pipe = MochiPipeline.from_pretrained( | |
args.pretrained_model_name_or_path, | |
transformer=transformer, | |
scheduler=scheduler, | |
revision=args.revision, | |
variant=args.variant, | |
) | |
if args.enable_slicing: | |
pipe.vae.enable_slicing() | |
if args.enable_tiling: | |
pipe.vae.enable_tiling() | |
if args.enable_model_cpu_offload: | |
pipe.enable_model_cpu_offload() | |
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) | |
for validation_prompt in validation_prompts: | |
pipeline_args = { | |
"prompt": validation_prompt, | |
"guidance_scale": 6.0, | |
"num_inference_steps": 64, | |
"height": args.height, | |
"width": args.width, | |
"max_sequence_length": 256, | |
} | |
log_validation( | |
pipe=pipe, | |
args=args, | |
pipeline_args=pipeline_args, | |
epoch=epoch, | |
wandb_run=wandb_run, | |
) | |
print("===== Memory after validation =====") | |
print_memory("cuda") | |
reset_memory("cuda") | |
del pipe.text_encoder | |
del pipe.vae | |
del pipe | |
gc.collect() | |
torch.cuda.empty_cache() | |
transformer.train() | |
transformer.eval() | |
transformer_lora_layers = get_peft_model_state_dict(transformer) | |
MochiPipeline.save_lora_weights(save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers) | |
# Cleanup trained models to save memory | |
del transformer | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Final test inference | |
validation_outputs = [] | |
if args.validation_prompt and args.num_validation_videos > 0: | |
print("===== Memory before testing =====") | |
print_memory("cuda") | |
reset_memory("cuda") | |
pipe = MochiPipeline.from_pretrained( | |
args.pretrained_model_name_or_path, | |
revision=args.revision, | |
variant=args.variant, | |
) | |
if args.enable_slicing: | |
pipe.vae.enable_slicing() | |
if args.enable_tiling: | |
pipe.vae.enable_tiling() | |
if args.enable_model_cpu_offload: | |
pipe.enable_model_cpu_offload() | |
# Load LoRA weights | |
lora_scaling = args.lora_alpha / args.rank | |
pipe.load_lora_weights(args.output_dir, adapter_name="mochi-lora") | |
pipe.set_adapters(["mochi-lora"], [lora_scaling]) | |
# Run inference | |
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) | |
for validation_prompt in validation_prompts: | |
pipeline_args = { | |
"prompt": validation_prompt, | |
"guidance_scale": 6.0, | |
"num_inference_steps": 64, | |
"height": args.height, | |
"width": args.width, | |
"max_sequence_length": 256, | |
} | |
video = log_validation( | |
pipe=pipe, | |
args=args, | |
pipeline_args=pipeline_args, | |
epoch=epoch, | |
wandb_run=wandb_run, | |
is_final_validation=True, | |
) | |
validation_outputs.extend(video) | |
print("===== Memory after testing =====") | |
print_memory("cuda") | |
reset_memory("cuda") | |
torch.cuda.synchronize("cuda") | |
if args.push_to_hub: | |
save_model_card( | |
repo_id, | |
videos=validation_outputs, | |
base_model=args.pretrained_model_name_or_path, | |
validation_prompt=args.validation_prompt, | |
repo_folder=args.output_dir, | |
fps=args.fps, | |
) | |
upload_folder( | |
repo_id=repo_id, | |
folder_path=args.output_dir, | |
commit_message="End of training", | |
ignore_patterns=["*.bin"], | |
) | |
print(f"Params pushed to {repo_id}.") | |
if __name__ == "__main__": | |
args = get_args() | |
main(args) | |