import json import logging import math import os import random from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, List import diffusers import torch import torch.backends import transformers import wandb from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import ( DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, gather_object, set_seed, ) from diffusers import DiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params from diffusers.utils import export_to_video, load_image, load_video from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from tqdm import tqdm from .args import Args, validate_args from .constants import ( FINETRAINERS_LOG_LEVEL, PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME, ) from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset from .hooks import apply_layerwise_upcasting from .models import get_config_from_model_name from .patches import perform_peft_patches from .state import State from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from from .utils.data_utils import should_perform_precomputation from .utils.diffusion_utils import ( get_scheduler_alphas, get_scheduler_sigmas, prepare_loss_weights, prepare_sigmas, prepare_target, ) from .utils.file_utils import string_to_filename from .utils.hub_utils import save_model_card from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous from .utils.model_utils import resolve_vae_cls_from_ckpt_path from .utils.optimizer_utils import get_optimizer from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model logger = get_logger("finetrainers") logger.setLevel(FINETRAINERS_LOG_LEVEL) class Trainer: def __init__(self, args: Args) -> None: validate_args(args) self.args = args self.args.seed = self.args.seed or datetime.now().year self.state = State() # Tokenizers self.tokenizer = None self.tokenizer_2 = None self.tokenizer_3 = None # Text encoders self.text_encoder = None self.text_encoder_2 = None self.text_encoder_3 = None # Denoisers self.transformer = None self.unet = None # Autoencoders self.vae = None # Scheduler self.scheduler = None self.transformer_config = None self.vae_config = None self._init_distributed() self._init_logging() self._init_directories_and_repositories() self._init_config_options() # Peform any patches needed for training if len(self.args.layerwise_upcasting_modules) > 0: perform_peft_patches() # TODO(aryan): handle text encoders # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]): # perform_text_encoder_patches() self.state.model_name = self.args.model_name self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type) def prepare_dataset(self) -> None: # TODO(aryan): Make a background process for fetching logger.info("Initializing dataset and dataloader") self.dataset = ImageOrVideoDatasetWithResizing( data_root=self.args.data_root, caption_column=self.args.caption_column, video_column=self.args.video_column, resolution_buckets=self.args.video_resolution_buckets, dataset_file=self.args.dataset_file, id_token=self.args.id_token, remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes, ) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=1, sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), collate_fn=self.model_config.get("collate_fn"), num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory, ) def prepare_models(self) -> None: logger.info("Initializing models") load_components_kwargs = self._get_load_components_kwargs() condition_components, latent_components, diffusion_components = {}, {}, {} if not self.args.precompute_conditions: # To download the model files first on the main process (if not already present) # and then load the cached files afterward from the other processes. with self.state.accelerator.main_process_first(): condition_components = self.model_config["load_condition_models"](**load_components_kwargs) latent_components = self.model_config["load_latent_models"](**load_components_kwargs) diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs) components = {} components.update(condition_components) components.update(latent_components) components.update(diffusion_components) self._set_components(components) if self.vae is not None: if self.args.enable_slicing: self.vae.enable_slicing() if self.args.enable_tiling: self.vae.enable_tiling() def prepare_precomputations(self) -> None: if not self.args.precompute_conditions: return logger.info("Initializing precomputations") if self.args.batch_size != 1: raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.") def collate_fn(batch): latent_conditions = [x["latent_conditions"] for x in batch] text_conditions = [x["text_conditions"] for x in batch] batched_latent_conditions = {} batched_text_conditions = {} for key in list(latent_conditions[0].keys()): if torch.is_tensor(latent_conditions[0][key]): batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0) else: # TODO(aryan): implement batch sampler for precomputed latents batched_latent_conditions[key] = [x[key] for x in latent_conditions][0] for key in list(text_conditions[0].keys()): if torch.is_tensor(text_conditions[0][key]): batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0) else: # TODO(aryan): implement batch sampler for precomputed latents batched_text_conditions[key] = [x[key] for x in text_conditions][0] return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions} cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path) precomputation_dir = ( Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" ) should_precompute = should_perform_precomputation(precomputation_dir) if not should_precompute: logger.info("Precomputed conditions and latents found. Loading precomputed data.") self.dataloader = torch.utils.data.DataLoader( PrecomputedDataset( data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id ), batch_size=self.args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory, ) return logger.info("Precomputed conditions and latents not found. Running precomputation.") # At this point, no models are loaded, so we need to load and precompute conditions and latents with self.state.accelerator.main_process_first(): condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs()) self._set_components(condition_components) self._move_components_to_device() self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3]) if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty": logger.warning( "Caption dropout is not supported with precomputation yet. This will be supported in the future." ) conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME conditions_dir.mkdir(parents=True, exist_ok=True) latents_dir.mkdir(parents=True, exist_ok=True) accelerator = self.state.accelerator # Precompute conditions progress_bar = tqdm( range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes), desc="Precomputing conditions", disable=not accelerator.is_local_main_process, ) index = 0 for i, data in enumerate(self.dataset): if i % accelerator.num_processes != accelerator.process_index: continue logger.debug( f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}" ) text_conditions = self.model_config["prepare_conditions"]( tokenizer=self.tokenizer, tokenizer_2=self.tokenizer_2, tokenizer_3=self.tokenizer_3, text_encoder=self.text_encoder, text_encoder_2=self.text_encoder_2, text_encoder_3=self.text_encoder_3, prompt=data["prompt"], device=accelerator.device, dtype=self.args.transformer_dtype, ) filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt" torch.save(text_conditions, filename.as_posix()) index += 1 progress_bar.update(1) self._delete_components() memory_statistics = get_memory_statistics() logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}") torch.cuda.reset_peak_memory_stats(accelerator.device) # Precompute latents with self.state.accelerator.main_process_first(): latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs()) self._set_components(latent_components) self._move_components_to_device() self._disable_grad_for_components([self.vae]) if self.vae is not None: if self.args.enable_slicing: self.vae.enable_slicing() if self.args.enable_tiling: self.vae.enable_tiling() progress_bar = tqdm( range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes), desc="Precomputing latents", disable=not accelerator.is_local_main_process, ) index = 0 for i, data in enumerate(self.dataset): if i % accelerator.num_processes != accelerator.process_index: continue logger.debug( f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}" ) latent_conditions = self.model_config["prepare_latents"]( vae=self.vae, image_or_video=data["video"].unsqueeze(0), device=accelerator.device, dtype=self.args.transformer_dtype, generator=self.state.generator, precompute=True, ) filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt" torch.save(latent_conditions, filename.as_posix()) index += 1 progress_bar.update(1) self._delete_components() accelerator.wait_for_everyone() logger.info("Precomputation complete") memory_statistics = get_memory_statistics() logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}") torch.cuda.reset_peak_memory_stats(accelerator.device) # Update dataloader to use precomputed conditions and latents self.dataloader = torch.utils.data.DataLoader( PrecomputedDataset( data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id ), batch_size=self.args.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory, ) def prepare_trainable_parameters(self) -> None: logger.info("Initializing trainable parameters") with self.state.accelerator.main_process_first(): diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs()) self._set_components(diffusion_components) components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae] self._disable_grad_for_components(components) if self.args.training_type == "full-finetune": logger.info("Finetuning transformer with no additional parameters") self._enable_grad_for_components([self.transformer]) else: logger.info("Finetuning transformer with PEFT parameters") self._disable_grad_for_components([self.transformer]) # Layerwise upcasting must be applied before adding the LoRA adapter. # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: apply_layerwise_upcasting( self.transformer, storage_dtype=self.args.layerwise_upcasting_storage_dtype, compute_dtype=self.args.transformer_dtype, skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, non_blocking=True, ) self._move_components_to_device() if self.args.gradient_checkpointing: self.transformer.enable_gradient_checkpointing() if self.args.training_type == "lora": transformer_lora_config = LoraConfig( r=self.args.rank, lora_alpha=self.args.lora_alpha, init_lora_weights=True, target_modules=self.args.target_modules, ) self.transformer.add_adapter(transformer_lora_config) else: transformer_lora_config = None # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 # even if layerwise upcasting. Would be nice to have a test as well self.register_saving_loading_hooks(transformer_lora_config) def register_saving_loading_hooks(self, transformer_lora_config): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if self.state.accelerator.is_main_process: transformer_lora_layers_to_save = None for model in models: if isinstance( unwrap_model(self.state.accelerator, model), type(unwrap_model(self.state.accelerator, self.transformer)), ): model = unwrap_model(self.state.accelerator, model) if self.args.training_type == "lora": transformer_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"Unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again if weights: weights.pop() if self.args.training_type == "lora": self.model_config["pipeline_cls"].save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, ) else: model.save_pretrained(os.path.join(output_dir, "transformer")) # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need # to able to load all diffusion components from a specific checkpoint folder during validation, we need to # ensure the scheduler config is serialized as well. self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler")) def load_model_hook(models, input_dir): if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED: while len(models) > 0: model = models.pop() if isinstance( unwrap_model(self.state.accelerator, model), type(unwrap_model(self.state.accelerator, self.transformer)), ): transformer_ = unwrap_model(self.state.accelerator, model) else: raise ValueError( f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}" ) else: transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__ if self.args.training_type == "lora": transformer_ = transformer_cls_.from_pretrained( self.args.pretrained_model_name_or_path, subfolder="transformer" ) transformer_.add_adapter(transformer_lora_config) lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir) transformer_state_dict = { f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") } incompatible_keys = set_peft_model_state_dict( transformer_, transformer_state_dict, adapter_name="default" ) if incompatible_keys is not None: # check only for unexpected keys unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) if unexpected_keys: logger.warning( f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) else: transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer")) self.state.accelerator.register_save_state_pre_hook(save_model_hook) self.state.accelerator.register_load_state_pre_hook(load_model_hook) def prepare_optimizer(self) -> None: logger.info("Initializing optimizer and lr scheduler") self.state.train_epochs = self.args.train_epochs self.state.train_steps = self.args.train_steps # Make sure the trainable params are in float32 if self.args.training_type == "lora": cast_training_params([self.transformer], dtype=torch.float32) self.state.learning_rate = self.args.lr if self.args.scale_lr: self.state.learning_rate = ( self.state.learning_rate * self.args.gradient_accumulation_steps * self.args.batch_size * self.state.accelerator.num_processes ) transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters())) transformer_parameters_with_lr = { "params": transformer_trainable_parameters, "lr": self.state.learning_rate, } params_to_optimize = [transformer_parameters_with_lr] self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters) use_deepspeed_opt = ( self.state.accelerator.state.deepspeed_plugin is not None and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config ) optimizer = get_optimizer( params_to_optimize=params_to_optimize, optimizer_name=self.args.optimizer, learning_rate=self.state.learning_rate, beta1=self.args.beta1, beta2=self.args.beta2, beta3=self.args.beta3, epsilon=self.args.epsilon, weight_decay=self.args.weight_decay, use_8bit=self.args.use_8bit_bnb, use_deepspeed=use_deepspeed_opt, ) num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps) if self.state.train_steps is None: self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch self.state.overwrote_max_train_steps = True use_deepspeed_lr_scheduler = ( self.state.accelerator.state.deepspeed_plugin is not None and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config ) total_training_steps = self.state.train_steps * self.state.accelerator.num_processes num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes if use_deepspeed_lr_scheduler: from accelerate.utils import DummyScheduler lr_scheduler = DummyScheduler( name=self.args.lr_scheduler, optimizer=optimizer, total_num_steps=total_training_steps, num_warmup_steps=num_warmup_steps, ) else: lr_scheduler = get_scheduler( name=self.args.lr_scheduler, optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_training_steps, num_cycles=self.args.lr_num_cycles, power=self.args.lr_power, ) self.optimizer = optimizer self.lr_scheduler = lr_scheduler def prepare_for_training(self) -> None: self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare( self.transformer, self.optimizer, self.dataloader, self.lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps) if self.state.overwrote_max_train_steps: self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch) self.state.num_update_steps_per_epoch = num_update_steps_per_epoch def prepare_trackers(self) -> None: logger.info("Initializing trackers") tracker_name = self.args.tracker_name or "finetrainers-experiment" self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info()) def train(self) -> None: logger.info("Starting training") memory_statistics = get_memory_statistics() logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") if self.vae_config is None: # If we've precomputed conditions and latents already, and are now re-using it, we will never load # the VAE so self.vae_config will not be set. So, we need to load it here. vae_cls = resolve_vae_cls_from_ckpt_path( self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir ) vae_config = vae_cls.load_config( self.args.pretrained_model_name_or_path, subfolder="vae", revision=self.args.revision, cache_dir=self.args.cache_dir, ) self.vae_config = FrozenDict(**vae_config) # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need # to able to load all diffusion components from a specific checkpoint folder during validation, we need to # ensure the scheduler config is serialized as well. if self.args.training_type == "full-finetune": self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler")) self.state.train_batch_size = ( self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps ) info = { "trainable parameters": self.state.num_trainable_parameters, "total samples": len(self.dataset), "train epochs": self.state.train_epochs, "train steps": self.state.train_steps, "batches per device": self.args.batch_size, "total batches observed per epoch": len(self.dataloader), "train batch size": self.state.train_batch_size, "gradient accumulation steps": self.args.gradient_accumulation_steps, } logger.info(f"Training configuration: {json.dumps(info, indent=4)}") global_step = 0 first_epoch = 0 initial_global_step = 0 # Potentially load in the weights and states from a previous save ( resume_from_checkpoint_path, initial_global_step, global_step, first_epoch, ) = get_latest_ckpt_path_to_resume_from( resume_from_checkpoint=self.args.resume_from_checkpoint, num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, output_dir=self.args.output_dir, ) if resume_from_checkpoint_path: self.state.accelerator.load_state(resume_from_checkpoint_path) progress_bar = tqdm( range(0, self.state.train_steps), initial=initial_global_step, desc="Training steps", disable=not self.state.accelerator.is_local_main_process, ) accelerator = self.state.accelerator generator = torch.Generator(device=accelerator.device) if self.args.seed is not None: generator = generator.manual_seed(self.args.seed) self.state.generator = generator scheduler_sigmas = get_scheduler_sigmas(self.scheduler) scheduler_sigmas = ( scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32) if scheduler_sigmas is not None else None ) scheduler_alphas = get_scheduler_alphas(self.scheduler) scheduler_alphas = ( scheduler_alphas.to(device=accelerator.device, dtype=torch.float32) if scheduler_alphas is not None else None ) for epoch in range(first_epoch, self.state.train_epochs): logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})") self.transformer.train() models_to_accumulate = [self.transformer] epoch_loss = 0.0 num_loss_updates = 0 for step, batch in enumerate(self.dataloader): logger.debug(f"Starting step {step + 1}") logs = {} with accelerator.accumulate(models_to_accumulate): if not self.args.precompute_conditions: videos = batch["videos"] prompts = batch["prompts"] batch_size = len(prompts) if self.args.caption_dropout_technique == "empty": if random.random() < self.args.caption_dropout_p: prompts = [""] * batch_size latent_conditions = self.model_config["prepare_latents"]( vae=self.vae, image_or_video=videos, patch_size=self.transformer_config.patch_size, patch_size_t=self.transformer_config.patch_size_t, device=accelerator.device, dtype=self.args.transformer_dtype, generator=self.state.generator, ) text_conditions = self.model_config["prepare_conditions"]( tokenizer=self.tokenizer, text_encoder=self.text_encoder, tokenizer_2=self.tokenizer_2, text_encoder_2=self.text_encoder_2, prompt=prompts, device=accelerator.device, dtype=self.args.transformer_dtype, ) else: latent_conditions = batch["latent_conditions"] text_conditions = batch["text_conditions"] latent_conditions["latents"] = DiagonalGaussianDistribution( latent_conditions["latents"] ).sample(self.state.generator) # This method should only be called for precomputed latents. # TODO(aryan): rename this in separate PR latent_conditions = self.model_config["post_latent_preparation"]( vae_config=self.vae_config, patch_size=self.transformer_config.patch_size, patch_size_t=self.transformer_config.patch_size_t, **latent_conditions, ) align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype) align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype) batch_size = latent_conditions["latents"].shape[0] latent_conditions = make_contiguous(latent_conditions) text_conditions = make_contiguous(text_conditions) if self.args.caption_dropout_technique == "zero": if random.random() < self.args.caption_dropout_p: text_conditions["prompt_embeds"].fill_(0) text_conditions["prompt_attention_mask"].fill_(False) # TODO(aryan): refactor later if "pooled_prompt_embeds" in text_conditions: text_conditions["pooled_prompt_embeds"].fill_(0) sigmas = prepare_sigmas( scheduler=self.scheduler, sigmas=scheduler_sigmas, batch_size=batch_size, num_train_timesteps=self.scheduler.config.num_train_timesteps, flow_weighting_scheme=self.args.flow_weighting_scheme, flow_logit_mean=self.args.flow_logit_mean, flow_logit_std=self.args.flow_logit_std, flow_mode_scale=self.args.flow_mode_scale, device=accelerator.device, generator=self.state.generator, ) timesteps = (sigmas * 1000.0).long() noise = torch.randn( latent_conditions["latents"].shape, generator=self.state.generator, device=accelerator.device, dtype=self.args.transformer_dtype, ) sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim) # TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of # scheduler and calculate the noisy latents accordingly. Look into this later. if "calculate_noisy_latents" in self.model_config.keys(): noisy_latents = self.model_config["calculate_noisy_latents"]( scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"], timesteps=timesteps, ) else: # Default to flow-matching noise addition noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype) latent_conditions.update({"noisy_latents": noisy_latents}) weights = prepare_loss_weights( scheduler=self.scheduler, alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, sigmas=sigmas, flow_weighting_scheme=self.args.flow_weighting_scheme, ) weights = expand_tensor_dims(weights, noise.ndim) pred = self.model_config["forward_pass"]( transformer=self.transformer, scheduler=self.scheduler, timesteps=timesteps, **latent_conditions, **text_conditions, ) target = prepare_target( scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"] ) loss = weights.float() * (pred["latents"].float() - target.float()).pow(2) # Average loss across all but batch dimension loss = loss.mean(list(range(1, loss.ndim))) # Average loss across batch dimension loss = loss.mean() accelerator.backward(loss) if accelerator.sync_gradients: if accelerator.distributed_type == DistributedType.DEEPSPEED: grad_norm = self.transformer.get_global_grad_norm() # In some cases the grad norm may not return a float if torch.is_tensor(grad_norm): grad_norm = grad_norm.item() else: grad_norm = accelerator.clip_grad_norm_( self.transformer.parameters(), self.args.max_grad_norm ) if torch.is_tensor(grad_norm): grad_norm = grad_norm.item() logs["grad_norm"] = grad_norm self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 # Checkpointing if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: if global_step % self.args.checkpointing_steps == 0: save_path = get_intermediate_ckpt_path( checkpointing_limit=self.args.checkpointing_limit, step=global_step, output_dir=self.args.output_dir, ) accelerator.save_state(save_path) # Maybe run validation should_run_validation = ( self.args.validation_every_n_steps is not None and global_step % self.args.validation_every_n_steps == 0 ) if should_run_validation: self.validate(global_step) loss_item = loss.detach().item() epoch_loss += loss_item num_loss_updates += 1 logs["step_loss"] = loss_item logs["lr"] = self.lr_scheduler.get_last_lr()[0] progress_bar.set_postfix(logs) accelerator.log(logs, step=global_step) if global_step >= self.state.train_steps: break if num_loss_updates > 0: epoch_loss /= num_loss_updates accelerator.log({"epoch_loss": epoch_loss}, step=global_step) memory_statistics = get_memory_statistics() logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") # Maybe run validation should_run_validation = ( self.args.validation_every_n_epochs is not None and (epoch + 1) % self.args.validation_every_n_epochs == 0 ) if should_run_validation: self.validate(global_step) accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(accelerator, self.transformer) if self.args.training_type == "lora": transformer_lora_layers = get_peft_model_state_dict(transformer) self.model_config["pipeline_cls"].save_lora_weights( save_directory=self.args.output_dir, transformer_lora_layers=transformer_lora_layers, ) else: transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer")) accelerator.wait_for_everyone() self.validate(step=global_step, final_validation=True) if accelerator.is_main_process: if self.args.push_to_hub: upload_folder( repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"] ) self._delete_components() memory_statistics = get_memory_statistics() logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") accelerator.end_training() def validate(self, step: int, final_validation: bool = False) -> None: logger.info("Starting validation") accelerator = self.state.accelerator num_validation_samples = len(self.args.validation_prompts) if num_validation_samples == 0: logger.warning("No validation samples found. Skipping validation.") if accelerator.is_main_process: if self.args.push_to_hub: save_model_card( args=self.args, repo_id=self.state.repo_id, videos=None, validation_prompts=None, ) return self.transformer.eval() memory_statistics = get_memory_statistics() logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation) all_processes_artifacts = [] prompts_to_filenames = {} for i in range(num_validation_samples): # Skip current validation on all processes but one if i % accelerator.num_processes != accelerator.process_index: continue prompt = self.args.validation_prompts[i] image = self.args.validation_images[i] video = self.args.validation_videos[i] height = self.args.validation_heights[i] width = self.args.validation_widths[i] num_frames = self.args.validation_num_frames[i] frame_rate = self.args.validation_frame_rate if image is not None: image = load_image(image) if video is not None: video = load_video(video) logger.debug( f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", main_process_only=False, ) validation_artifacts = self.model_config["validation"]( pipeline=pipeline, prompt=prompt, image=image, video=video, height=height, width=width, num_frames=num_frames, frame_rate=frame_rate, num_videos_per_prompt=self.args.num_validation_videos_per_prompt, generator=torch.Generator(device=accelerator.device).manual_seed( self.args.seed if self.args.seed is not None else 0 ), # todo support passing `fps` for supported pipelines. ) prompt_filename = string_to_filename(prompt)[:25] artifacts = { "image": {"type": "image", "value": image}, "video": {"type": "video", "value": video}, } for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): if artifact_value: artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) logger.debug( f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", main_process_only=False, ) for index, (key, value) in enumerate(list(artifacts.items())): artifact_type = value["type"] artifact_value = value["value"] if artifact_type not in ["image", "video"] or artifact_value is None: continue extension = "png" if artifact_type == "image" else "mp4" filename = "validation-" if not final_validation else "final-" filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}" if accelerator.is_main_process and extension == "mp4": prompts_to_filenames[prompt] = filename filename = os.path.join(self.args.output_dir, filename) if artifact_type == "image" and artifact_value: logger.debug(f"Saving image to {filename}") artifact_value.save(filename) artifact_value = wandb.Image(filename) elif artifact_type == "video" and artifact_value: logger.debug(f"Saving video to {filename}") # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`. export_to_video(artifact_value, filename, fps=frame_rate) artifact_value = wandb.Video(filename, caption=prompt) all_processes_artifacts.append(artifact_value) all_artifacts = gather_object(all_processes_artifacts) if accelerator.is_main_process: tracker_key = "final" if final_validation else "validation" for tracker in accelerator.trackers: if tracker.name == "wandb": artifact_log_dict = {} image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] if len(image_artifacts) > 0: artifact_log_dict["images"] = image_artifacts video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] if len(video_artifacts) > 0: artifact_log_dict["videos"] = video_artifacts tracker.log({tracker_key: artifact_log_dict}, step=step) if self.args.push_to_hub and final_validation: video_filenames = list(prompts_to_filenames.values()) prompts = list(prompts_to_filenames.keys()) save_model_card( args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts, ) # Remove all hooks that might have been added during pipeline initialization to the models pipeline.remove_all_hooks() del pipeline accelerator.wait_for_everyone() free_memory() memory_statistics = get_memory_statistics() logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") torch.cuda.reset_peak_memory_stats(accelerator.device) if not final_validation: self.transformer.train() def evaluate(self) -> None: raise NotImplementedError("Evaluation has not been implemented yet.") def _init_distributed(self) -> None: logging_dir = Path(self.args.output_dir, self.args.logging_dir) project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) init_process_group_kwargs = InitProcessGroupKwargs( backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) ) report_to = None if self.args.report_to.lower() == "none" else self.args.report_to accelerator = Accelerator( project_config=project_config, gradient_accumulation_steps=self.args.gradient_accumulation_steps, log_with=report_to, kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], ) # Disable AMP for MPS. if torch.backends.mps.is_available(): accelerator.native_amp = False self.state.accelerator = accelerator if self.args.seed is not None: self.state.seed = self.args.seed set_seed(self.args.seed) def _init_logging(self) -> None: logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=FINETRAINERS_LOG_LEVEL, ) if self.state.accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() logger.info("Initialized FineTrainers") logger.info(self.state.accelerator.state, main_process_only=False) def _init_directories_and_repositories(self) -> None: if self.state.accelerator.is_main_process: self.args.output_dir = Path(self.args.output_dir) self.args.output_dir.mkdir(parents=True, exist_ok=True) self.state.output_dir = Path(self.args.output_dir) if self.args.push_to_hub: repo_id = self.args.hub_model_id or Path(self.args.output_dir).name self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id def _init_config_options(self) -> None: # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if self.args.allow_tf32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True def _move_components_to_device(self): if self.text_encoder is not None: self.text_encoder = self.text_encoder.to(self.state.accelerator.device) if self.text_encoder_2 is not None: self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device) if self.text_encoder_3 is not None: self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device) if self.transformer is not None: self.transformer = self.transformer.to(self.state.accelerator.device) if self.unet is not None: self.unet = self.unet.to(self.state.accelerator.device) if self.vae is not None: self.vae = self.vae.to(self.state.accelerator.device) def _get_load_components_kwargs(self) -> Dict[str, Any]: load_component_kwargs = { "text_encoder_dtype": self.args.text_encoder_dtype, "text_encoder_2_dtype": self.args.text_encoder_2_dtype, "text_encoder_3_dtype": self.args.text_encoder_3_dtype, "transformer_dtype": self.args.transformer_dtype, "vae_dtype": self.args.vae_dtype, "shift": self.args.flow_shift, "revision": self.args.revision, "cache_dir": self.args.cache_dir, } if self.args.pretrained_model_name_or_path is not None: load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path return load_component_kwargs def _set_components(self, components: Dict[str, Any]) -> None: # Set models self.tokenizer = components.get("tokenizer", self.tokenizer) self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2) self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3) self.text_encoder = components.get("text_encoder", self.text_encoder) self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2) self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3) self.transformer = components.get("transformer", self.transformer) self.unet = components.get("unet", self.unet) self.vae = components.get("vae", self.vae) self.scheduler = components.get("scheduler", self.scheduler) # Set configs self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config self.vae_config = self.vae.config if self.vae is not None else self.vae_config def _delete_components(self) -> None: self.tokenizer = None self.tokenizer_2 = None self.tokenizer_3 = None self.text_encoder = None self.text_encoder_2 = None self.text_encoder_3 = None self.transformer = None self.unet = None self.vae = None self.scheduler = None free_memory() torch.cuda.synchronize(self.state.accelerator.device) def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline: accelerator = self.state.accelerator if not final_validation: pipeline = self.model_config["initialize_pipeline"]( model_id=self.args.pretrained_model_name_or_path, tokenizer=self.tokenizer, text_encoder=self.text_encoder, tokenizer_2=self.tokenizer_2, text_encoder_2=self.text_encoder_2, transformer=unwrap_model(accelerator, self.transformer), vae=self.vae, device=accelerator.device, revision=self.args.revision, cache_dir=self.args.cache_dir, enable_slicing=self.args.enable_slicing, enable_tiling=self.args.enable_tiling, enable_model_cpu_offload=self.args.enable_model_cpu_offload, is_training=True, ) else: self._delete_components() # Load the transformer weights from the final checkpoint if performing full-finetune transformer = None if self.args.training_type == "full-finetune": transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"] pipeline = self.model_config["initialize_pipeline"]( model_id=self.args.pretrained_model_name_or_path, transformer=transformer, device=accelerator.device, revision=self.args.revision, cache_dir=self.args.cache_dir, enable_slicing=self.args.enable_slicing, enable_tiling=self.args.enable_tiling, enable_model_cpu_offload=self.args.enable_model_cpu_offload, is_training=False, ) # Load the LoRA weights if performing LoRA finetuning if self.args.training_type == "lora": pipeline.load_lora_weights(self.args.output_dir) return pipeline def _disable_grad_for_components(self, components: List[torch.nn.Module]): for component in components: if component is not None: component.requires_grad_(False) def _enable_grad_for_components(self, components: List[torch.nn.Module]): for component in components: if component is not None: component.requires_grad_(True) def _get_training_info(self) -> dict: args = self.args.to_dict() training_args = args.get("training_arguments", {}) training_type = training_args.get("training_type", "") # LoRA/non-LoRA stuff. if training_type == "full-finetune": filtered_training_args = { k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"} } else: filtered_training_args = training_args # Diffusion/flow stuff. diffusion_args = args.get("diffusion_arguments", {}) scheduler_name = self.scheduler.__class__.__name__ if scheduler_name != "FlowMatchEulerDiscreteScheduler": filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} else: filtered_diffusion_args = diffusion_args # Rest of the stuff. updated_training_info = args.copy() updated_training_info["training_arguments"] = filtered_training_args updated_training_info["diffusion_arguments"] = filtered_diffusion_args return updated_training_info