davidfant commited on
Commit
70d19e4
1 Parent(s): 89ed62a

stable-diffusion-inpainting, 4/5 breast size, updated code (2000 steps)

Browse files
Files changed (36) hide show
  1. args.json +1 -2
  2. samples/0/0.png +0 -0
  3. samples/0/1.png +0 -0
  4. samples/0/2.png +0 -0
  5. samples/0/3.png +0 -0
  6. samples/1/0.png +0 -0
  7. samples/1/1.png +0 -0
  8. samples/1/2.png +0 -0
  9. samples/1/3.png +0 -0
  10. samples/2/0.png +0 -0
  11. samples/2/1.png +0 -0
  12. samples/2/2.png +0 -0
  13. samples/2/3.png +0 -0
  14. samples/3/0.png +0 -0
  15. samples/3/1.png +0 -0
  16. samples/3/2.png +0 -0
  17. samples/3/3.png +0 -0
  18. samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/0.png +0 -0
  19. samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/1.png +0 -0
  20. samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/2.png +0 -0
  21. samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/3.png +0 -0
  22. samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/0.png +0 -0
  23. samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/1.png +0 -0
  24. samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/2.png +0 -0
  25. samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/3.png +0 -0
  26. samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/0.png +0 -0
  27. samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/1.png +0 -0
  28. samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/2.png +0 -0
  29. samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/3.png +0 -0
  30. samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/0.png +0 -0
  31. samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/1.png +0 -0
  32. samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/2.png +0 -0
  33. samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/3.png +0 -0
  34. text_encoder/pytorch_model.bin +1 -1
  35. train_inpainting_dreambooth.py +866 -0
  36. unet/diffusion_pytorch_model.bin +1 -1
args.json CHANGED
@@ -7,12 +7,10 @@
7
  "class_data_dir": null,
8
  "instance_prompt": null,
9
  "class_prompt": null,
10
- "save_sample_prompt": "close up of woman wearing bikini top, amfdk breast size|close up of woman wearing bikini top, boqnf breast size|close up of woman wearing bikini top, czufm breast size|close up of woman wearing bikini top, dpqjd breast size",
11
  "save_sample_negative_prompt": null,
12
  "n_save_sample": 4,
13
  "save_guidance_scale": 7.5,
14
  "save_infer_steps": 50,
15
- "pad_tokens": true,
16
  "with_prior_preservation": true,
17
  "prior_loss_weight": 1.0,
18
  "num_class_images": 300,
@@ -46,6 +44,7 @@
46
  "save_min_steps": 0,
47
  "mixed_precision": "fp16",
48
  "not_cache_latents": true,
 
49
  "local_rank": -1,
