|
import itertools
|
|
import math
|
|
import os
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset
|
|
from accelerate import Accelerator
|
|
from accelerate.utils import set_seed
|
|
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
|
from diffusers.optimization import get_scheduler
|
|
import bitsandbytes as bnb
|
|
from tqdm.auto import tqdm
|
|
from argparse import Namespace
|
|
import logging
|
|
from dataset import DreamBoothDataset
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
def load_models(pretrained_model_name_or_path):
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
from diffusers import AutoencoderKL, UNet2DConditionModel
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
|
|
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
|
|
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
|
|
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
|
|
|
|
return text_encoder, vae, unet, tokenizer
|
|
|
|
def training_function(args, text_encoder, vae, unet, tokenizer):
|
|
set_seed(args.seed)
|
|
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
|
|
|
|
vae.requires_grad_(False)
|
|
if not args.train_text_encoder:
|
|
text_encoder.requires_grad_(False)
|
|
|
|
if args.gradient_checkpointing:
|
|
unet.enable_gradient_checkpointing()
|
|
if args.train_text_encoder:
|
|
text_encoder.gradient_checkpointing_enable()
|
|
|
|
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
|
|
params_to_optimize = (
|
|
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
|
)
|
|
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
|
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
|
|
|
|
train_dataset = DreamBoothDataset(
|
|
instance_data_root=args.instance_data_dir,
|
|
instance_prompt=args.instance_prompt,
|
|
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
|
class_prompt=args.class_prompt,
|
|
tokenizer=tokenizer,
|
|
size=args.resolution,
|
|
center_crop=args.center_crop,
|
|
)
|
|
|
|
def collate_fn(examples):
|
|
input_ids = [example["instance_prompt_ids"] for example in examples]
|
|
pixel_values = [example["instance_images"] for example in examples]
|
|
|
|
if args.with_prior_preservation:
|
|
input_ids += [example["class_prompt_ids"] for example in examples]
|
|
pixel_values += [example["class_images"] for example in examples]
|
|
|
|
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
|
|
input_ids = tokenizer.pad({"input_ids": input_ids}, padding="max_length", return_tensors="pt", max_length=tokenizer.model_max_length).input_ids
|
|
|
|
return {"input_ids": input_ids, "pixel_values": pixel_values}
|
|
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn)
|
|
|
|
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
|
if args.train_text_encoder:
|
|
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
|
|
|
|
if args.with_prior_preservation:
|
|
class_images_dir = Path(args.class_data_dir)
|
|
class_images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
weight_dtype = torch.float32
|
|
if accelerator.mixed_precision == "fp16":
|
|
weight_dtype = torch.float16
|
|
elif accelerator.mixed_precision == "bf16":
|
|
weight_dtype = torch.bfloat16
|
|
|
|
vae.to(accelerator.device, dtype=weight_dtype)
|
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
if args.train_text_encoder:
|
|
text_encoder.train()
|
|
unet.train()
|
|
|
|
global_step = 0
|
|
for epoch in range(args.num_train_epochs):
|
|
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
|
|
progress_bar.set_description(f"Epoch {epoch}")
|
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
with accelerator.accumulate(unet):
|
|
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
|
latents = latents * 0.18215
|
|
|
|
noise = torch.randn_like(latents)
|
|
bsz = latents.shape[0]
|
|
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
|
|
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
|
|
if noise_scheduler.config.prediction_type == "epsilon":
|
|
target = noise
|
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
else:
|
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
|
|
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
accelerator.backward(loss)
|
|
|
|
if accelerator.sync_gradients:
|
|
accelerator.clip_grad_norm_(unet.parameters(), 1.0)
|
|
if args.train_text_encoder:
|
|
accelerator.clip_grad_norm_(text_encoder.parameters(), 1.0)
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
progress_bar.update(1)
|
|
global_step += 1
|
|
|
|
logs = {"loss": loss.detach().item(), "lr": args.learning_rate}
|
|
progress_bar.set_postfix(**logs)
|
|
accelerator.log(logs, step=global_step)
|
|
|
|
if global_step >= args.max_train_steps:
|
|
break
|
|
|
|
progress_bar.close()
|
|
accelerator.wait_for_everyone()
|
|
|
|
if accelerator.is_main_process:
|
|
if (epoch + 1) % args.save_interval == 0:
|
|
pipeline = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path)
|
|
pipeline.save_pretrained(args.output_dir)
|
|
|
|
accelerator.end_training()
|
|
|
|
def parse_args():
|
|
args = Namespace(
|
|
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
|
|
instance_data_dir="datasets/imagedata/images",
|
|
class_data_dir="./class_images",
|
|
output_dir="./output",
|
|
instance_prompt="a photo of yash Kothari",
|
|
class_prompt="A photo of Yash Kothari with medium, dark hair and a full beard, smiling slightly",
|
|
resolution=512,
|
|
center_crop=False,
|
|
train_text_encoder=True,
|
|
gradient_accumulation_steps=1,
|
|
mixed_precision="fp16",
|
|
learning_rate=5e-6,
|
|
use_8bit_adam=True,
|
|
train_batch_size=4,
|
|
num_train_epochs=100,
|
|
save_interval=10,
|
|
max_train_steps=2000,
|
|
gradient_checkpointing=False,
|
|
with_prior_preservation=True,
|
|
seed=42,
|
|
)
|
|
return args
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path)
|
|
training_function(args, text_encoder, vae, unet, tokenizer)
|
|
|