Spaces:
Runtime error
Runtime error
File size: 3,845 Bytes
d953dcd 48dd315 d953dcd 48dd315 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import shutil
import json
import torch
import random
from pathlib import Path
from torch.utils.data import Dataset
from torchvision import transforms
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from tqdm.auto import tqdm
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, data_dir, prompt, tokenizer, size=512, center_crop=False):
self.data_dir = Path(data_dir)
self.prompt = prompt
self.tokenizer = tokenizer
self.size = size
self.center_crop = center_crop
self.image_transforms = transforms.Compose([
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
self.images = [f for f in self.data_dir.iterdir() if f.is_file() and not str(f).endswith(".txt")]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_path = self.images[idx]
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = self.image_transforms(image)
prompt_ids = self.tokenizer(
self.prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
).input_ids
return {"image": image, "prompt_ids": prompt_ids}
def fine_tune_model(instance_data_dir, instance_prompt, model_name, output_dir, seed=1337, resolution=512, train_batch_size=1, max_train_steps=800):
# Setup
accelerator = Accelerator(cpu=True)
set_seed(seed)
tokenizer = CLIPTokenizer.from_pretrained(model_name)
text_encoder = CLIPTextModel.from_pretrained(model_name)
vae = AutoencoderKL.from_pretrained(model_name)
unet = UNet2DConditionModel.from_pretrained(model_name)
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
dataset = CustomDataset(instance_data_dir, instance_prompt, tokenizer, resolution)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6)
unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
vae.to(accelerator.device)
text_encoder.to(accelerator.device)
global_step = 0
for step, batch in tqdm(enumerate(dataloader), total=max_train_steps):
latents = vae.encode(batch["image"].to(accelerator.device)).latent_dist.sample() * 0.18215
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device))[0]
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
global_step += 1
if global_step >= max_train_steps:
break
# Save model
unet = accelerator.unwrap_model(unet)
unet.save_pretrained(output_dir)
vae.save_pretrained(output_dir)
text_encoder.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
|