Updates from diffusers
Browse files- pipeline.py +21 -53
pipeline.py
CHANGED
@@ -3,19 +3,19 @@ import re
|
|
3 |
from typing import Callable, List, Optional, Union
|
4 |
|
5 |
import numpy as np
|
|
|
6 |
import torch
|
|
|
|
|
7 |
import random
|
8 |
import sys
|
9 |
from tqdm.auto import tqdm
|
10 |
|
11 |
import diffusers
|
12 |
-
import PIL
|
13 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
14 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
15 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
16 |
-
from diffusers.utils import
|
17 |
-
from packaging import version
|
18 |
-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
19 |
|
20 |
|
21 |
try:
|
@@ -255,7 +255,6 @@ def get_weighted_text_embeddings(
|
|
255 |
no_boseos_middle: Optional[bool] = False,
|
256 |
skip_parsing: Optional[bool] = False,
|
257 |
skip_weighting: Optional[bool] = False,
|
258 |
-
**kwargs,
|
259 |
):
|
260 |
r"""
|
261 |
Prompts can be assigned with local weights using brackets. For example,
|
@@ -603,7 +602,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
603 |
latents = 1 / 0.18215 * latents
|
604 |
image = self.vae.decode(latents).sample
|
605 |
image = (image / 2 + 0.5).clamp(0, 1)
|
606 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with
|
607 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
608 |
return image
|
609 |
|
@@ -684,8 +683,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
684 |
return_dict: bool = True,
|
685 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
686 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
687 |
-
callback_steps:
|
688 |
-
**kwargs,
|
689 |
):
|
690 |
r"""
|
691 |
Function invoked when calling the pipeline for generation.
|
@@ -761,10 +759,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
761 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
762 |
(nsfw) content, according to the `safety_checker`.
|
763 |
"""
|
764 |
-
message = "Please use `image` instead of `init_image`."
|
765 |
-
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
766 |
-
image = init_image or image
|
767 |
-
|
768 |
# 0. Default height and width to unet
|
769 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
770 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
@@ -886,8 +880,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
886 |
return_dict: bool = True,
|
887 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
888 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
889 |
-
callback_steps:
|
890 |
-
**kwargs,
|
891 |
):
|
892 |
r"""
|
893 |
Function for text-to-image generation.
|
@@ -963,7 +956,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
963 |
callback=callback,
|
964 |
is_cancelled_callback=is_cancelled_callback,
|
965 |
callback_steps=callback_steps,
|
966 |
-
**kwargs,
|
967 |
)
|
968 |
|
969 |
def img2img(
|
@@ -982,8 +974,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
982 |
return_dict: bool = True,
|
983 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
984 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
985 |
-
callback_steps:
|
986 |
-
**kwargs,
|
987 |
):
|
988 |
r"""
|
989 |
Function for image-to-image generation.
|
@@ -1059,7 +1050,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1059 |
callback=callback,
|
1060 |
is_cancelled_callback=is_cancelled_callback,
|
1061 |
callback_steps=callback_steps,
|
1062 |
-
**kwargs,
|
1063 |
)
|
1064 |
|
1065 |
def inpaint(
|
@@ -1079,8 +1069,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1079 |
return_dict: bool = True,
|
1080 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1081 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1082 |
-
callback_steps:
|
1083 |
-
**kwargs,
|
1084 |
):
|
1085 |
r"""
|
1086 |
Function for inpaint.
|
@@ -1161,13 +1150,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1161 |
callback=callback,
|
1162 |
is_cancelled_callback=is_cancelled_callback,
|
1163 |
callback_steps=callback_steps,
|
1164 |
-
**kwargs,
|
1165 |
)
|
1166 |
|
1167 |
|
1168 |
# Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
1169 |
def get_text_latent_space(self, prompt, guidance_scale = 7.5):
|
1170 |
-
|
1171 |
# get prompt text embeddings
|
1172 |
text_input = self.tokenizer(
|
1173 |
prompt,
|
@@ -1177,7 +1164,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1177 |
return_tensors="pt",
|
1178 |
)
|
1179 |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
1180 |
-
|
1181 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1182 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1183 |
# corresponds to doing no classifier free guidance.
|
@@ -1196,7 +1183,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1196 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
1197 |
|
1198 |
return text_embeddings
|
1199 |
-
|
1200 |
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
1201 |
""" helper function to spherically interpolate two arrays v1 v2
|
1202 |
from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
|
@@ -1293,11 +1280,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1293 |
eta: Optional[float] = 0.0,
|
1294 |
generator: Optional[torch.Generator] = None,
|
1295 |
output_type: Optional[str] = "pil",
|
1296 |
-
save_n_steps: Optional[int] = None,
|
1297 |
**kwargs,):
|
|
|
1298 |
from diffusers.schedulers import LMSDiscreteScheduler
|
1299 |
batch_size = 1
|
1300 |
-
|
1301 |
if generator == None:
|
1302 |
generator = torch.Generator("cuda")
|
1303 |
generator_state = generator.get_state()
|
@@ -1331,27 +1318,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1331 |
extra_step_kwargs = {}
|
1332 |
if accepts_eta:
|
1333 |
extra_step_kwargs["eta"] = eta
|
1334 |
-
|
1335 |
-
mid_latents = []
|
1336 |
-
mid_images = []
|
1337 |
-
else:
|
1338 |
-
mid_latents = None
|
1339 |
-
mid_images = None
|
1340 |
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
1341 |
-
if save_n_steps:
|
1342 |
-
if i % save_n_steps == 0:
|
1343 |
-
# scale and decode the image latents with vae
|
1344 |
-
dec_mid_latents = 1 / 0.18215 * latents
|
1345 |
-
mid_latents.append(dec_mid_latents)
|
1346 |
-
image = self.vae.decode(dec_mid_latents).sample
|
1347 |
-
|
1348 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
1349 |
-
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
1350 |
-
|
1351 |
-
if output_type == "pil":
|
1352 |
-
image = self.numpy_to_pil(image)
|
1353 |
-
mid_latents.append(image)
|
1354 |
-
mid_images.append(image)
|
1355 |
# expand the latents if we are doing classifier free guidance
|
1356 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1357 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
@@ -1359,7 +1327,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1359 |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
1360 |
|
1361 |
# predict the noise residual
|
1362 |
-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
1363 |
|
1364 |
# perform guidance
|
1365 |
if do_classifier_free_guidance:
|
@@ -1368,21 +1336,21 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1368 |
|
1369 |
# compute the previous noisy sample x_t -> x_t-1
|
1370 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
1371 |
-
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)
|
1372 |
else:
|
1373 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
1374 |
|
1375 |
# scale and decode the image latents with vae
|
1376 |
latents = 1 / 0.18215 * latents
|
1377 |
-
image = self.vae.decode(latents)
|
1378 |
|
1379 |
image = (image / 2 + 0.5).clamp(0, 1)
|
1380 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
1381 |
-
|
1382 |
if output_type == "pil":
|
1383 |
image = self.numpy_to_pil(image)
|
1384 |
|
1385 |
-
return {"image": image, "generator_state": generator_state
|
1386 |
|
1387 |
def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
|
1388 |
# random vector to move in latent space
|
@@ -1390,7 +1358,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1390 |
rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
|
1391 |
scaled_rand_t = rand_t / rand_mag
|
1392 |
variation_embedding = text_embeddings + scaled_rand_t
|
1393 |
-
|
1394 |
generator = torch.Generator("cuda")
|
1395 |
generator.set_state(generator_state)
|
1396 |
result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
|
|
|
3 |
from typing import Callable, List, Optional, Union
|
4 |
|
5 |
import numpy as np
|
6 |
+
import PIL
|
7 |
import torch
|
8 |
+
from packaging import version
|
9 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
10 |
import random
|
11 |
import sys
|
12 |
from tqdm.auto import tqdm
|
13 |
|
14 |
import diffusers
|
|
|
15 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
18 |
+
from diffusers.utils import logging
|
|
|
|
|
19 |
|
20 |
|
21 |
try:
|
|
|
255 |
no_boseos_middle: Optional[bool] = False,
|
256 |
skip_parsing: Optional[bool] = False,
|
257 |
skip_weighting: Optional[bool] = False,
|
|
|
258 |
):
|
259 |
r"""
|
260 |
Prompts can be assigned with local weights using brackets. For example,
|
|
|
602 |
latents = 1 / 0.18215 * latents
|
603 |
image = self.vae.decode(latents).sample
|
604 |
image = (image / 2 + 0.5).clamp(0, 1)
|
605 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
606 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
607 |
return image
|
608 |
|
|
|
683 |
return_dict: bool = True,
|
684 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
685 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
686 |
+
callback_steps: int = 1,
|
|
|
687 |
):
|
688 |
r"""
|
689 |
Function invoked when calling the pipeline for generation.
|
|
|
759 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
760 |
(nsfw) content, according to the `safety_checker`.
|
761 |
"""
|
|
|
|
|
|
|
|
|
762 |
# 0. Default height and width to unet
|
763 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
764 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
880 |
return_dict: bool = True,
|
881 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
882 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
883 |
+
callback_steps: int = 1,
|
|
|
884 |
):
|
885 |
r"""
|
886 |
Function for text-to-image generation.
|
|
|
956 |
callback=callback,
|
957 |
is_cancelled_callback=is_cancelled_callback,
|
958 |
callback_steps=callback_steps,
|
|
|
959 |
)
|
960 |
|
961 |
def img2img(
|
|
|
974 |
return_dict: bool = True,
|
975 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
976 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
977 |
+
callback_steps: int = 1,
|
|
|
978 |
):
|
979 |
r"""
|
980 |
Function for image-to-image generation.
|
|
|
1050 |
callback=callback,
|
1051 |
is_cancelled_callback=is_cancelled_callback,
|
1052 |
callback_steps=callback_steps,
|
|
|
1053 |
)
|
1054 |
|
1055 |
def inpaint(
|
|
|
1069 |
return_dict: bool = True,
|
1070 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
1071 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
1072 |
+
callback_steps: int = 1,
|
|
|
1073 |
):
|
1074 |
r"""
|
1075 |
Function for inpaint.
|
|
|
1150 |
callback=callback,
|
1151 |
is_cancelled_callback=is_cancelled_callback,
|
1152 |
callback_steps=callback_steps,
|
|
|
1153 |
)
|
1154 |
|
1155 |
|
1156 |
# Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
1157 |
def get_text_latent_space(self, prompt, guidance_scale = 7.5):
|
|
|
1158 |
# get prompt text embeddings
|
1159 |
text_input = self.tokenizer(
|
1160 |
prompt,
|
|
|
1164 |
return_tensors="pt",
|
1165 |
)
|
1166 |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
1167 |
+
|
1168 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
1169 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
1170 |
# corresponds to doing no classifier free guidance.
|
|
|
1183 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
1184 |
|
1185 |
return text_embeddings
|
1186 |
+
|
1187 |
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
1188 |
""" helper function to spherically interpolate two arrays v1 v2
|
1189 |
from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
|
|
|
1280 |
eta: Optional[float] = 0.0,
|
1281 |
generator: Optional[torch.Generator] = None,
|
1282 |
output_type: Optional[str] = "pil",
|
|
|
1283 |
**kwargs,):
|
1284 |
+
|
1285 |
from diffusers.schedulers import LMSDiscreteScheduler
|
1286 |
batch_size = 1
|
1287 |
+
|
1288 |
if generator == None:
|
1289 |
generator = torch.Generator("cuda")
|
1290 |
generator_state = generator.get_state()
|
|
|
1318 |
extra_step_kwargs = {}
|
1319 |
if accepts_eta:
|
1320 |
extra_step_kwargs["eta"] = eta
|
1321 |
+
|
|
|
|
|
|
|
|
|
|
|
1322 |
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1323 |
# expand the latents if we are doing classifier free guidance
|
1324 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
1325 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
|
1327 |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
1328 |
|
1329 |
# predict the noise residual
|
1330 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
1331 |
|
1332 |
# perform guidance
|
1333 |
if do_classifier_free_guidance:
|
|
|
1336 |
|
1337 |
# compute the previous noisy sample x_t -> x_t-1
|
1338 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
1339 |
+
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
|
1340 |
else:
|
1341 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
1342 |
|
1343 |
# scale and decode the image latents with vae
|
1344 |
latents = 1 / 0.18215 * latents
|
1345 |
+
image = self.vae.decode(latents)
|
1346 |
|
1347 |
image = (image / 2 + 0.5).clamp(0, 1)
|
1348 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
1349 |
+
|
1350 |
if output_type == "pil":
|
1351 |
image = self.numpy_to_pil(image)
|
1352 |
|
1353 |
+
return {"image": image, "generator_state": generator_state}
|
1354 |
|
1355 |
def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
|
1356 |
# random vector to move in latent space
|
|
|
1358 |
rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
|
1359 |
scaled_rand_t = rand_t / rand_mag
|
1360 |
variation_embedding = text_embeddings + scaled_rand_t
|
1361 |
+
|
1362 |
generator = torch.Generator("cuda")
|
1363 |
generator.set_state(generator_state)
|
1364 |
result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
|