50
  "concepts_list": [
51
  {
 
7
  "class_data_dir": null,
8
  "instance_prompt": null,
9
  "class_prompt": null,
 
10
  "save_sample_negative_prompt": null,
11
  "n_save_sample": 4,
12
  "save_guidance_scale": 7.5,
13
  "save_infer_steps": 50,
 
14
  "with_prior_preservation": true,
15
  "prior_loss_weight": 1.0,
16
  "num_class_images": 300,
 
44
  "save_min_steps": 0,
45
  "mixed_precision": "fp16",
46
  "not_cache_latents": true,
47
+ "hflip": false,
48
  "local_rank": -1,
49
  "concepts_list": [
50
  {
samples/0/0.png CHANGED
samples/0/1.png CHANGED
samples/0/2.png CHANGED
samples/0/3.png CHANGED
samples/1/0.png CHANGED
samples/1/1.png CHANGED
samples/1/2.png CHANGED
samples/1/3.png CHANGED
samples/2/0.png CHANGED
samples/2/1.png CHANGED
samples/2/2.png CHANGED
samples/2/3.png CHANGED
samples/3/0.png CHANGED
samples/3/1.png CHANGED
samples/3/2.png CHANGED
samples/3/3.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/0.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/1.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/2.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-amfdk-breast-size/3.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/0.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/1.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/2.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-boqnf-breast-size/3.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/0.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/1.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/2.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-czufm-breast-size/3.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/0.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/1.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/2.png CHANGED
samples/close-up-of-woman-wearing-bikini-top,-dpqjd-breast-size/3.png CHANGED
text_encoder/pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:18e4ec9289fee0ea712fa0be8a4cc42b7a0e697681033246b11d8c181a544146
3
  size 492308087
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1791b109357aaec3009ab57c590775556825c134f27a143c235945c3fdc635d2
3
  size 492308087
train_inpainting_dreambooth.py ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import itertools
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ import shutil
9
+ from contextlib import nullcontext
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from huggingface_hub import HfFolder, Repository, whoami
20
+ from PIL import Image
21
+ from torch.utils.data import Dataset
22
+ from torchvision import transforms
23
+ from tqdm.auto import tqdm
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+
26
+ from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler,
27
+ StableDiffusionInpaintPipeline, UNet2DConditionModel)
28
+ from diffusers.optimization import get_scheduler
29
+
30
+ torch.backends.cudnn.benchmark = True
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ def parse_args(input_args=None):
37
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
38
+ parser.add_argument(
39
+ "--pretrained_model_name_or_path",
40
+ type=str,
41
+ default=None,
42
+ required=True,
43
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
44
+ )
45
+ parser.add_argument(
46
+ "--pretrained_vae_name_or_path",
47
+ type=str,
48
+ default=None,
49
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
50
+ )
51
+ parser.add_argument(
52
+ "--revision",
53
+ type=str,
54
+ default="fp16",
55
+ required=False,
56
+ help="Revision of pretrained model identifier from huggingface.co/models.",
57
+ )
58
+ parser.add_argument(
59
+ "--tokenizer_name",
60
+ type=str,
61
+ default=None,
62
+ help="Pretrained tokenizer name or path if not the same as model_name",
63
+ )
64
+ parser.add_argument(
65
+ "--instance_data_dir",
66
+ type=str,
67
+ default=None,
68
+ help="A folder containing the training data of instance images.",
69
+ )
70
+ parser.add_argument(
71
+ "--class_data_dir",
72
+ type=str,
73
+ default=None,
74
+ help="A folder containing the training data of class images.",
75
+ )
76
+ parser.add_argument(
77
+ "--instance_prompt",
78
+ type=str,
79
+ default=None,
80
+ help="The prompt with identifier specifying the instance",
81
+ )
82
+ parser.add_argument(
83
+ "--class_prompt",
84
+ type=str,
85
+ default=None,
86
+ help="The prompt to specify images in the same class as provided instance images.",
87
+ )
88
+ # parser.add_argument(
89
+ # "--save_sample_prompt",
90
+ # type=str,
91
+ # default=None,
92
+ # help="The prompt used to generate sample outputs to save.",
93
+ # )
94
+ parser.add_argument(
95
+ "--save_sample_negative_prompt",
96
+ type=str,
97
+ default=None,
98
+ help="The negative prompt used to generate sample outputs to save.",
99
+ )
100
+ parser.add_argument(
101
+ "--n_save_sample",
102
+ type=int,
103
+ default=4,
104
+ help="The number of samples to save.",
105
+ )
106
+ parser.add_argument(
107
+ "--save_guidance_scale",
108
+ type=float,
109
+ default=7.5,
110
+ help="CFG for save sample.",
111
+ )
112
+ parser.add_argument(
113
+ "--save_infer_steps",
114
+ type=int,
115
+ default=50,
116
+ help="The number of inference steps for save sample.",
117
+ )
118
+ parser.add_argument(
119
+ "--with_prior_preservation",
120
+ default=False,
121
+ action="store_true",
122
+ help="Flag to add prior preservation loss.",
123
+ )
124
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
125
+ parser.add_argument(
126
+ "--num_class_images",
127
+ type=int,
128
+ default=100,
129
+ help=(
130
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
131
+ " sampled with class_prompt."
132
+ ),
133
+ )
134
+ parser.add_argument(
135
+ "--output_dir",
136
+ type=str,
137
+ default="text-inversion-model",
138
+ help="The output directory where the model predictions and checkpoints will be written.",
139
+ )
140
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
141
+ parser.add_argument(
142
+ "--resolution",
143
+ type=int,
144
+ default=512,
145
+ help=(
146
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
147
+ " resolution"
148
+ ),
149
+ )
150
+ parser.add_argument(
151
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
152
+ )
153
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
154
+ parser.add_argument(
155
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
156
+ )
157
+ parser.add_argument(
158
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
159
+ )
160
+ parser.add_argument("--num_train_epochs", type=int, default=1)
161
+ parser.add_argument(
162
+ "--max_train_steps",
163
+ type=int,
164
+ default=None,
165
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
166
+ )
167
+ parser.add_argument(
168
+ "--gradient_accumulation_steps",
169
+ type=int,
170
+ default=1,
171
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
172
+ )
173
+ parser.add_argument(
174
+ "--gradient_checkpointing",
175
+ action="store_true",
176
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
177
+ )
178
+ parser.add_argument(
179
+ "--learning_rate",
180
+ type=float,
181
+ default=5e-6,
182
+ help="Initial learning rate (after the potential warmup period) to use.",
183
+ )
184
+ parser.add_argument(
185
+ "--scale_lr",
186
+ action="store_true",
187
+ default=False,
188
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
189
+ )
190
+ parser.add_argument(
191
+ "--lr_scheduler",
192
+ type=str,
193
+ default="constant",
194
+ help=(
195
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
196
+ ' "constant", "constant_with_warmup"]'
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
201
+ )
202
+ parser.add_argument(
203
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
204
+ )
205
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
206
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
207
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
208
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
209
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
210
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
211
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
212
+ parser.add_argument(
213
+ "--hub_model_id",
214
+ type=str,
215
+ default=None,
216
+ help="The name of the repository to keep in sync with the local `output_dir`.",
217
+ )
218
+ parser.add_argument(
219
+ "--logging_dir",
220
+ type=str,
221
+ default="logs",
222
+ help=(
223
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
224
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
225
+ ),
226
+ )
227
+ parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
228
+ parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
229
+ parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
230
+ parser.add_argument(
231
+ "--mixed_precision",
232
+ type=str,
233
+ default="no",
234
+ choices=["no", "fp16", "bf16"],
235
+ help=(
236
+ "Whether to use mixed precision. Choose"
237
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
238
+ "and an Nvidia Ampere GPU."
239
+ ),
240
+ )
241
+ parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")
242
+ parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
243
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
244
+ parser.add_argument(
245
+ "--concepts_list",
246
+ type=str,
247
+ default=None,
248
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
249
+ )
250
+
251
+ if input_args is not None:
252
+ args = parser.parse_args(input_args)
253
+ else:
254
+ args = parser.parse_args()
255
+
256
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
257
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
258
+ args.local_rank = env_local_rank
259
+
260
+ return args
261
+
262
+
263
+ def get_cutout_holes(height, width, min_holes=8, max_holes=32, min_height=32, max_height=128, min_width=32, max_width=128):
264
+ holes = []
265
+ for _n in range(random.randint(min_holes, max_holes)):
266
+ hole_height = random.randint(min_height, max_height)
267
+ hole_width = random.randint(min_width, max_width)
268
+ y1 = random.randint(0, height - hole_height)
269
+ x1 = random.randint(0, width - hole_width)
270
+ y2 = y1 + hole_height
271
+ x2 = x1 + hole_width
272
+ holes.append((x1, y1, x2, y2))
273
+ return holes
274
+
275
+
276
+ def generate_random_mask(image):
277
+ mask = torch.zeros_like(image[:1])
278
+ holes = get_cutout_holes(mask.shape[1], mask.shape[2])
279
+ for (x1, y1, x2, y2) in holes:
280
+ mask[:, y1:y2, x1:x2] = 1.
281
+ if random.uniform(0, 1) < 0.25:
282
+ mask.fill_(1.)
283
+ masked_image = image * (mask < 0.5)
284
+ return mask, masked_image
285
+
286
+
287
+ class DreamBoothDataset(Dataset):
288
+ """
289
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
290
+ It pre-processes the images and the tokenizes prompts.
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ concepts_list,
296
+ tokenizer,
297
+ with_prior_preservation=True,
298
+ size=512,
299
+ center_crop=False,
300
+ num_class_images=None,
301
+ hflip=False
302
+ ):
303
+ self.size = size
304
+ self.center_crop = center_crop
305
+ self.tokenizer = tokenizer
306
+ self.with_prior_preservation = with_prior_preservation
307
+ self.instance_images_path = []
308
+ self.class_images_path = []
309
+
310
+ for concept in concepts_list:
311
+ inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()]
312
+ self.instance_images_path.extend(inst_img_path)
313
+
314
+ if with_prior_preservation:
315
+ class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
316
+ self.class_images_path.extend(class_img_path[:num_class_images])
317
+
318
+ random.shuffle(self.instance_images_path)
319
+ self.num_instance_images = len(self.instance_images_path)
320
+ self.num_class_images = len(self.class_images_path)
321
+ self._length = max(self.num_class_images, self.num_instance_images)
322
+
323
+ self.image_transforms = transforms.Compose(
324
+ [
325
+ transforms.RandomHorizontalFlip(0.5 * hflip),
326
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
327
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
328
+ transforms.ToTensor(),
329
+ transforms.Normalize([0.5], [0.5]),
330
+ ]
331
+ )
332
+
333
+ def __len__(self):
334
+ return self._length
335
+
336
+ def __getitem__(self, index):
337
+ example = {}
338
+ instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
339
+ instance_image = Image.open(instance_path)
340
+ if not instance_image.mode == "RGB":
341
+ instance_image = instance_image.convert("RGB")
342
+ example["instance_images"] = self.image_transforms(instance_image)
343
+ example["instance_masks"], example["instance_masked_images"] = generate_random_mask(example["instance_images"])
344
+ example["instance_prompt_ids"] = self.tokenizer(
345
+ instance_prompt,
346
+ padding="max_length",
347
+ truncation=True,
348
+ max_length=self.tokenizer.model_max_length,
349
+ ).input_ids
350
+
351
+ if self.with_prior_preservation:
352
+ class_path, class_prompt = self.class_images_path[index % self.num_class_images]
353
+ class_image = Image.open(class_path)
354
+ if not class_image.mode == "RGB":
355
+ class_image = class_image.convert("RGB")
356
+ example["class_images"] = self.image_transforms(class_image)
357
+ example["class_masks"], example["class_masked_images"] = generate_random_mask(example["class_images"])
358
+ example["class_prompt_ids"] = self.tokenizer(
359
+ class_prompt,
360
+ padding="max_length",
361
+ truncation=True,
362
+ max_length=self.tokenizer.model_max_length,
363
+ ).input_ids
364
+
365
+ return example
366
+
367
+
368
+ class PromptDataset(Dataset):
369
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
370
+
371
+ def __init__(self, prompt, num_samples):
372
+ self.prompt = prompt
373
+ self.num_samples = num_samples
374
+
375
+ def __len__(self):
376
+ return self.num_samples
377
+
378
+ def __getitem__(self, index):
379
+ example = {}
380
+ example["prompt"] = self.prompt
381
+ example["index"] = index
382
+ return example
383
+
384
+
385
+ class LatentsDataset(Dataset):
386
+ def __init__(self, latents_cache, text_encoder_cache):
387
+ self.latents_cache = latents_cache
388
+ self.text_encoder_cache = text_encoder_cache
389
+
390
+ def __len__(self):
391
+ return len(self.latents_cache)
392
+
393
+ def __getitem__(self, index):
394
+ return self.latents_cache[index], self.text_encoder_cache[index]
395
+
396
+
397
+ class AverageMeter:
398
+ def __init__(self, name=None):
399
+ self.name = name
400
+ self.reset()
401
+
402
+ def reset(self):
403
+ self.sum = self.count = self.avg = 0
404
+
405
+ def update(self, val, n=1):
406
+ self.sum += val * n
407
+ self.count += n
408
+ self.avg = self.sum / self.count
409
+
410
+
411
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
412
+ if token is None:
413
+ token = HfFolder.get_token()
414
+ if organization is None:
415
+ username = whoami(token)["name"]
416
+ return f"{username}/{model_id}"
417
+ else:
418
+ return f"{organization}/{model_id}"
419
+
420
+
421
+ def main(args):
422
+ logging_dir = Path(args.output_dir, "0", args.logging_dir)
423
+
424
+ accelerator = Accelerator(
425
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
426
+ mixed_precision=args.mixed_precision,
427
+ log_with="tensorboard",
428
+ logging_dir=logging_dir,
429
+ )
430
+
431
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
432
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
433
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
434
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
435
+ raise ValueError(
436
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
437
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
438
+ )
439
+
440
+ if args.seed is not None:
441
+ set_seed(args.seed)
442
+
443
+ if args.concepts_list is None:
444
+ args.concepts_list = [
445
+ {
446
+ "instance_prompt": args.instance_prompt,
447
+ "class_prompt": args.class_prompt,
448
+ "instance_data_dir": args.instance_data_dir,
449
+ "class_data_dir": args.class_data_dir
450
+ }
451
+ ]
452
+ else:
453
+ with open(args.concepts_list, "r") as f:
454
+ args.concepts_list = json.load(f)
455
+
456
+ if args.with_prior_preservation:
457
+ pipeline = None
458
+ for concept in args.concepts_list:
459
+ class_images_dir = Path(concept["class_data_dir"])
460
+ class_images_dir.mkdir(parents=True, exist_ok=True)
461
+ cur_class_images = len(list(class_images_dir.iterdir()))
462
+
463
+ if cur_class_images < args.num_class_images:
464
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
465
+ if pipeline is None:
466
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
467
+ args.pretrained_model_name_or_path,
468
+ vae=AutoencoderKL.from_pretrained(
469
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
470
+ revision=None if args.pretrained_vae_name_or_path else args.revision
471
+ ),
472
+ torch_dtype=torch_dtype,
473
+ safety_checker=None,
474
+ revision=args.revision
475
+ )
476
+ pipeline.set_progress_bar_config(disable=True)
477
+ pipeline.to(accelerator.device)
478
+
479
+ num_new_images = args.num_class_images - cur_class_images
480
+ logger.info(f"Number of class images to sample: {num_new_images}.")
481
+
482
+ sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
483
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
484
+
485
+ sample_dataloader = accelerator.prepare(sample_dataloader)
486
+
487
+ inp_img = Image.new("RGB", (512, 512), color=(0, 0, 0))
488
+ inp_mask = Image.new("L", (512, 512), color=255)
489
+
490
+ with torch.autocast("cuda"), torch.inference_mode():
491
+ for example in tqdm(
492
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
493
+ ):
494
+ images = pipeline(
495
+ prompt=example["prompt"],
496
+ image=inp_img,
497
+ mask_image=inp_mask,
498
+ num_inference_steps=args.save_infer_steps
499
+ ).images
500
+
501
+ for i, image in enumerate(images):
502
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
503
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
504
+ image.save(image_filename)
505
+
506
+ del pipeline
507
+ if torch.cuda.is_available():
508
+ torch.cuda.empty_cache()
509
+
510
+ # Load the tokenizer
511
+ if args.tokenizer_name:
512
+ tokenizer = CLIPTokenizer.from_pretrained(
513
+ args.tokenizer_name,
514
+ revision=args.revision,
515
+ )
516
+ elif args.pretrained_model_name_or_path:
517
+ tokenizer = CLIPTokenizer.from_pretrained(
518
+ args.pretrained_model_name_or_path,
519
+ subfolder="tokenizer",
520
+ revision=args.revision,
521
+ )
522
+
523
+ # Load models and create wrapper for stable diffusion
524
+ text_encoder = CLIPTextModel.from_pretrained(
525
+ args.pretrained_model_name_or_path,
526
+ subfolder="text_encoder",
527
+ revision=args.revision,
528
+ )
529
+ vae = AutoencoderKL.from_pretrained(
530
+ args.pretrained_model_name_or_path,
531
+ subfolder="vae",
532
+ revision=args.revision,
533
+ )
534
+ unet = UNet2DConditionModel.from_pretrained(
535
+ args.pretrained_model_name_or_path,
536
+ subfolder="unet",
537
+ revision=args.revision,
538
+ torch_dtype=torch.float32
539
+ )
540
+
541
+ vae.requires_grad_(False)
542
+ if not args.train_text_encoder:
543
+ text_encoder.requires_grad_(False)
544
+
545
+ if args.gradient_checkpointing:
546
+ unet.enable_gradient_checkpointing()
547
+ if args.train_text_encoder:
548
+ text_encoder.gradient_checkpointing_enable()
549
+
550
+ if args.scale_lr:
551
+ args.learning_rate = (
552
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
553
+ )
554
+
555
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
556
+ if args.use_8bit_adam:
557
+ try:
558
+ import bitsandbytes as bnb
559
+ except ImportError:
560
+ raise ImportError(
561
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
562
+ )
563
+
564
+ optimizer_class = bnb.optim.AdamW8bit
565
+ else:
566
+ optimizer_class = torch.optim.AdamW
567
+
568
+ params_to_optimize = (
569
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
570
+ )
571
+ optimizer = optimizer_class(
572
+ params_to_optimize,
573
+ lr=args.learning_rate,
574
+ betas=(args.adam_beta1, args.adam_beta2),
575
+ weight_decay=args.adam_weight_decay,
576
+ eps=args.adam_epsilon,
577
+ )
578
+
579
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
580
+
581
+ train_dataset = DreamBoothDataset(
582
+ concepts_list=args.concepts_list,
583
+ tokenizer=tokenizer,
584
+ with_prior_preservation=args.with_prior_preservation,
585
+ size=args.resolution,
586
+ center_crop=args.center_crop,
587
+ num_class_images=args.num_class_images,
588
+ hflip=args.hflip
589
+ )
590
+
591
+ def collate_fn(examples):
592
+ input_ids = [example["instance_prompt_ids"] for example in examples]
593
+ pixel_values = [example["instance_images"] for example in examples]
594
+ mask_values = [example["instance_masks"] for example in examples]
595
+ masked_image_values = [example["instance_masked_images"] for example in examples]
596
+
597
+ # Concat class and instance examples for prior preservation.
598
+ # We do this to avoid doing two forward passes.
599
+ if args.with_prior_preservation:
600
+ input_ids += [example["class_prompt_ids"] for example in examples]
601
+ pixel_values += [example["class_images"] for example in examples]
602
+ mask_values += [example["class_masks"] for example in examples]
603
+ masked_image_values += [example["class_masked_images"] for example in examples]
604
+
605
+ pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
606
+ mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
607
+ masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()
608
+
609
+ input_ids = tokenizer.pad(
610
+ {"input_ids": input_ids},
611
+ padding="max_length",
612
+ max_length=tokenizer.model_max_length,
613
+ return_tensors="pt",
614
+ ).input_ids
615
+
616
+ batch = {
617
+ "input_ids": input_ids,
618
+ "pixel_values": pixel_values,
619
+ "mask_values": mask_values,
620
+ "masked_image_values": masked_image_values
621
+ }
622
+ return batch
623
+
624
+ train_dataloader = torch.utils.data.DataLoader(
625
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=8
626
+ )
627
+
628
+ weight_dtype = torch.float32
629
+ if args.mixed_precision == "fp16":
630
+ weight_dtype = torch.float16
631
+ elif args.mixed_precision == "bf16":
632
+ weight_dtype = torch.bfloat16
633
+
634
+ # Move text_encode and vae to gpu.
635
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
636
+ # as these models are only used for inference, keeping weights in full precision is not required.
637
+ vae.to(accelerator.device, dtype=weight_dtype)
638
+ if not args.train_text_encoder:
639
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
640
+
641
+ if not args.not_cache_latents:
642
+ latents_cache = []
643
+ text_encoder_cache = []
644
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
645
+ with torch.no_grad():
646
+ batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
647
+ batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
648
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
649
+ if args.train_text_encoder:
650
+ text_encoder_cache.append(batch["input_ids"])
651
+ else:
652
+ text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
653
+ train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
654
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
655
+
656
+ del vae
657
+ if not args.train_text_encoder:
658
+ del text_encoder
659
+ if torch.cuda.is_available():
660
+ torch.cuda.empty_cache()
661
+
662
+ # Scheduler and math around the number of training steps.
663
+ overrode_max_train_steps = False
664
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
665
+ if args.max_train_steps is None:
666
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
667
+ overrode_max_train_steps = True
668
+
669
+ lr_scheduler = get_scheduler(
670
+ args.lr_scheduler,
671
+ optimizer=optimizer,
672
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
673
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
674
+ )
675
+
676
+ if args.train_text_encoder:
677
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
678
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
679
+ )
680
+ else:
681
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
682
+ unet, optimizer, train_dataloader, lr_scheduler
683
+ )
684
+
685
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
686
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
687
+ if overrode_max_train_steps:
688
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
689
+ # Afterwards we recalculate our number of training epochs
690
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
691
+
692
+ # We need to initialize the trackers we use, and also store our configuration.
693
+ # The trackers initializes automatically on the main process.
694
+ if accelerator.is_main_process:
695
+ accelerator.init_trackers("dreambooth")
696
+
697
+ # Train!
698
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
699
+
700
+ logger.info("***** Running training *****")
701
+ logger.info(f" Num examples = {len(train_dataset)}")
702
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
703
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
704
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
705
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
706
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
707
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
708
+
709
+ def save_weights(step):
710
+ # Create the pipeline using using the trained modules and save it.
711
+ if accelerator.is_main_process:
712
+ if args.train_text_encoder:
713
+ text_enc_model = accelerator.unwrap_model(text_encoder)
714
+ else:
715
+ text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
716
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
717
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
718
+ args.pretrained_model_name_or_path,
719
+ unet=accelerator.unwrap_model(unet),
720
+ text_encoder=text_enc_model,
721
+ vae=AutoencoderKL.from_pretrained(
722
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
723
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
724
+ revision=None if args.pretrained_vae_name_or_path else args.revision
725
+ ),
726
+ safety_checker=None,
727
+ scheduler=scheduler,
728
+ torch_dtype=torch.float16,
729
+ revision=args.revision,
730
+ )
731
+ save_dir = os.path.join(args.output_dir, f"{step}")
732
+ pipeline.save_pretrained(save_dir)
733
+ with open(os.path.join(save_dir, "args.json"), "w") as f:
734
+ json.dump(args.__dict__, f, indent=2)
735
+
736
+ shutil.copy("train_inpainting_dreambooth.py", save_dir)
737
+
738
+ pipeline = pipeline.to(accelerator.device)
739
+ pipeline.set_progress_bar_config(disable=True)
740
+ for idx, concept in enumerate(args.concepts_list):
741
+ g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed)
742
+ sample_dir = os.path.join(save_dir, "samples", str(idx))
743
+ os.makedirs(sample_dir, exist_ok=True)
744
+ inp_img = Image.new("RGB", (512, 512), color=(0, 0, 0))
745
+ inp_mask = Image.new("L", (512, 512), color=255)
746
+ with torch.autocast("cuda"), torch.inference_mode():
747
+ for i in tqdm(range(args.n_save_sample), desc="Generating samples"):
748
+ images = pipeline(
749
+ prompt=concept["instance_prompt"],
750
+ image=inp_img,
751
+ mask_image=inp_mask,
752
+ negative_prompt=args.save_sample_negative_prompt,
753
+ guidance_scale=args.save_guidance_scale,
754
+ num_inference_steps=args.save_infer_steps,
755
+ generator=g_cuda
756
+ ).images
757
+ images[0].save(os.path.join(sample_dir, f"{i}.png"))
758
+ del pipeline
759
+ if torch.cuda.is_available():
760
+ torch.cuda.empty_cache()
761
+ print(f"[*] Weights saved at {save_dir}")
762
+
763
+ # Only show the progress bar once on each machine.
764
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
765
+ progress_bar.set_description("Steps")
766
+ global_step = 0
767
+ loss_avg = AverageMeter()
768
+ text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
769
+ for epoch in range(args.num_train_epochs):
770
+ unet.train()
771
+ if args.train_text_encoder:
772
+ text_encoder.train()
773
+ random.shuffle(train_dataset.class_images_path)
774
+ for step, batch in enumerate(train_dataloader):
775
+ with accelerator.accumulate(unet):
776
+ # Convert images to latent space
777
+ with torch.no_grad():
778
+ if not args.not_cache_latents:
779
+ latent_dist = batch[0][0]
780
+ else:
781
+ latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
782
+ masked_latent_dist = vae.encode(batch["masked_image_values"].to(dtype=weight_dtype)).latent_dist
783
+ latents = latent_dist.sample() * 0.18215
784
+ masked_image_latents = masked_latent_dist.sample() * 0.18215
785
+ mask = F.interpolate(batch["mask_values"], scale_factor=1 / 8)
786
+
787
+ # Sample noise that we'll add to the latents
788
+ noise = torch.randn_like(latents)
789
+ bsz = latents.shape[0]
790
+ # Sample a random timestep for each image
791
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
792
+ timesteps = timesteps.long()
793
+
794
+ # Add noise to the latents according to the noise magnitude at each timestep
795
+ # (this is the forward diffusion process)
796
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
797
+
798
+ # Get the text embedding for conditioning
799
+ with text_enc_context:
800
+ if not args.not_cache_latents:
801
+ if args.train_text_encoder:
802
+ encoder_hidden_states = text_encoder(batch[0][1])[0]
803
+ else:
804
+ encoder_hidden_states = batch[0][1]
805
+ else:
806
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
807
+
808
+ encoder_hidden_states = F.dropout(encoder_hidden_states, p=0.1)
809
+
810
+ latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
811
+ # Predict the noise residual
812
+ noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
813
+
814
+ if args.with_prior_preservation:
815
+ # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
816
+ noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
817
+ noise, noise_prior = torch.chunk(noise, 2, dim=0)
818
+
819
+ # Compute instance loss
820
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
821
+
822
+ # Compute prior loss
823
+ prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
824
+
825
+ # Add the prior loss to the instance loss.
826
+ loss = loss + args.prior_loss_weight * prior_loss
827
+ else:
828
+ loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
829
+
830
+ accelerator.backward(loss)
831
+ # if accelerator.sync_gradients:
832
+ # params_to_clip = (
833
+ # itertools.chain(unet.parameters(), text_encoder.parameters())
834
+ # if args.train_text_encoder
835
+ # else unet.parameters()
836
+ # )
837
+ # accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
838
+ optimizer.step()
839
+ lr_scheduler.step()
840
+ optimizer.zero_grad(set_to_none=True)
841
+ loss_avg.update(loss.detach_(), bsz)
842
+
843
+ if not global_step % args.log_interval:
844
+ logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
845
+ progress_bar.set_postfix(**logs)
846
+ accelerator.log(logs, step=global_step)
847
+
848
+ if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps:
849
+ save_weights(global_step)
850
+
851
+ progress_bar.update(1)
852
+ global_step += 1
853
+
854
+ if global_step >= args.max_train_steps:
855
+ break
856
+
857
+ accelerator.wait_for_everyone()
858
+
859
+ save_weights(global_step)
860
+
861
+ accelerator.end_training()
862
+
863
+
864
+ if __name__ == "__main__":
865
+ args = parse_args()
866
+ main(args)
unet/diffusion_pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:42c5e779efd52a45b6e501b26277c17d8ea3b7948025a0202c350b690146e7a2
3
  size 3438421925
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3619be9a04e9be8314bfa3ce5781a9a604ec19fe44ec8395863611a5885c62a
3
  size 3438421925