AlanB commited on
Commit
69e533f
·
1 Parent(s): afbfa5a

Updates from diffusers

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -53
pipeline.py CHANGED
@@ -3,19 +3,19 @@ import re
3
  from typing import Callable, List, Optional, Union
4
 
5
  import numpy as np
 
6
  import torch
 
 
7
  import random
8
  import sys
9
  from tqdm.auto import tqdm
10
 
11
  import diffusers
12
- import PIL
13
  from diffusers import SchedulerMixin, StableDiffusionPipeline
14
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
15
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
16
- from diffusers.utils import deprecate, logging
17
- from packaging import version
18
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
19
 
20
 
21
  try:
@@ -255,7 +255,6 @@ def get_weighted_text_embeddings(
255
  no_boseos_middle: Optional[bool] = False,
256
  skip_parsing: Optional[bool] = False,
257
  skip_weighting: Optional[bool] = False,
258
- **kwargs,
259
  ):
260
  r"""
261
  Prompts can be assigned with local weights using brackets. For example,
@@ -603,7 +602,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
603
  latents = 1 / 0.18215 * latents
604
  image = self.vae.decode(latents).sample
605
  image = (image / 2 + 0.5).clamp(0, 1)
606
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
607
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
608
  return image
609
 
@@ -684,8 +683,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
684
  return_dict: bool = True,
685
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
686
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
687
- callback_steps: Optional[int] = 1,
688
- **kwargs,
689
  ):
690
  r"""
691
  Function invoked when calling the pipeline for generation.
@@ -761,10 +759,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
761
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
762
  (nsfw) content, according to the `safety_checker`.
763
  """
764
- message = "Please use `image` instead of `init_image`."
765
- init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
766
- image = init_image or image
767
-
768
  # 0. Default height and width to unet
769
  height = height or self.unet.config.sample_size * self.vae_scale_factor
770
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -886,8 +880,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
886
  return_dict: bool = True,
887
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
888
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
889
- callback_steps: Optional[int] = 1,
890
- **kwargs,
891
  ):
892
  r"""
893
  Function for text-to-image generation.
@@ -963,7 +956,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
963
  callback=callback,
964
  is_cancelled_callback=is_cancelled_callback,
965
  callback_steps=callback_steps,
966
- **kwargs,
967
  )
968
 
969
  def img2img(
@@ -982,8 +974,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
982
  return_dict: bool = True,
983
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
984
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
985
- callback_steps: Optional[int] = 1,
986
- **kwargs,
987
  ):
988
  r"""
989
  Function for image-to-image generation.
@@ -1059,7 +1050,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1059
  callback=callback,
1060
  is_cancelled_callback=is_cancelled_callback,
1061
  callback_steps=callback_steps,
1062
- **kwargs,
1063
  )
1064
 
1065
  def inpaint(
@@ -1079,8 +1069,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1079
  return_dict: bool = True,
1080
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1081
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
1082
- callback_steps: Optional[int] = 1,
1083
- **kwargs,
1084
  ):
1085
  r"""
1086
  Function for inpaint.
@@ -1161,13 +1150,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1161
  callback=callback,
1162
  is_cancelled_callback=is_cancelled_callback,
1163
  callback_steps=callback_steps,
1164
- **kwargs,
1165
  )
1166
 
1167
 
1168
  # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1169
  def get_text_latent_space(self, prompt, guidance_scale = 7.5):
1170
-
1171
  # get prompt text embeddings
1172
  text_input = self.tokenizer(
1173
  prompt,
@@ -1177,7 +1164,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1177
  return_tensors="pt",
1178
  )
1179
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1180
-
1181
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1182
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1183
  # corresponds to doing no classifier free guidance.
@@ -1196,7 +1183,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1196
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1197
 
1198
  return text_embeddings
