kothariyashhh's picture
Upload 72 files
c09bcc2 verified
import itertools
import math
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, StableDiffusionPipeline
from diffusers.optimization import get_scheduler
import bitsandbytes as bnb
from tqdm.auto import tqdm
from argparse import Namespace
import logging
from dataset import DreamBoothDataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def load_models(pretrained_model_name_or_path):
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
return text_encoder, vae, unet, tokenizer
def training_function(args, text_encoder, vae, unet, tokenizer):
set_seed(args.seed)
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad({"input_ids": input_ids}, padding="max_length", return_tensors="pt", max_length=tokenizer.model_max_length).input_ids
return {"input_ids": input_ids, "pixel_values": pixel_values}
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if args.train_text_encoder:
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
class_images_dir.mkdir(parents=True, exist_ok=True)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
if args.train_text_encoder:
text_encoder.train()
unet.train()
global_step = 0
for epoch in range(args.num_train_epochs):
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), 1.0)
if args.train_text_encoder:
accelerator.clip_grad_norm_(text_encoder.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item(), "lr": args.learning_rate}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
progress_bar.close()
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if (epoch + 1) % args.save_interval == 0:
pipeline = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path)
pipeline.save_pretrained(args.output_dir)
accelerator.end_training()
def parse_args():
args = Namespace(
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
instance_data_dir="datasets/imagedata/images",
class_data_dir="./class_images",
output_dir="./output",
instance_prompt="a photo of yash Kothari",
class_prompt="A photo of Yash Kothari with medium, dark hair and a full beard, smiling slightly",
resolution=512,
center_crop=False,
train_text_encoder=True,
gradient_accumulation_steps=1,
mixed_precision="fp16",
learning_rate=5e-6,
use_8bit_adam=True,
train_batch_size=4,
num_train_epochs=100,
save_interval=10,
max_train_steps=2000,
gradient_checkpointing=False,
with_prior_preservation=True,
seed=42,
)
return args
if __name__ == "__main__":
args = parse_args()
text_encoder, vae, unet, tokenizer = load_models(args.pretrained_model_name_or_path)
training_function(args, text_encoder, vae, unet, tokenizer)