Commit
·
87f8c3d
1
Parent(s):
42eef31
Delete diffusion.py
Browse files- diffusion.py +0 -586
diffusion.py
DELETED
@@ -1,586 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
"""stabledefusion_using_dataset.ipynb
|
3 |
-
|
4 |
-
Automatically generated by Colaboratory.
|
5 |
-
|
6 |
-
Original file is located at
|
7 |
-
https://colab.research.google.com/drive/1mORMC1aTJ8LzN06Z5zUGbdzlDiPb63EC
|
8 |
-
|
9 |
-
<h1>A stable dufissiun using <b>huggingface diffuser</b>(concepts-library) </h1>
|
10 |
-
|
11 |
-
## Initial setup
|
12 |
-
|
13 |
-
<h2>Install the required libs<h2>
|
14 |
-
"""
|
15 |
-
|
16 |
-
# Install the required libs
|
17 |
-
!pip install -U -qq git+https://github.com/huggingface/diffusers.git
|
18 |
-
!pip install -qq accelerate transformers ftfy
|
19 |
-
!pip install -qq "ipywidgets>=7,<8"
|
20 |
-
|
21 |
-
"""<h2>Install xformers for faster and memory efficient training(for low end GPU)<h2>"""
|
22 |
-
|
23 |
-
# Commented out IPython magic to ensure Python compatibility.
|
24 |
-
|
25 |
-
!pip install -U --pre triton
|
26 |
-
|
27 |
-
from subprocess import getoutput
|
28 |
-
from IPython.display import HTML
|
29 |
-
from IPython.display import clear_output
|
30 |
-
import time
|
31 |
-
|
32 |
-
s = getoutput('nvidia-smi')
|
33 |
-
if 'T4' in s:
|
34 |
-
gpu = 'T4'
|
35 |
-
elif 'P100' in s:
|
36 |
-
gpu = 'P100'
|
37 |
-
elif 'V100' in s:
|
38 |
-
gpu = 'V100'
|
39 |
-
elif 'A100' in s:
|
40 |
-
gpu = 'A100'
|
41 |
-
|
42 |
-
while True:
|
43 |
-
try:
|
44 |
-
gpu=='T4'or gpu=='P100'or gpu=='V100'or gpu=='A100'
|
45 |
-
break
|
46 |
-
except:
|
47 |
-
pass
|
48 |
-
print('[1;31mit seems that your GPU is not supported at the moment')
|
49 |
-
time.sleep(5)
|
50 |
-
|
51 |
-
if (gpu=='T4'):
|
52 |
-
# %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/T4/xformers-0.0.13.dev0-py3-none-any.whl
|
53 |
-
|
54 |
-
elif (gpu=='P100'):
|
55 |
-
# %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/P100/xformers-0.0.13.dev0-py3-none-any.whl
|
56 |
-
|
57 |
-
elif (gpu=='V100'):
|
58 |
-
# %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/V100/xformers-0.0.13.dev0-py3-none-any.whl
|
59 |
-
|
60 |
-
elif (gpu=='A100'):
|
61 |
-
# %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/A100/xformers-0.0.13.dev0-py3-none-any.whl
|
62 |
-
|
63 |
-
"""<h2>Enviroment Setup</h2>"""
|
64 |
-
|
65 |
-
import argparse
|
66 |
-
import itertools
|
67 |
-
import math
|
68 |
-
import os
|
69 |
-
import random
|
70 |
-
|
71 |
-
import numpy as np
|
72 |
-
import torch
|
73 |
-
import torch.nn.functional as F
|
74 |
-
import torch.utils.checkpoint
|
75 |
-
from torch.utils.data import Dataset
|
76 |
-
|
77 |
-
import PIL
|
78 |
-
from accelerate import Accelerator
|
79 |
-
from accelerate.logging import get_logger
|
80 |
-
from accelerate.utils import set_seed
|
81 |
-
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
82 |
-
from diffusers.optimization import get_scheduler
|
83 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
84 |
-
from PIL import Image
|
85 |
-
from torchvision import transforms
|
86 |
-
from tqdm.auto import tqdm
|
87 |
-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
88 |
-
|
89 |
-
def image_grid(imgs, rows, cols):
|
90 |
-
assert len(imgs) == rows*cols
|
91 |
-
|
92 |
-
w, h = imgs[0].size
|
93 |
-
grid = Image.new('RGB', size=(cols*w, rows*h))
|
94 |
-
grid_w, grid_h = grid.size
|
95 |
-
|
96 |
-
for i, img in enumerate(imgs):
|
97 |
-
grid.paste(img, box=(i%cols*w, i//cols*h))
|
98 |
-
return grid
|
99 |
-
|
100 |
-
"""<h2>getting model </h2>"""
|
101 |
-
|
102 |
-
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
|
103 |
-
# #@param ["stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5"] {allow-input: true}
|
104 |
-
|
105 |
-
"""<h2>adding data set using urls</h2>"""
|
106 |
-
|
107 |
-
urls = [
|
108 |
-
|
109 |
-
"https://p4.wallpaperbetter.com/wallpaper/990/374/475/jinx-league-of-legends-arcane-hd-wallpaper-preview.jpg",
|
110 |
-
"https://p4.wallpaperbetter.com/wallpaper/597/595/773/arcane-jinx-league-of-legends-hd-wallpaper-preview.jpg",
|
111 |
-
"https://p4.wallpaperbetter.com/wallpaper/455/986/460/jinx-league-of-legends-arcane-hd-wallpaper-preview.jpg",
|
112 |
-
"https://p4.wallpaperbetter.com/wallpaper/769/667/606/jinx-league-of-legends-arcane-hd-wallpaper-preview.jpg",
|
113 |
-
"https://p4.wallpaperbetter.com/wallpaper/836/342/438/jinx-league-of-legends-vi-league-of-legends-arcane-hd-wallpaper-preview.jpg",
|
114 |
-
"https://p4.wallpaperbetter.com/wallpaper/211/1017/269/cyberpunk-edgerunners-cyberpunk-2077-hd-wallpaper-preview.jpg",
|
115 |
-
"https://p4.wallpaperbetter.com/wallpaper/8/868/36/cyberpunk-edgerunners-cyberpunk-2077-lucy-edgerunners-rebecca-edgerunners-hd-wallpaper-preview.jpg",
|
116 |
-
"https://p4.wallpaperbetter.com/wallpaper/288/722/467/cyberpunk-edgerunners-lucy-edgerunners-anime-girls-cyberpunk-2077-cyberpunk-hd-wallpaper-preview.jpg",
|
117 |
-
|
118 |
-
]
|
119 |
-
|
120 |
-
"""<h2>Checking if immages are loaded</h2>"""
|
121 |
-
|
122 |
-
import requests
|
123 |
-
import glob
|
124 |
-
from io import BytesIO
|
125 |
-
|
126 |
-
def download_image(url):
|
127 |
-
try:
|
128 |
-
response = requests.get(url)
|
129 |
-
except:
|
130 |
-
return None
|
131 |
-
return Image.open(BytesIO(response.content)).convert("RGB")
|
132 |
-
|
133 |
-
images = list(filter(None,[download_image(url) for url in urls]))
|
134 |
-
save_path = "./my_concept"
|
135 |
-
if not os.path.exists(save_path):
|
136 |
-
os.mkdir(save_path)
|
137 |
-
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
|
138 |
-
image_grid(images, 1, len(images))
|
139 |
-
|
140 |
-
"""<h2>innitilizing placeholder and initial token for newly created concept</h2>"""
|
141 |
-
|
142 |
-
what_to_teach = "object"
|
143 |
-
|
144 |
-
placeholder_token = "\u003Canime-style>"
|
145 |
-
|
146 |
-
#tokeniser spellings always check(token dana sa phla confrm ke lana)
|
147 |
-
initializer_token = "character" #character mean k art ma chracter banana ha
|
148 |
-
|
149 |
-
"""<h2>setting peompts for traning"""
|
150 |
-
|
151 |
-
imagenet_templates_small = [
|
152 |
-
"a photo of a {}",
|
153 |
-
"a rendering of a {}",
|
154 |
-
"a cropped photo of the {}",
|
155 |
-
"the photo of a {}",
|
156 |
-
"a photo of a clean {}",
|
157 |
-
"a photo of a dirty {}",
|
158 |
-
"a dark photo of the {}",
|
159 |
-
"a photo of my {}",
|
160 |
-
"a photo of the cool {}",
|
161 |
-
"a close-up photo of a {}",
|
162 |
-
"a bright photo of the {}",
|
163 |
-
"a cropped photo of a {}",
|
164 |
-
"a photo of the {}",
|
165 |
-
"a good photo of the {}",
|
166 |
-
"a photo of one {}",
|
167 |
-
"a close-up photo of the {}",
|
168 |
-
"a rendition of the {}",
|
169 |
-
"a photo of the clean {}",
|
170 |
-
"a rendition of a {}",
|
171 |
-
"a photo of a nice {}",
|
172 |
-
"a good photo of a {}",
|
173 |
-
"a photo of the nice {}",
|
174 |
-
"a photo of the small {}",
|
175 |
-
"a photo of the weird {}",
|
176 |
-
"a photo of the large {}",
|
177 |
-
"a photo of a cool {}",
|
178 |
-
"a photo of a small {}",
|
179 |
-
"4k",
|
180 |
-
"hyeper realistic",
|
181 |
-
]
|
182 |
-
|
183 |
-
imagenet_style_templates_small = [
|
184 |
-
"a painting in the style of {}",
|
185 |
-
"a rendering in the style of {}",
|
186 |
-
"a cropped painting in the style of {}",
|
187 |
-
"the painting in the style of {}",
|
188 |
-
"a clean painting in the style of {}",
|
189 |
-
"a dirty painting in the style of {}",
|
190 |
-
"a dark painting in the style of {}",
|
191 |
-
"a picture in the style of {}",
|
192 |
-
"a cool painting in the style of {}",
|
193 |
-
"a close-up painting in the style of {}",
|
194 |
-
"a bright painting in the style of {}",
|
195 |
-
"a cropped painting in the style of {}",
|
196 |
-
"a good painting in the style of {}",
|
197 |
-
"a close-up painting in the style of {}",
|
198 |
-
"a rendition in the style of {}",
|
199 |
-
"a nice painting in the style of {}",
|
200 |
-
"a small painting in the style of {}",
|
201 |
-
"a weird painting in the style of {}",
|
202 |
-
"a large painting in the style of {}",
|
203 |
-
"lying on rose bed of {}",
|
204 |
-
]
|
205 |
-
|
206 |
-
"""<h2>setting dataset"""
|
207 |
-
|
208 |
-
class TextualInversionDataset(Dataset):
|
209 |
-
def __init__(
|
210 |
-
self,
|
211 |
-
data_root,
|
212 |
-
tokenizer,
|
213 |
-
learnable_property="object", # [object, style]
|
214 |
-
size=512,
|
215 |
-
repeats=100,
|
216 |
-
interpolation="bicubic",
|
217 |
-
flip_p=0.5,
|
218 |
-
set="train",
|
219 |
-
placeholder_token="*",
|
220 |
-
center_crop=False,
|
221 |
-
):
|
222 |
-
|
223 |
-
self.data_root = data_root
|
224 |
-
self.tokenizer = tokenizer
|
225 |
-
self.learnable_property = learnable_property
|
226 |
-
self.size = size
|
227 |
-
self.placeholder_token = placeholder_token
|
228 |
-
self.center_crop = center_crop
|
229 |
-
self.flip_p = flip_p
|
230 |
-
|
231 |
-
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
232 |
-
|
233 |
-
self.num_images = len(self.image_paths)
|
234 |
-
self._length = self.num_images
|
235 |
-
|
236 |
-
if set == "train":
|
237 |
-
self._length = self.num_images * repeats
|
238 |
-
|
239 |
-
self.interpolation = {
|
240 |
-
"linear": PIL.Image.LINEAR,
|
241 |
-
"bilinear": PIL.Image.BILINEAR,
|
242 |
-
"bicubic": PIL.Image.BICUBIC,
|
243 |
-
"lanczos": PIL.Image.LANCZOS,
|
244 |
-
}[interpolation]
|
245 |
-
|
246 |
-
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
247 |
-
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
248 |
-
|
249 |
-
def __len__(self):
|
250 |
-
return self._length
|
251 |
-
|
252 |
-
def __getitem__(self, i):
|
253 |
-
example = {}
|
254 |
-
image = Image.open(self.image_paths[i % self.num_images])
|
255 |
-
|
256 |
-
if not image.mode == "RGB":
|
257 |
-
image = image.convert("RGB")
|
258 |
-
|
259 |
-
placeholder_string = self.placeholder_token
|
260 |
-
text = random.choice(self.templates).format(placeholder_string)
|
261 |
-
|
262 |
-
example["input_ids"] = self.tokenizer(
|
263 |
-
text,
|
264 |
-
padding="max_length",
|
265 |
-
truncation=True,
|
266 |
-
max_length=self.tokenizer.model_max_length,
|
267 |
-
return_tensors="pt",
|
268 |
-
).input_ids[0]
|
269 |
-
|
270 |
-
# default to score-sde preprocessing
|
271 |
-
img = np.array(image).astype(np.uint8)
|
272 |
-
|
273 |
-
if self.center_crop:
|
274 |
-
crop = min(img.shape[0], img.shape[1])
|
275 |
-
h, w, = (
|
276 |
-
img.shape[0],
|
277 |
-
img.shape[1],
|
278 |
-
)
|
279 |
-
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
280 |
-
|
281 |
-
image = Image.fromarray(img)
|
282 |
-
image = image.resize((self.size, self.size), resample=self.interpolation)
|
283 |
-
|
284 |
-
image = self.flip_transform(image)
|
285 |
-
image = np.array(image).astype(np.uint8)
|
286 |
-
image = (image / 127.5 - 1.0).astype(np.float32)
|
287 |
-
|
288 |
-
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
289 |
-
return example
|
290 |
-
|
291 |
-
"""<h2>Load the tokenizer and add the placeholder token as a additional special token."""
|
292 |
-
|
293 |
-
tokenizer = CLIPTokenizer.from_pretrained(
|
294 |
-
pretrained_model_name_or_path,
|
295 |
-
subfolder="tokenizer",
|
296 |
-
)
|
297 |
-
|
298 |
-
# Add the placeholder token in tokenizer
|
299 |
-
num_added_tokens = tokenizer.add_tokens(placeholder_token)
|
300 |
-
if num_added_tokens == 0:
|
301 |
-
raise ValueError(
|
302 |
-
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
303 |
-
" `placeholder_token` that is not already in the tokenizer."
|
304 |
-
)
|
305 |
-
|
306 |
-
"""<h2> Get token ids for our placeholder and initializer token. This code """
|
307 |
-
|
308 |
-
# Convert the initializer_token, placeholder_token to ids
|
309 |
-
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
|
310 |
-
# Check if initializer_token is a single token or a sequence of tokens
|
311 |
-
if len(token_ids) > 1:
|
312 |
-
raise ValueError("The initializer token must be a single token.")
|
313 |
-
|
314 |
-
initializer_token_id = token_ids[0]
|
315 |
-
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
|
316 |
-
|
317 |
-
"""<h2>Load the Stable Diffusion model"""
|
318 |
-
|
319 |
-
# Load models and create wrapper for stable diffusion
|
320 |
-
# pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
|
321 |
-
# del pipeline
|
322 |
-
text_encoder = CLIPTextModel.from_pretrained(
|
323 |
-
pretrained_model_name_or_path, subfolder="text_encoder"
|
324 |
-
)
|
325 |
-
vae = AutoencoderKL.from_pretrained(
|
326 |
-
pretrained_model_name_or_path, subfolder="vae"
|
327 |
-
)
|
328 |
-
unet = UNet2DConditionModel.from_pretrained(
|
329 |
-
pretrained_model_name_or_path, subfolder="unet"
|
330 |
-
)
|
331 |
-
|
332 |
-
"""<h2>added the "placeholder_token" in the "tokenizer" so we resize the token embeddings<h2>
|
333 |
-
<h2>create a new embedding vector in the token embeddings
|
334 |
-
"""
|
335 |
-
|
336 |
-
text_encoder.resize_token_embeddings(len(tokenizer))
|
337 |
-
|
338 |
-
"""<h2>Initialise the newly added placeholder token"""
|
339 |
-
|
340 |
-
token_embeds = text_encoder.get_input_embeddings().weight.data
|
341 |
-
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
342 |
-
|
343 |
-
"""<h2>train the newly added embedding vecto"""
|
344 |
-
|
345 |
-
def freeze_params(params):
|
346 |
-
for param in params:
|
347 |
-
param.requires_grad = False
|
348 |
-
|
349 |
-
# Freeze vae and unet
|
350 |
-
freeze_params(vae.parameters())
|
351 |
-
freeze_params(unet.parameters())
|
352 |
-
# Freeze all parameters except for the token embeddings in text encoder
|
353 |
-
params_to_freeze = itertools.chain(
|
354 |
-
text_encoder.text_model.encoder.parameters(),
|
355 |
-
text_encoder.text_model.final_layer_norm.parameters(),
|
356 |
-
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
357 |
-
)
|
358 |
-
freeze_params(params_to_freeze)
|
359 |
-
|
360 |
-
"""<h2>Creating training data"""
|
361 |
-
|
362 |
-
train_dataset = TextualInversionDataset(
|
363 |
-
data_root=save_path,
|
364 |
-
tokenizer=tokenizer,
|
365 |
-
size=vae.sample_size,
|
366 |
-
placeholder_token=placeholder_token,
|
367 |
-
repeats=100,
|
368 |
-
learnable_property=what_to_teach, #Option selected above between object and style
|
369 |
-
center_crop=False,
|
370 |
-
set="train",
|
371 |
-
)
|
372 |
-
|
373 |
-
def create_dataloader(train_batch_size=1):
|
374 |
-
return torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
|
375 |
-
|
376 |
-
#creating noise secdular
|
377 |
-
noise_scheduler = DDPMScheduler.from_config(pretrained_model_name_or_path, subfolder="scheduler")
|
378 |
-
|
379 |
-
"""<h2>Setting traning arguments"""
|
380 |
-
|
381 |
-
hyperparameters = {
|
382 |
-
"learning_rate": 5e-04,
|
383 |
-
"scale_lr": True,
|
384 |
-
"max_train_steps": 2000,
|
385 |
-
"save_steps": 250,
|
386 |
-
"train_batch_size": 4,
|
387 |
-
"gradient_accumulation_steps": 1,
|
388 |
-
"gradient_checkpointing": True,
|
389 |
-
"mixed_precision": "fp16",
|
390 |
-
"seed": 42,
|
391 |
-
"output_dir": "sd-concept-output"
|
392 |
-
}
|
393 |
-
!mkdir -p sd-concept-output
|
394 |
-
|
395 |
-
"""<h2>traninfg functions"""
|
396 |
-
|
397 |
-
logger = get_logger(__name__)
|
398 |
-
|
399 |
-
def save_progress(text_encoder, placeholder_token_id, accelerator, save_path):
|
400 |
-
logger.info("Saving embeddings")
|
401 |
-
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
402 |
-
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
|
403 |
-
torch.save(learned_embeds_dict, save_path)
|
404 |
-
|
405 |
-
def training_function(text_encoder, vae, unet):
|
406 |
-
train_batch_size = hyperparameters["train_batch_size"]
|
407 |
-
gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
|
408 |
-
learning_rate = hyperparameters["learning_rate"]
|
409 |
-
max_train_steps = hyperparameters["max_train_steps"]
|
410 |
-
output_dir = hyperparameters["output_dir"]
|
411 |
-
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
412 |
-
|
413 |
-
accelerator = Accelerator(
|
414 |
-
gradient_accumulation_steps=gradient_accumulation_steps,
|
415 |
-
mixed_precision=hyperparameters["mixed_precision"]
|
416 |
-
)
|
417 |
-
|
418 |
-
if gradient_checkpointing:
|
419 |
-
text_encoder.gradient_checkpointing_enable()
|
420 |
-
unet.enable_gradient_checkpointing()
|
421 |
-
|
422 |
-
train_dataloader = create_dataloader(train_batch_size)
|
423 |
-
|
424 |
-
if hyperparameters["scale_lr"]:
|
425 |
-
learning_rate = (
|
426 |
-
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
427 |
-
)
|
428 |
-
|
429 |
-
# Initialize the optimizer
|
430 |
-
optimizer = torch.optim.AdamW(
|
431 |
-
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
432 |
-
lr=learning_rate,
|
433 |
-
)
|
434 |
-
|
435 |
-
text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
436 |
-
text_encoder, optimizer, train_dataloader
|
437 |
-
)
|
438 |
-
|
439 |
-
weight_dtype = torch.float32
|
440 |
-
if accelerator.mixed_precision == "fp16":
|
441 |
-
weight_dtype = torch.float16
|
442 |
-
elif accelerator.mixed_precision == "bf16":
|
443 |
-
weight_dtype = torch.bfloat16
|
444 |
-
|
445 |
-
# Move vae and unet to device
|
446 |
-
vae.to(accelerator.device, dtype=weight_dtype)
|
447 |
-
unet.to(accelerator.device, dtype=weight_dtype)
|
448 |
-
|
449 |
-
# Keep vae in eval mode as we don't train it
|
450 |
-
vae.eval()
|
451 |
-
# Keep unet in train mode to enable gradient checkpointing
|
452 |
-
unet.train()
|
453 |
-
|
454 |
-
|
455 |
-
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
456 |
-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
457 |
-
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
458 |
-
|
459 |
-
# Train!
|
460 |
-
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
461 |
-
|
462 |
-
logger.info("***** Running training *****")
|
463 |
-
logger.info(f" Num examples = {len(train_dataset)}")
|
464 |
-
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
465 |
-
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
466 |
-
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
467 |
-
logger.info(f" Total optimization steps = {max_train_steps}")
|
468 |
-
# Only show the progress bar once on each machine.
|
469 |
-
progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
|
470 |
-
progress_bar.set_description("Steps")
|
471 |
-
global_step = 0
|
472 |
-
|
473 |
-
for epoch in range(num_train_epochs):
|
474 |
-
text_encoder.train()
|
475 |
-
for step, batch in enumerate(train_dataloader):
|
476 |
-
with accelerator.accumulate(text_encoder):
|
477 |
-
# Convert images to latent space
|
478 |
-
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
479 |
-
latents = latents * 0.18215
|
480 |
-
|
481 |
-
# Sample noise that we'll add to the latents
|
482 |
-
noise = torch.randn_like(latents)
|
483 |
-
bsz = latents.shape[0]
|
484 |
-
# Sample a random timestep for each image
|
485 |
-
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
|
486 |
-
|
487 |
-
# Add noise to the latents according to the noise magnitude at each timestep
|
488 |
-
# (this is the forward diffusion process)
|
489 |
-
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
490 |
-
|
491 |
-
# Get the text embedding for conditioning
|
492 |
-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
493 |
-
|
494 |
-
# Predict the noise residual
|
495 |
-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states.to(weight_dtype)).sample
|
496 |
-
|
497 |
-
# Get the target for loss depending on the prediction type
|
498 |
-
if noise_scheduler.config.prediction_type == "epsilon":
|
499 |
-
target = noise
|
500 |
-
elif noise_scheduler.config.prediction_type == "v_prediction":
|
501 |
-
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
502 |
-
else:
|
503 |
-
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
504 |
-
|
505 |
-
loss = F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
506 |
-
accelerator.backward(loss)
|
507 |
-
|
508 |
-
# Zero out the gradients for all token embeddings except the newly added
|
509 |
-
# embeddings for the concept, as we only want to optimize the concept embeddings
|
510 |
-
if accelerator.num_processes > 1:
|
511 |
-
grads = text_encoder.module.get_input_embeddings().weight.grad
|
512 |
-
else:
|
513 |
-
grads = text_encoder.get_input_embeddings().weight.grad
|
514 |
-
# Get the index for tokens that we want to zero the grads for
|
515 |
-
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
516 |
-
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
|
517 |
-
|
518 |
-
optimizer.step()
|
519 |
-
optimizer.zero_grad()
|
520 |
-
|
521 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
522 |
-
if accelerator.sync_gradients:
|
523 |
-
progress_bar.update(1)
|
524 |
-
global_step += 1
|
525 |
-
if global_step % hyperparameters["save_steps"] == 0:
|
526 |
-
save_path = os.path.join(output_dir, f"learned_embeds-step-{global_step}.bin")
|
527 |
-
save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
|
528 |
-
|
529 |
-
logs = {"loss": loss.detach().item()}
|
530 |
-
progress_bar.set_postfix(**logs)
|
531 |
-
|
532 |
-
if global_step >= max_train_steps:
|
533 |
-
break
|
534 |
-
|
535 |
-
accelerator.wait_for_everyone()
|
536 |
-
|
537 |
-
|
538 |
-
# Create the pipeline using using the trained modules and save it.
|
539 |
-
if accelerator.is_main_process:
|
540 |
-
pipeline = StableDiffusionPipeline.from_pretrained(
|
541 |
-
pretrained_model_name_or_path,
|
542 |
-
text_encoder=accelerator.unwrap_model(text_encoder),
|
543 |
-
tokenizer=tokenizer,
|
544 |
-
vae=vae,
|
545 |
-
unet=unet,
|
546 |
-
)
|
547 |
-
pipeline.save_pretrained(output_dir)
|
548 |
-
# Also save the newly trained embeddings
|
549 |
-
save_path = os.path.join(output_dir, f"learned_embeds.bin")
|
550 |
-
save_progress(text_encoder, placeholder_token_id, accelerator, save_path)
|
551 |
-
|
552 |
-
"""<h2>launching traning on gpu(will not work without gpu)"""
|
553 |
-
|
554 |
-
import accelerate
|
555 |
-
accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet))
|
556 |
-
|
557 |
-
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
558 |
-
if param.grad is not None:
|
559 |
-
del param.grad # free some memory
|
560 |
-
torch.cuda.empty_cache()
|
561 |
-
|
562 |
-
"""<h2>set up pipeline"""
|
563 |
-
|
564 |
-
from diffusers import DPMSolverMultistepScheduler
|
565 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
566 |
-
hyperparameters["output_dir"],
|
567 |
-
scheduler=DPMSolverMultistepScheduler.from_pretrained(hyperparameters["output_dir"], subfolder="scheduler"),
|
568 |
-
torch_dtype=torch.float16,
|
569 |
-
).to("cuda")
|
570 |
-
|
571 |
-
#@title Run the Stable Diffusion pipeline
|
572 |
-
|
573 |
-
|
574 |
-
prompt = "Planet scale halo of water in space digital art, Trending on ArtStation" #@param {type:"string"}
|
575 |
-
|
576 |
-
num_samples = 4
|
577 |
-
num_rows = 1
|
578 |
-
|
579 |
-
all_images = []
|
580 |
-
for _ in range(num_rows):
|
581 |
-
images = pipe([prompt] * num_samples, num_inference_steps=30, guidance_scale=7.5).images
|
582 |
-
all_images.extend(images)
|
583 |
-
|
584 |
-
grid = image_grid(all_images, num_rows, num_samples)
|
585 |
-
|
586 |
-
grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|