Singularity666 commited on
Commit
48dd315
·
verified ·
1 Parent(s): 2ca2b1f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +96 -60
main.py CHANGED
@@ -1,62 +1,98 @@
1
  import os
 
 
2
  import torch
3
- from torch import autocast
4
- from diffusers import StableDiffusionPipeline, DDIMScheduler
5
- from train_dreambooth import train_dreambooth
6
-
7
- class DreamboothApp:
8
- def __init__(self, model_path, pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"):
9
- self.model_path = model_path
10
- self.pretrained_model_name_or_path = pretrained_model_name_or_path
11
- self.pipe = None
12
- self.g_cuda = torch.Generator(device='cuda')
13
-
14
- def load_model(self):
15
- self.pipe = StableDiffusionPipeline.from_pretrained(self.model_path, safety_checker=None, torch_dtype=torch.float16).to("cuda")
16
- self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
17
- self.pipe.enable_xformers_memory_efficient_attention()
18
-
19
- def train(self, instance_data_dir, class_data_dir, instance_prompt, class_prompt, num_class_images=50, max_train_steps=800, output_dir="stable_diffusion_weights"):
20
- concepts_list = [
21
- {
22
- "instance_prompt": instance_prompt,
23
- "class_prompt": class_prompt,
24
- "instance_data_dir": instance_data_dir,
25
- "class_data_dir": class_data_dir
26
- }
27
- ]
28
- train_dreambooth(pretrained_model_name_or_path=self.pretrained_model_name_or_path,
29
- pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse",
30
- output_dir=output_dir,
31
- revision="fp16",
32
- with_prior_preservation=True,
33
- prior_loss_weight=1.0,
34
- seed=1337,
35
- resolution=512,
36
- train_batch_size=1,
37
- train_text_encoder=True,
38
- mixed_precision="fp16",
39
- use_8bit_adam=True,
40
- gradient_accumulation_steps=1,
41
- learning_rate=1e-6,
42
- lr_scheduler="constant",
43
- lr_warmup_steps=0,
44
- num_class_images=num_class_images,
45
- sample_batch_size=4,
46
- max_train_steps=max_train_steps,
47
- save_interval=10000,
48
- save_sample_prompt=instance_prompt,
49
- concepts_list=concepts_list)
50
- self.model_path = output_dir
51
-
52
- def inference(self, prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, seed=None):
53
- if seed is not None:
54
- self.g_cuda.manual_seed(seed)
55
- with autocast("cuda"), torch.inference_mode():
56
- return self.pipe(
57
- prompt, height=int(height), width=int(width),
58
- negative_prompt=negative_prompt,
59
- num_images_per_prompt=int(num_samples),
60
- num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
61
- generator=self.g_cuda
62
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import shutil
3
+ import json
4
  import torch
5
+ import random
6
+ from pathlib import Path
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
10
+ from transformers import CLIPTextModel, CLIPTokenizer
11
+ from accelerate import Accelerator
12
+ from tqdm.auto import tqdm
13
+ from PIL import Image
14
+
15
+ class CustomDataset(Dataset):
16
+ def __init__(self, data_dir, prompt, tokenizer, size=512, center_crop=False):
17
+ self.data_dir = Path(data_dir)
18
+ self.prompt = prompt
19
+ self.tokenizer = tokenizer
20
+ self.size = size
21
+ self.center_crop = center_crop
22
+
23
+ self.image_transforms = transforms.Compose([
24
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
25
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.5], [0.5])
28
+ ])
29
+
30
+ self.images = [f for f in self.data_dir.iterdir() if f.is_file() and not str(f).endswith(".txt")]
31
+
32
+ def __len__(self):
33
+ return len(self.images)
34
+
35
+ def __getitem__(self, idx):
36
+ image_path = self.images[idx]
37
+ image = Image.open(image_path)
38
+ if not image.mode == "RGB":
39
+ image = image.convert("RGB")
40
+
41
+ image = self.image_transforms(image)
42
+ prompt_ids = self.tokenizer(
43
+ self.prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length
44
+ ).input_ids
45
+
46
+ return {"image": image, "prompt_ids": prompt_ids}
47
+
48
+ def fine_tune_model(instance_data_dir, instance_prompt, model_name, output_dir, seed=1337, resolution=512, train_batch_size=1, max_train_steps=800):
49
+ # Setup
50
+ accelerator = Accelerator(cpu=True)
51
+ set_seed(seed)
52
+
53
+ tokenizer = CLIPTokenizer.from_pretrained(model_name)
54
+ text_encoder = CLIPTextModel.from_pretrained(model_name)
55
+ vae = AutoencoderKL.from_pretrained(model_name)
56
+ unet = UNet2DConditionModel.from_pretrained(model_name)
57
+ noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")
58
+
59
+ dataset = CustomDataset(instance_data_dir, instance_prompt, tokenizer, resolution)
60
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_batch_size, shuffle=True)
61
+
62
+ optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-6)
63
+
64
+ unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
65
+ vae.to(accelerator.device)
66
+ text_encoder.to(accelerator.device)
67
+
68
+ global_step = 0
69
+ for step, batch in tqdm(enumerate(dataloader), total=max_train_steps):
70
+ latents = vae.encode(batch["image"].to(accelerator.device)).latent_dist.sample() * 0.18215
71
+ noise = torch.randn_like(latents)
72
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
73
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
74
+ encoder_hidden_states = text_encoder(batch["prompt_ids"].to(accelerator.device))[0]
75
+
76
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
77
+
78
+ loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
79
+ accelerator.backward(loss)
80
+
81
+ optimizer.step()
82
+ optimizer.zero_grad()
83
+ global_step += 1
84
+ if global_step >= max_train_steps:
85
+ break
86
+
87
+ # Save model
88
+ unet = accelerator.unwrap_model(unet)
89
+ unet.save_pretrained(output_dir)
90
+ vae.save_pretrained(output_dir)
91
+ text_encoder.save_pretrained(output_dir)
92
+ tokenizer.save_pretrained(output_dir)
93
+
94
+ def set_seed(seed):
95
+ random.seed(seed)
96
+ torch.manual_seed(seed)
97
+ if torch.cuda.is_available():
98
+ torch.cuda.manual_seed_all(seed)