import itertools import math import os import torch import torch.nn.functional as F from torch.utils.data import Dataset from accelerate import Accelerator from accelerate.utils import set_seed from diffusers import DDPMScheduler, StableDiffusionPipeline from diffusers.optimization import get_scheduler import bitsandbytes as bnb from tqdm.auto import tqdm from argparse import Namespace import logging from dataset import DreamBoothDataset logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def load_models(pretrained_model_name_or_path): from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, UNet2DConditionModel tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') return text_encoder, vae, unet, tokenizer def training_function(args, text_encoder, vae, unet, tokenizer): set_seed(args.seed) accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision) vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: text_encoder.gradient_checkpointing_enable() optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW params_to_optimize = ( itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() ) optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, ) def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] if args.with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad({"input_ids": input_ids}, padding="max_length", return_tensors="pt", max_length=tokenizer.model_max_length).input_ids return {"input_ids": input_ids, "pixel_values": pixel_values} train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn) unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) if args.train_text_encoder: text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) class_images_dir.mkdir(parents=True, exist_ok=True) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) if args.train_text_encoder: text_encoder.train() unet.train() global_step = 0 for epoch in range(args.num_train_epochs): progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) encoder_hidden_states = text_encoder(batch["input_ids"])[0] model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(unet.parameters(), 1.0) if args.train_text_encoder: accelerator.clip_grad_norm_(text_encoder.parameters(), 1.0) optimizer.step() optimizer.zero_grad() progress_bar.update(1) global_step += 1 logs = {"loss": loss.detach().item(), "lr": args.learning_rate} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break progress_bar.close() accelerator.wait_for_everyone() if accelerator.is_main_process: if (epoch + 1) % args.save_interval == 0: pipeline = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path) pipeline.save_pretrained(args.output_dir) accelerator.end_training() def parse_args(): args = Namespace( pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", instance_data_dir="datasets/imagedata/images", class_data_dir="./class_images", output_dir="./output", instance_prompt="a photo of yash Kothari", class_prompt="A photo of Yash Kothari with medium, dark hair and a full beard, smiling slightly", resolution=512, center_crop=False, train_text_encoder=True, gradient_accumulation_steps=1, mixed_precision="fp16", learning_rate=5e-6, use_8bit_adam=True, train_batch_size=4, num_train_epochs=100, save_interval=10, max_train_steps=2000, gradient_checkpointing=False, with_prior_preservation=True, seed=42, ) return args if __name__ == "__main__": args = parse_args() text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path) training_function(args, text_encoder, vae, unet, tokenizer)