jbilcke-hf's picture
jbilcke-hf HF Staff
fix for Finetrainers
66c6879
raw
history blame
57.1 kB
import json
import logging
import math
import os
import gc
import random
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List
import resource
import diffusers
import torch
import torch.backends
import transformers
import wandb
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import (
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
ProjectConfiguration,
gather_object,
set_seed,
)
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params
from diffusers.utils import export_to_video, load_image, load_video
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from tqdm import tqdm
from .args import Args, validate_args
from .constants import (
FINETRAINERS_LOG_LEVEL,
PRECOMPUTED_CONDITIONS_DIR_NAME,
PRECOMPUTED_DIR_NAME,
PRECOMPUTED_LATENTS_DIR_NAME,
)
from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
from .hooks import apply_layerwise_upcasting
from .models import get_config_from_model_name
from .patches import perform_peft_patches
from .state import State
from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
from .utils.data_utils import should_perform_precomputation
from .utils.diffusion_utils import (
get_scheduler_alphas,
get_scheduler_sigmas,
prepare_loss_weights,
prepare_sigmas,
prepare_target,
)
from .utils.file_utils import string_to_filename
from .utils.hub_utils import save_model_card
from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
from .utils.model_utils import resolve_vae_cls_from_ckpt_path
from .utils.optimizer_utils import get_optimizer
from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
logger = get_logger("finetrainers")
logger.setLevel(FINETRAINERS_LOG_LEVEL)
class Trainer:
def __init__(self, args: Args) -> None:
validate_args(args)
self.args = args
self.args.seed = self.args.seed or datetime.now().year
self.state = State()
# Tokenizers
self.tokenizer = None
self.tokenizer_2 = None
self.tokenizer_3 = None
# Text encoders
self.text_encoder = None
self.text_encoder_2 = None
self.text_encoder_3 = None
# Denoisers
self.transformer = None
self.unet = None
# Autoencoders
self.vae = None
# Scheduler
self.scheduler = None
self.transformer_config = None
self.vae_config = None
self._init_distributed()
self._init_logging()
self._init_directories_and_repositories()
self._init_config_options()
# Peform any patches needed for training
if len(self.args.layerwise_upcasting_modules) > 0:
perform_peft_patches()
# TODO(aryan): handle text encoders
# if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
# perform_text_encoder_patches()
self.state.model_name = self.args.model_name
self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
def prepare_dataset(self) -> None:
# TODO(aryan): Make a background process for fetching
logger.info("Initializing dataset and dataloader")
self.dataset = ImageOrVideoDatasetWithResizing(
data_root=self.args.data_root,
caption_column=self.args.caption_column,
video_column=self.args.video_column,
resolution_buckets=self.args.video_resolution_buckets,
dataset_file=self.args.dataset_file,
id_token=self.args.id_token,
remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes,
)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=1,
sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
collate_fn=self.model_config.get("collate_fn"),
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.pin_memory,
)
def prepare_models(self) -> None:
logger.info("Initializing models")
load_components_kwargs = self._get_load_components_kwargs()
condition_components, latent_components, diffusion_components = {}, {}, {}
if not self.args.precompute_conditions:
# To download the model files first on the main process (if not already present)
# and then load the cached files afterward from the other processes.
with self.state.accelerator.main_process_first():
condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)
components = {}
components.update(condition_components)
components.update(latent_components)
components.update(diffusion_components)
self._set_components(components)
if self.vae is not None:
if self.args.enable_slicing:
self.vae.enable_slicing()
if self.args.enable_tiling:
self.vae.enable_tiling()
def prepare_precomputations(self) -> None:
if not self.args.precompute_conditions:
return
logger.info("Initializing precomputations")
if self.args.batch_size != 1:
raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.")
def collate_fn(batch):
latent_conditions = [x["latent_conditions"] for x in batch]
text_conditions = [x["text_conditions"] for x in batch]
batched_latent_conditions = {}
batched_text_conditions = {}
for key in list(latent_conditions[0].keys()):
if torch.is_tensor(latent_conditions[0][key]):
batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0)
else:
# TODO(aryan): implement batch sampler for precomputed latents
batched_latent_conditions[key] = [x[key] for x in latent_conditions][0]
for key in list(text_conditions[0].keys()):
if torch.is_tensor(text_conditions[0][key]):
batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0)
else:
# TODO(aryan): implement batch sampler for precomputed latents
batched_text_conditions[key] = [x[key] for x in text_conditions][0]
return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions}
cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path)
precomputation_dir = (
Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
)
should_precompute = should_perform_precomputation(precomputation_dir)
if not should_precompute:
logger.info("Precomputed conditions and latents found. Loading precomputed data.")
self.dataloader = torch.utils.data.DataLoader(
PrecomputedDataset(
data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
),
batch_size=self.args.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.pin_memory,
)
return
logger.info("Precomputed conditions and latents not found. Running precomputation.")
# At this point, no models are loaded, so we need to load and precompute conditions and latents
with self.state.accelerator.main_process_first():
condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
self._set_components(condition_components)
self._move_components_to_device()
self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3])
if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty":
logger.warning(
"Caption dropout is not supported with precomputation yet. This will be supported in the future."
)
conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
conditions_dir.mkdir(parents=True, exist_ok=True)
latents_dir.mkdir(parents=True, exist_ok=True)
accelerator = self.state.accelerator
# Precompute conditions
progress_bar = tqdm(
range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
desc="Precomputing conditions",
disable=not accelerator.is_local_main_process,
)
index = 0
for i, data in enumerate(self.dataset):
if i % accelerator.num_processes != accelerator.process_index:
continue
logger.debug(
f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
)
text_conditions = self.model_config["prepare_conditions"](
tokenizer=self.tokenizer,
tokenizer_2=self.tokenizer_2,
tokenizer_3=self.tokenizer_3,
text_encoder=self.text_encoder,
text_encoder_2=self.text_encoder_2,
text_encoder_3=self.text_encoder_3,
prompt=data["prompt"],
device=accelerator.device,
dtype=self.args.transformer_dtype,
)
filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
torch.save(text_conditions, filename.as_posix())
index += 1
progress_bar.update(1)
self._delete_components()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
# Precompute latents
with self.state.accelerator.main_process_first():
latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
self._set_components(latent_components)
self._move_components_to_device()
self._disable_grad_for_components([self.vae])
if self.vae is not None:
if self.args.enable_slicing:
self.vae.enable_slicing()
if self.args.enable_tiling:
self.vae.enable_tiling()
progress_bar = tqdm(
range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
desc="Precomputing latents",
disable=not accelerator.is_local_main_process,
)
index = 0
for i, data in enumerate(self.dataset):
if i % accelerator.num_processes != accelerator.process_index:
continue
logger.debug(
f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
)
latent_conditions = self.model_config["prepare_latents"](
vae=self.vae,
image_or_video=data["video"].unsqueeze(0),
device=accelerator.device,
dtype=self.args.transformer_dtype,
generator=self.state.generator,
precompute=True,
)
filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt"
torch.save(latent_conditions, filename.as_posix())
index += 1
progress_bar.update(1)
self._delete_components()
accelerator.wait_for_everyone()
logger.info("Precomputation complete")
memory_statistics = get_memory_statistics()
logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
# Update dataloader to use precomputed conditions and latents
self.dataloader = torch.utils.data.DataLoader(
PrecomputedDataset(
data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
),
batch_size=self.args.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.pin_memory,
)
def prepare_trainable_parameters(self) -> None:
logger.info("Initializing trainable parameters")
with self.state.accelerator.main_process_first():
diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
self._set_components(diffusion_components)
components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae]
self._disable_grad_for_components(components)
if self.args.training_type == "full-finetune":
logger.info("Finetuning transformer with no additional parameters")
self._enable_grad_for_components([self.transformer])
else:
logger.info("Finetuning transformer with PEFT parameters")
self._disable_grad_for_components([self.transformer])
# Layerwise upcasting must be applied before adding the LoRA adapter.
# If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
# CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
apply_layerwise_upcasting(
self.transformer,
storage_dtype=self.args.layerwise_upcasting_storage_dtype,
compute_dtype=self.args.transformer_dtype,
skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
non_blocking=True,
)
self._move_components_to_device()
if self.args.gradient_checkpointing:
self.transformer.enable_gradient_checkpointing()
if self.args.training_type == "lora":
transformer_lora_config = LoraConfig(
r=self.args.rank,
lora_alpha=self.args.lora_alpha,
init_lora_weights=True,
target_modules=self.args.target_modules,
)
self.transformer.add_adapter(transformer_lora_config)
else:
transformer_lora_config = None
# TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
# even if layerwise upcasting. Would be nice to have a test as well
self.register_saving_loading_hooks(transformer_lora_config)
def register_saving_loading_hooks(self, transformer_lora_config):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if self.state.accelerator.is_main_process:
transformer_lora_layers_to_save = None
for model in models:
if isinstance(
unwrap_model(self.state.accelerator, model),
type(unwrap_model(self.state.accelerator, self.transformer)),
):
model = unwrap_model(self.state.accelerator, model)
if self.args.training_type == "lora":
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"Unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
if self.args.training_type == "lora":
self.model_config["pipeline_cls"].save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
)
else:
model.save_pretrained(os.path.join(output_dir, "transformer"))
# In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
# to able to load all diffusion components from a specific checkpoint folder during validation, we need to
# ensure the scheduler config is serialized as well.
self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
def load_model_hook(models, input_dir):
if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
if isinstance(
unwrap_model(self.state.accelerator, model),
type(unwrap_model(self.state.accelerator, self.transformer)),
):
transformer_ = unwrap_model(self.state.accelerator, model)
else:
raise ValueError(
f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
)
else:
transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
if self.args.training_type == "lora":
transformer_ = transformer_cls_.from_pretrained(
self.args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v
for k, v in lora_state_dict.items()
if k.startswith("transformer.")
}
incompatible_keys = set_peft_model_state_dict(
transformer_, transformer_state_dict, adapter_name="default"
)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
else:
transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
self.state.accelerator.register_save_state_pre_hook(save_model_hook)
self.state.accelerator.register_load_state_pre_hook(load_model_hook)
def prepare_optimizer(self) -> None:
logger.info("Initializing optimizer and lr scheduler")
self.state.train_epochs = self.args.train_epochs
self.state.train_steps = self.args.train_steps
# Make sure the trainable params are in float32
if self.args.training_type == "lora":
cast_training_params([self.transformer], dtype=torch.float32)
self.state.learning_rate = self.args.lr
if self.args.scale_lr:
self.state.learning_rate = (
self.state.learning_rate
* self.args.gradient_accumulation_steps
* self.args.batch_size
* self.state.accelerator.num_processes
)
transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
transformer_parameters_with_lr = {
"params": transformer_trainable_parameters,
"lr": self.state.learning_rate,
}
params_to_optimize = [transformer_parameters_with_lr]
self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters)
use_deepspeed_opt = (
self.state.accelerator.state.deepspeed_plugin is not None
and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(
params_to_optimize=params_to_optimize,
optimizer_name=self.args.optimizer,
learning_rate=self.state.learning_rate,
beta1=self.args.beta1,
beta2=self.args.beta2,
beta3=self.args.beta3,
epsilon=self.args.epsilon,
weight_decay=self.args.weight_decay,
use_8bit=self.args.use_8bit_bnb,
use_deepspeed=use_deepspeed_opt,
)
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
if self.state.train_steps is None:
self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
self.state.overwrote_max_train_steps = True
use_deepspeed_lr_scheduler = (
self.state.accelerator.state.deepspeed_plugin is not None
and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
)
total_training_steps = self.state.train_steps * self.state.accelerator.num_processes
num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes
if use_deepspeed_lr_scheduler:
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=self.args.lr_scheduler,
optimizer=optimizer,
total_num_steps=total_training_steps,
num_warmup_steps=num_warmup_steps,
)
else:
lr_scheduler = get_scheduler(
name=self.args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_training_steps,
num_cycles=self.args.lr_num_cycles,
power=self.args.lr_power,
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def prepare_for_training(self) -> None:
self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
if self.state.overwrote_max_train_steps:
self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
def prepare_trackers(self) -> None:
logger.info("Initializing trackers")
tracker_name = self.args.tracker_name or "finetrainers-experiment"
self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info())
def train(self) -> None:
logger.info("Starting training")
# Add these lines at the beginning
if hasattr(resource, 'RLIMIT_NOFILE'):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
logger.info(f"Current file descriptor limits in trainer: soft={soft}, hard={hard}")
# Try to increase to hard limit if possible
if soft < hard:
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
except Exception as e:
logger.warning(f"Could not check or update file descriptor limits: {e}")
memory_statistics = get_memory_statistics()
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
if self.vae_config is None:
# If we've precomputed conditions and latents already, and are now re-using it, we will never load
# the VAE so self.vae_config will not be set. So, we need to load it here.
vae_cls = resolve_vae_cls_from_ckpt_path(
self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir
)
vae_config = vae_cls.load_config(
self.args.pretrained_model_name_or_path,
subfolder="vae",
revision=self.args.revision,
cache_dir=self.args.cache_dir,
)
self.vae_config = FrozenDict(**vae_config)
# In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
# to able to load all diffusion components from a specific checkpoint folder during validation, we need to
# ensure the scheduler config is serialized as well.
if self.args.training_type == "full-finetune":
self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler"))
self.state.train_batch_size = (
self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
)
info = {
"trainable parameters": self.state.num_trainable_parameters,
"total samples": len(self.dataset),
"train epochs": self.state.train_epochs,
"train steps": self.state.train_steps,
"batches per device": self.args.batch_size,
"total batches observed per epoch": len(self.dataloader),
"train batch size": self.state.train_batch_size,
"gradient accumulation steps": self.args.gradient_accumulation_steps,
}
logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
global_step = 0
first_epoch = 0
initial_global_step = 0
# Potentially load in the weights and states from a previous save
(
resume_from_checkpoint_path,
initial_global_step,
global_step,
first_epoch,
) = get_latest_ckpt_path_to_resume_from(
resume_from_checkpoint=self.args.resume_from_checkpoint,
num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
output_dir=self.args.output_dir,
)
if resume_from_checkpoint_path:
self.state.accelerator.load_state(resume_from_checkpoint_path)
progress_bar = tqdm(
range(0, self.state.train_steps),
initial=initial_global_step,
desc="Training steps",
disable=not self.state.accelerator.is_local_main_process,
)
accelerator = self.state.accelerator
generator = torch.Generator(device=accelerator.device)
if self.args.seed is not None:
generator = generator.manual_seed(self.args.seed)
self.state.generator = generator
scheduler_sigmas = get_scheduler_sigmas(self.scheduler)
scheduler_sigmas = (
scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32)
if scheduler_sigmas is not None
else None
)
scheduler_alphas = get_scheduler_alphas(self.scheduler)
scheduler_alphas = (
scheduler_alphas.to(device=accelerator.device, dtype=torch.float32)
if scheduler_alphas is not None
else None
)
for epoch in range(first_epoch, self.state.train_epochs):
logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")
self.transformer.train()
models_to_accumulate = [self.transformer]
epoch_loss = 0.0
num_loss_updates = 0
for step, batch in enumerate(self.dataloader):
logger.debug(f"Starting step {step + 1}")
logs = {}
with accelerator.accumulate(models_to_accumulate):
if not self.args.precompute_conditions:
videos = batch["videos"]
prompts = batch["prompts"]
batch_size = len(prompts)
if self.args.caption_dropout_technique == "empty":
if random.random() < self.args.caption_dropout_p:
prompts = [""] * batch_size
latent_conditions = self.model_config["prepare_latents"](
vae=self.vae,
image_or_video=videos,
patch_size=self.transformer_config.patch_size,
patch_size_t=self.transformer_config.patch_size_t,
device=accelerator.device,
dtype=self.args.transformer_dtype,
generator=self.state.generator,
)
text_conditions = self.model_config["prepare_conditions"](
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
tokenizer_2=self.tokenizer_2,
text_encoder_2=self.text_encoder_2,
prompt=prompts,
device=accelerator.device,
dtype=self.args.transformer_dtype,
)
else:
latent_conditions = batch["latent_conditions"]
text_conditions = batch["text_conditions"]
latent_conditions["latents"] = DiagonalGaussianDistribution(
latent_conditions["latents"]
).sample(self.state.generator)
# This method should only be called for precomputed latents.
# TODO(aryan): rename this in separate PR
latent_conditions = self.model_config["post_latent_preparation"](
vae_config=self.vae_config,
patch_size=self.transformer_config.patch_size,
patch_size_t=self.transformer_config.patch_size_t,
**latent_conditions,
)
align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype)
align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype)
batch_size = latent_conditions["latents"].shape[0]
latent_conditions = make_contiguous(latent_conditions)
text_conditions = make_contiguous(text_conditions)
if self.args.caption_dropout_technique == "zero":
if random.random() < self.args.caption_dropout_p:
text_conditions["prompt_embeds"].fill_(0)
text_conditions["prompt_attention_mask"].fill_(False)
# TODO(aryan): refactor later
if "pooled_prompt_embeds" in text_conditions:
text_conditions["pooled_prompt_embeds"].fill_(0)
sigmas = prepare_sigmas(
scheduler=self.scheduler,
sigmas=scheduler_sigmas,
batch_size=batch_size,
num_train_timesteps=self.scheduler.config.num_train_timesteps,
flow_weighting_scheme=self.args.flow_weighting_scheme,
flow_logit_mean=self.args.flow_logit_mean,
flow_logit_std=self.args.flow_logit_std,
flow_mode_scale=self.args.flow_mode_scale,
device=accelerator.device,
generator=self.state.generator,
)
timesteps = (sigmas * 1000.0).long()
noise = torch.randn(
latent_conditions["latents"].shape,
generator=self.state.generator,
device=accelerator.device,
dtype=self.args.transformer_dtype,
)
sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
# TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of
# scheduler and calculate the noisy latents accordingly. Look into this later.
if "calculate_noisy_latents" in self.model_config.keys():
noisy_latents = self.model_config["calculate_noisy_latents"](
scheduler=self.scheduler,
noise=noise,
latents=latent_conditions["latents"],
timesteps=timesteps,
)
else:
# Default to flow-matching noise addition
noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise
noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype)
latent_conditions.update({"noisy_latents": noisy_latents})
weights = prepare_loss_weights(
scheduler=self.scheduler,
alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
sigmas=sigmas,
flow_weighting_scheme=self.args.flow_weighting_scheme,
)
weights = expand_tensor_dims(weights, noise.ndim)
pred = self.model_config["forward_pass"](
transformer=self.transformer,
scheduler=self.scheduler,
timesteps=timesteps,
**latent_conditions,
**text_conditions,
)
target = prepare_target(
scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"]
)
loss = weights.float() * (pred["latents"].float() - target.float()).pow(2)
# Average loss across all but batch dimension
loss = loss.mean(list(range(1, loss.ndim)))
# Average loss across batch dimension
loss = loss.mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
if accelerator.distributed_type == DistributedType.DEEPSPEED:
grad_norm = self.transformer.get_global_grad_norm()
# In some cases the grad norm may not return a float
if torch.is_tensor(grad_norm):
grad_norm = grad_norm.item()
else:
grad_norm = accelerator.clip_grad_norm_(
self.transformer.parameters(), self.args.max_grad_norm
)
if torch.is_tensor(grad_norm):
grad_norm = grad_norm.item()
logs["grad_norm"] = grad_norm
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# Checkpointing
if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % self.args.checkpointing_steps == 0:
save_path = get_intermediate_ckpt_path(
checkpointing_limit=self.args.checkpointing_limit,
step=global_step,
output_dir=self.args.output_dir,
)
accelerator.save_state(save_path)
# Maybe run validation
should_run_validation = (
self.args.validation_every_n_steps is not None
and global_step % self.args.validation_every_n_steps == 0
)
if should_run_validation:
self.validate(global_step)
loss_item = loss.detach().item()
epoch_loss += loss_item
num_loss_updates += 1
logs["step_loss"] = loss_item
logs["lr"] = self.lr_scheduler.get_last_lr()[0]
progress_bar.set_postfix(logs)
accelerator.log(logs, step=global_step)
if global_step % 100 == 0: # Every 100 steps
# Force garbage collection to clean up any lingering resources
gc.collect()
if global_step >= self.state.train_steps:
break
if num_loss_updates > 0:
epoch_loss /= num_loss_updates
accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
memory_statistics = get_memory_statistics()
logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
# Maybe run validation
should_run_validation = (
self.args.validation_every_n_epochs is not None
and (epoch + 1) % self.args.validation_every_n_epochs == 0
)
if should_run_validation:
self.validate(global_step)
if epoch % 3 == 0: # Every 3 epochs
logger.info("Performing periodic resource cleanup")
free_memory()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize(accelerator.device)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(accelerator, self.transformer)
if self.args.training_type == "lora":
transformer_lora_layers = get_peft_model_state_dict(transformer)
self.model_config["pipeline_cls"].save_lora_weights(
save_directory=self.args.output_dir,
transformer_lora_layers=transformer_lora_layers,
)
else:
transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer"))
accelerator.wait_for_everyone()
self.validate(step=global_step, final_validation=True)
if accelerator.is_main_process:
if self.args.push_to_hub:
upload_folder(
repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"]
)
self._delete_components()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
accelerator.end_training()
def validate(self, step: int, final_validation: bool = False) -> None:
logger.info("Starting validation")
accelerator = self.state.accelerator
num_validation_samples = len(self.args.validation_prompts)
if num_validation_samples == 0:
logger.warning("No validation samples found. Skipping validation.")
if accelerator.is_main_process:
if self.args.push_to_hub:
save_model_card(
args=self.args,
repo_id=self.state.repo_id,
videos=None,
validation_prompts=None,
)
return
self.transformer.eval()
memory_statistics = get_memory_statistics()
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation)
all_processes_artifacts = []
prompts_to_filenames = {}
for i in range(num_validation_samples):
# Skip current validation on all processes but one
if i % accelerator.num_processes != accelerator.process_index:
continue
prompt = self.args.validation_prompts[i]
image = self.args.validation_images[i]
video = self.args.validation_videos[i]
height = self.args.validation_heights[i]
width = self.args.validation_widths[i]
num_frames = self.args.validation_num_frames[i]
frame_rate = self.args.validation_frame_rate
if image is not None:
image = load_image(image)
if video is not None:
video = load_video(video)
logger.debug(
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
main_process_only=False,
)
validation_artifacts = self.model_config["validation"](
pipeline=pipeline,
prompt=prompt,
image=image,
video=video,
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
generator=torch.Generator(device=accelerator.device).manual_seed(
self.args.seed if self.args.seed is not None else 0
),
# todo support passing `fps` for supported pipelines.
)
prompt_filename = string_to_filename(prompt)[:25]
artifacts = {
"image": {"type": "image", "value": image},
"video": {"type": "video", "value": video},
}
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
if artifact_value:
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
logger.debug(
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
main_process_only=False,
)
for index, (key, value) in enumerate(list(artifacts.items())):
artifact_type = value["type"]
artifact_value = value["value"]
if artifact_type not in ["image", "video"] or artifact_value is None:
continue
extension = "png" if artifact_type == "image" else "mp4"
filename = "validation-" if not final_validation else "final-"
filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}"
if accelerator.is_main_process and extension == "mp4":
prompts_to_filenames[prompt] = filename
filename = os.path.join(self.args.output_dir, filename)
if artifact_type == "image" and artifact_value:
logger.debug(f"Saving image to {filename}")
artifact_value.save(filename)
artifact_value = wandb.Image(filename)
elif artifact_type == "video" and artifact_value:
logger.debug(f"Saving video to {filename}")
# TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
export_to_video(artifact_value, filename, fps=frame_rate)
artifact_value = wandb.Video(filename, caption=prompt)
all_processes_artifacts.append(artifact_value)
all_artifacts = gather_object(all_processes_artifacts)
if accelerator.is_main_process:
tracker_key = "final" if final_validation else "validation"
for tracker in accelerator.trackers:
if tracker.name == "wandb":
artifact_log_dict = {}
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
if len(image_artifacts) > 0:
artifact_log_dict["images"] = image_artifacts
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
if len(video_artifacts) > 0:
artifact_log_dict["videos"] = video_artifacts
tracker.log({tracker_key: artifact_log_dict}, step=step)
if self.args.push_to_hub and final_validation:
video_filenames = list(prompts_to_filenames.values())
prompts = list(prompts_to_filenames.keys())
save_model_card(
args=self.args,
repo_id=self.state.repo_id,
videos=video_filenames,
validation_prompts=prompts,
)
# Remove all hooks that might have been added during pipeline initialization to the models
pipeline.remove_all_hooks()
del pipeline
accelerator.wait_for_everyone()
free_memory()
memory_statistics = get_memory_statistics()
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
torch.cuda.reset_peak_memory_stats(accelerator.device)
if not final_validation:
self.transformer.train()
def evaluate(self) -> None:
raise NotImplementedError("Evaluation has not been implemented yet.")
def _init_distributed(self) -> None:
logging_dir = Path(self.args.output_dir, self.args.logging_dir)
project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_process_group_kwargs = InitProcessGroupKwargs(
backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
)
report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
accelerator = Accelerator(
project_config=project_config,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
log_with=report_to,
kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
self.state.accelerator = accelerator
if self.args.seed is not None:
self.state.seed = self.args.seed
set_seed(self.args.seed)
def _init_logging(self) -> None:
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=FINETRAINERS_LOG_LEVEL,
)
if self.state.accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
logger.info("Initialized FineTrainers")
logger.info(self.state.accelerator.state, main_process_only=False)
def _init_directories_and_repositories(self) -> None:
if self.state.accelerator.is_main_process:
self.args.output_dir = Path(self.args.output_dir)
self.args.output_dir.mkdir(parents=True, exist_ok=True)
self.state.output_dir = Path(self.args.output_dir)
if self.args.push_to_hub:
repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
def _init_config_options(self) -> None:
# Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if self.args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
def _move_components_to_device(self):
if self.text_encoder is not None:
self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
if self.text_encoder_2 is not None:
self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
if self.text_encoder_3 is not None:
self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
if self.transformer is not None:
self.transformer = self.transformer.to(self.state.accelerator.device)
if self.unet is not None:
self.unet = self.unet.to(self.state.accelerator.device)
if self.vae is not None:
self.vae = self.vae.to(self.state.accelerator.device)
def _get_load_components_kwargs(self) -> Dict[str, Any]:
load_component_kwargs = {
"text_encoder_dtype": self.args.text_encoder_dtype,
"text_encoder_2_dtype": self.args.text_encoder_2_dtype,
"text_encoder_3_dtype": self.args.text_encoder_3_dtype,
"transformer_dtype": self.args.transformer_dtype,
"vae_dtype": self.args.vae_dtype,
"shift": self.args.flow_shift,
"revision": self.args.revision,
"cache_dir": self.args.cache_dir,
}
if self.args.pretrained_model_name_or_path is not None:
load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
return load_component_kwargs
def _set_components(self, components: Dict[str, Any]) -> None:
# Set models
self.tokenizer = components.get("tokenizer", self.tokenizer)
self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
self.text_encoder = components.get("text_encoder", self.text_encoder)
self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
self.transformer = components.get("transformer", self.transformer)
self.unet = components.get("unet", self.unet)
self.vae = components.get("vae", self.vae)
self.scheduler = components.get("scheduler", self.scheduler)
# Set configs
self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
self.vae_config = self.vae.config if self.vae is not None else self.vae_config
def _delete_components(self) -> None:
self.tokenizer = None
self.tokenizer_2 = None
self.tokenizer_3 = None
self.text_encoder = None
self.text_encoder_2 = None
self.text_encoder_3 = None
self.transformer = None
self.unet = None
self.vae = None
self.scheduler = None
free_memory()
torch.cuda.synchronize(self.state.accelerator.device)
def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
accelerator = self.state.accelerator
if not final_validation:
pipeline = self.model_config["initialize_pipeline"](
model_id=self.args.pretrained_model_name_or_path,
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
tokenizer_2=self.tokenizer_2,
text_encoder_2=self.text_encoder_2,
transformer=unwrap_model(accelerator, self.transformer),
vae=self.vae,
device=accelerator.device,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
enable_slicing=self.args.enable_slicing,
enable_tiling=self.args.enable_tiling,
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
is_training=True,
)
else:
self._delete_components()
# Load the transformer weights from the final checkpoint if performing full-finetune
transformer = None
if self.args.training_type == "full-finetune":
transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"]
pipeline = self.model_config["initialize_pipeline"](
model_id=self.args.pretrained_model_name_or_path,
transformer=transformer,
device=accelerator.device,
revision=self.args.revision,
cache_dir=self.args.cache_dir,
enable_slicing=self.args.enable_slicing,
enable_tiling=self.args.enable_tiling,
enable_model_cpu_offload=self.args.enable_model_cpu_offload,
is_training=False,
)
# Load the LoRA weights if performing LoRA finetuning
if self.args.training_type == "lora":
pipeline.load_lora_weights(self.args.output_dir)
return pipeline
def _disable_grad_for_components(self, components: List[torch.nn.Module]):
for component in components:
if component is not None:
component.requires_grad_(False)
def _enable_grad_for_components(self, components: List[torch.nn.Module]):
for component in components:
if component is not None:
component.requires_grad_(True)
def _get_training_info(self) -> dict:
args = self.args.to_dict()
training_args = args.get("training_arguments", {})
training_type = training_args.get("training_type", "")
# LoRA/non-LoRA stuff.
if training_type == "full-finetune":
filtered_training_args = {
k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"}
}
else:
filtered_training_args = training_args
# Diffusion/flow stuff.
diffusion_args = args.get("diffusion_arguments", {})
scheduler_name = self.scheduler.__class__.__name__
if scheduler_name != "FlowMatchEulerDiscreteScheduler":
filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
else:
filtered_diffusion_args = diffusion_args
# Rest of the stuff.
updated_training_info = args.copy()
updated_training_info["training_arguments"] = filtered_training_args
updated_training_info["diffusion_arguments"] = filtered_diffusion_args
return updated_training_info