1199
-
1200
  def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1201
  """ helper function to spherically interpolate two arrays v1 v2
1202
  from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
@@ -1293,11 +1280,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1293
  eta: Optional[float] = 0.0,
1294
  generator: Optional[torch.Generator] = None,
1295
  output_type: Optional[str] = "pil",
1296
- save_n_steps: Optional[int] = None,
1297
  **kwargs,):
 
1298
  from diffusers.schedulers import LMSDiscreteScheduler
1299
  batch_size = 1
1300
-
1301
  if generator == None:
1302
  generator = torch.Generator("cuda")
1303
  generator_state = generator.get_state()
@@ -1331,27 +1318,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1331
  extra_step_kwargs = {}
1332
  if accepts_eta:
1333
  extra_step_kwargs["eta"] = eta
1334
- if save_n_steps:
1335
- mid_latents = []
1336
- mid_images = []
1337
- else:
1338
- mid_latents = None
1339
- mid_images = None
1340
  for i, t in tqdm(enumerate(self.scheduler.timesteps)):
1341
- if save_n_steps:
1342
- if i % save_n_steps == 0:
1343
- # scale and decode the image latents with vae
1344
- dec_mid_latents = 1 / 0.18215 * latents
1345
- mid_latents.append(dec_mid_latents)
1346
- image = self.vae.decode(dec_mid_latents).sample
1347
-
1348
- image = (image / 2 + 0.5).clamp(0, 1)
1349
- image = image.cpu().permute(0, 2, 3, 1).numpy()
1350
-
1351
- if output_type == "pil":
1352
- image = self.numpy_to_pil(image)
1353
- mid_latents.append(image)
1354
- mid_images.append(image)
1355
  # expand the latents if we are doing classifier free guidance
1356
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1357
  if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -1359,7 +1327,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1359
  latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1360
 
1361
  # predict the noise residual
1362
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
1363
 
1364
  # perform guidance
1365
  if do_classifier_free_guidance:
@@ -1368,21 +1336,21 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1368
 
1369
  # compute the previous noisy sample x_t -> x_t-1
1370
  if isinstance(self.scheduler, LMSDiscreteScheduler):
1371
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
1372
  else:
1373
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
1374
 
1375
  # scale and decode the image latents with vae
1376
  latents = 1 / 0.18215 * latents
1377
- image = self.vae.decode(latents).sample
1378
 
1379
  image = (image / 2 + 0.5).clamp(0, 1)
1380
  image = image.cpu().permute(0, 2, 3, 1).numpy()
1381
-
1382
  if output_type == "pil":
1383
  image = self.numpy_to_pil(image)
1384
 
1385
- return {"image": image, "generator_state": generator_state, "mid_latents": mid_latents, "mid_images": mid_images}
1386
 
1387
  def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1388
  # random vector to move in latent space
@@ -1390,7 +1358,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1390
  rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1391
  scaled_rand_t = rand_t / rand_mag
1392
  variation_embedding = text_embeddings + scaled_rand_t
1393
-
1394
  generator = torch.Generator("cuda")
1395
  generator.set_state(generator_state)
1396
  result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
 
3
  from typing import Callable, List, Optional, Union
4
 
5
  import numpy as np
6
+ import PIL
7
  import torch
8
+ from packaging import version
9
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
10
  import random
11
  import sys
12
  from tqdm.auto import tqdm
13
 
14
  import diffusers
 
15
  from diffusers import SchedulerMixin, StableDiffusionPipeline
16
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
 
 
19
 
20
 
21
  try:
 
255
  no_boseos_middle: Optional[bool] = False,
256
  skip_parsing: Optional[bool] = False,
257
  skip_weighting: Optional[bool] = False,
 
258
  ):
259
  r"""
260
  Prompts can be assigned with local weights using brackets. For example,
 
602
  latents = 1 / 0.18215 * latents
603
  image = self.vae.decode(latents).sample
604
  image = (image / 2 + 0.5).clamp(0, 1)
605
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
606
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
607
  return image
608
 
 
683
  return_dict: bool = True,
684
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
685
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
686
+ callback_steps: int = 1,
 
687
  ):
688
  r"""
689
  Function invoked when calling the pipeline for generation.
 
759
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
760
  (nsfw) content, according to the `safety_checker`.
761
  """
 
 
 
 
762
  # 0. Default height and width to unet
763
  height = height or self.unet.config.sample_size * self.vae_scale_factor
764
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
880
  return_dict: bool = True,
881
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
882
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
883
+ callback_steps: int = 1,
 
884
  ):
885
  r"""
886
  Function for text-to-image generation.
 
956
  callback=callback,
957
  is_cancelled_callback=is_cancelled_callback,
958
  callback_steps=callback_steps,
 
959
  )
960
 
961
  def img2img(
 
974
  return_dict: bool = True,
975
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
976
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
977
+ callback_steps: int = 1,
 
978
  ):
979
  r"""
980
  Function for image-to-image generation.
 
1050
  callback=callback,
1051
  is_cancelled_callback=is_cancelled_callback,
1052
  callback_steps=callback_steps,
 
1053
  )
1054
 
1055
  def inpaint(
 
1069
  return_dict: bool = True,
1070
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1071
  is_cancelled_callback: Optional[Callable[[], bool]] = None,
1072
+ callback_steps: int = 1,
 
1073
  ):
1074
  r"""
1075
  Function for inpaint.
 
1150
  callback=callback,
1151
  is_cancelled_callback=is_cancelled_callback,
1152
  callback_steps=callback_steps,
 
1153
  )
1154
 
1155
 
1156
  # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1157
  def get_text_latent_space(self, prompt, guidance_scale = 7.5):
 
1158
  # get prompt text embeddings
1159
  text_input = self.tokenizer(
1160
  prompt,
 
1164
  return_tensors="pt",
1165
  )
1166
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1167
+
1168
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1169
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1170
  # corresponds to doing no classifier free guidance.
 
1183
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1184
 
1185
  return text_embeddings
1186
+
1187
  def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1188
  """ helper function to spherically interpolate two arrays v1 v2
1189
  from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
 
1280
  eta: Optional[float] = 0.0,
1281
  generator: Optional[torch.Generator] = None,
1282
  output_type: Optional[str] = "pil",
 
1283
  **kwargs,):
1284
+
1285
  from diffusers.schedulers import LMSDiscreteScheduler
1286
  batch_size = 1
1287
+
1288
  if generator == None:
1289
  generator = torch.Generator("cuda")
1290
  generator_state = generator.get_state()
 
1318
  extra_step_kwargs = {}
1319
  if accepts_eta:
1320
  extra_step_kwargs["eta"] = eta
1321
+
 
 
 
 
 
1322
  for i, t in tqdm(enumerate(self.scheduler.timesteps)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1323
  # expand the latents if we are doing classifier free guidance
1324
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1325
  if isinstance(self.scheduler, LMSDiscreteScheduler):
 
1327
  latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1328
 
1329
  # predict the noise residual
1330
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1331
 
1332
  # perform guidance
1333
  if do_classifier_free_guidance:
 
1336
 
1337
  # compute the previous noisy sample x_t -> x_t-1
1338
  if isinstance(self.scheduler, LMSDiscreteScheduler):
1339
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
1340
  else:
1341
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1342
 
1343
  # scale and decode the image latents with vae
1344
  latents = 1 / 0.18215 * latents
1345
+ image = self.vae.decode(latents)
1346
 
1347
  image = (image / 2 + 0.5).clamp(0, 1)
1348
  image = image.cpu().permute(0, 2, 3, 1).numpy()
1349
+
1350
  if output_type == "pil":
1351
  image = self.numpy_to_pil(image)
1352
 
1353
+ return {"image": image, "generator_state": generator_state}
1354
 
1355
  def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1356
  # random vector to move in latent space
 
1358
  rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1359
  scaled_rand_t = rand_t / rand_mag
1360
  variation_embedding = text_embeddings + scaled_rand_t
1361
+
1362
  generator = torch.Generator("cuda")
1363
  generator.set_state(generator_state)
1364
  result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)