AlanB commited on
Commit
a551d9c
·
1 Parent(s): e5ecc74

Forgot something..

Browse files
Files changed (1) hide show
  1. pipeline.py +226 -7
pipeline.py CHANGED
@@ -4,6 +4,9 @@ from typing import Callable, List, Optional, Union
4
 
5
  import numpy as np
6
  import torch
 
 
 
7
 
8
  import PIL
9
  from diffusers.configuration_utils import FrozenDict
@@ -11,7 +14,7 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
  from diffusers.pipeline_utils import DiffusionPipeline
12
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
  from diffusers.utils import deprecate, is_accelerate_available, logging
16
 
17
  # TODO: remove and import from diffusers.utils when the new version of diffusers is released
@@ -435,7 +438,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
435
  text_encoder: CLIPTextModel,
436
  tokenizer: CLIPTokenizer,
437
  unet: UNet2DConditionModel,
438
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
439
  safety_checker: StableDiffusionSafetyChecker,
440
  feature_extractor: CLIPFeatureExtractor,
441
  ):
@@ -468,8 +471,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
468
  new_config["clip_sample"] = False
469
  scheduler._internal_dict = FrozenDict(new_config)
470
 
471
- if safety_checker is None:
472
- logger.warning(
 
473
  f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
474
  " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
475
  " results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -520,9 +524,14 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
520
  `attention_head_dim` must be a multiple of `slice_size`.
521
  """
522
  if slice_size == "auto":
523
- # half the attention head size is usually a good trade-off between
524
- # speed and memory
525
- slice_size = self.unet.config.attention_head_dim // 2
 
 
 
 
 
526
  self.unet.set_attention_slice(slice_size)
527
 
528
  def disable_attention_slicing(self):
@@ -1146,3 +1155,213 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1146
  callback_steps=callback_steps,
1147
  **kwargs,
1148
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import numpy as np
6
  import torch
7
+ import random
8
+ import sys
9
+ from tqdm.auto import tqdm
10
 
11
  import PIL
12
  from diffusers.configuration_utils import FrozenDict
 
14
  from diffusers.pipeline_utils import DiffusionPipeline
15
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
18
  from diffusers.utils import deprecate, is_accelerate_available, logging
19
 
20
  # TODO: remove and import from diffusers.utils when the new version of diffusers is released
 
438
  text_encoder: CLIPTextModel,
439
  tokenizer: CLIPTokenizer,
440
  unet: UNet2DConditionModel,
441
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler],
442
  safety_checker: StableDiffusionSafetyChecker,
443
  feature_extractor: CLIPFeatureExtractor,
444
  ):
 
471
  new_config["clip_sample"] = False
472
  scheduler._internal_dict = FrozenDict(new_config)
473
 
474
+ #if safety_checker is None:
475
+ if False:
476
+ logger.warn(
477
  f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
478
  " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
479
  " results in services or applications open to the public. Both the diffusers team and Hugging Face"
 
524
  `attention_head_dim` must be a multiple of `slice_size`.
525
  """
526
  if slice_size == "auto":
527
+ if isinstance(self.unet.config.attention_head_dim, int):
528
+ # half the attention head size is usually a good trade-off between
529
+ # speed and memory
530
+ slice_size = self.unet.config.attention_head_dim // 2
531
+ else:
532
+ # if `attention_head_dim` is a list, take the smallest head size
533
+ slice_size = min(self.unet.config.attention_head_dim)
534
+
535
  self.unet.set_attention_slice(slice_size)
536
 
537
  def disable_attention_slicing(self):
 
1155
  callback_steps=callback_steps,
1156
  **kwargs,
1157
  )
1158
+ # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1159
+ def get_text_latent_space(self, prompt, guidance_scale = 7.5):
1160
+ # get prompt text embeddings
1161
+ text_input = self.tokenizer(
1162
+ prompt,
1163
+ padding="max_length",
1164
+ max_length=self.tokenizer.model_max_length,
1165
+ truncation=True,
1166
+ return_tensors="pt",
1167
+ )
1168
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1169
+
1170
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1171
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1172
+ # corresponds to doing no classifier free guidance.
1173
+ do_classifier_free_guidance = guidance_scale > 1.0
1174
+ # get unconditional embeddings for classifier free guidance
1175
+ if do_classifier_free_guidance:
1176
+ max_length = text_input.input_ids.shape[-1]
1177
+ uncond_input = self.tokenizer(
1178
+ [""], padding="max_length", max_length=max_length, return_tensors="pt"
1179
+ )
1180
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
1181
+
1182
+ # For classifier free guidance, we need to do two forward passes.
1183
+ # Here we concatenate the unconditional and text embeddings into a single batch
1184
+ # to avoid doing two forward passes
1185
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1186
+
1187
+ return text_embeddings
1188
+
1189
+ def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1190
+ """ helper function to spherically interpolate two arrays v1 v2
1191
+ from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
1192
+ this should be better than lerping for moving between noise spaces """
1193
+
1194
+ if not isinstance(v0, np.ndarray):
1195
+ inputs_are_torch = True
1196
+ input_device = v0.device
1197
+ v0 = v0.cpu().numpy()
1198
+ v1 = v1.cpu().numpy()
1199
+
1200
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
1201
+ if np.abs(dot) > DOT_THRESHOLD:
1202
+ v2 = (1 - t) * v0 + t * v1
1203
+ else:
1204
+ theta_0 = np.arccos(dot)
1205
+ sin_theta_0 = np.sin(theta_0)
1206
+ theta_t = theta_0 * t
1207
+ sin_theta_t = np.sin(theta_t)
1208
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
1209
+ s1 = sin_theta_t / sin_theta_0
1210
+ v2 = s0 * v0 + s1 * v1
1211
+
1212
+ if inputs_are_torch:
1213
+ v2 = torch.from_numpy(v2).to(input_device)
1214
+
1215
+ return v2
1216
+
1217
+ def lerp_between_prompts(self, first_prompt, second_prompt, seed = None, length = 10, save=False, guidance_scale: Optional[float] = 7.5, **kwargs):
1218
+ first_embedding = self.get_text_latent_space(first_prompt)
1219
+ second_embedding = self.get_text_latent_space(second_prompt)
1220
+ if not seed:
1221
+ seed = random.randint(0, sys.maxsize)
1222
+ generator = torch.Generator(self.device)
1223
+ generator.manual_seed(seed)
1224
+ generator_state = generator.get_state()
1225
+ lerp_embed_points = []
1226
+ for i in range(length):
1227
+ weight = i / length
1228
+ tensor_lerp = torch.lerp(first_embedding, second_embedding, weight)
1229
+ lerp_embed_points.append(tensor_lerp)
1230
+ images = []
1231
+ for idx, latent_point in enumerate(lerp_embed_points):
1232
+ generator.set_state(generator_state)
1233
+ image = self.diffuse_from_inits(latent_point, **kwargs)["image"][0]
1234
+ images.append(image)
1235
+ if save:
1236
+ image.save(f"{first_prompt}-{second_prompt}-{idx:02d}.png", "PNG")
1237
+ return {"images": images, "latent_points": lerp_embed_points,"generator_state": generator_state}
1238
+
1239
+ def slerp_through_seeds(self,
1240
+ prompt,
1241
+ height: Optional[int] = 512,
1242
+ width: Optional[int] = 512,
1243
+ save = False,
1244
+ seed = None, steps = 10, **kwargs):
1245
+
1246
+ if not seed:
1247
+ seed = random.randint(0, sys.maxsize)
1248
+ generator = torch.Generator(self.device)
1249
+ generator.manual_seed(seed)
1250
+ init_start = torch.randn(
1251
+ (1, self.unet.in_channels, height // 8, width // 8),
1252
+ generator = generator, device = self.device)
1253
+ init_end = torch.randn(
1254
+ (1, self.unet.in_channels, height // 8, width // 8),
1255
+ generator = generator, device = self.device)
1256
+ generator_state = generator.get_state()
1257
+ slerp_embed_points = []
1258
+ # weight from 0 to 1/(steps - 1), add init_end specifically so that we
1259
+ # have len(images) = steps
1260
+ for i in range(steps - 1):
1261
+ weight = i / steps
1262
+ tensor_slerp = self.slerp(weight, init_start, init_end)
1263
+ slerp_embed_points.append(tensor_slerp)
1264
+ slerp_embed_points.append(init_end)
1265
+ images = []
1266
+ embed_point = self.get_text_latent_space(prompt)
1267
+ for idx, noise_point in enumerate(slerp_embed_points):
1268
+ generator.set_state(generator_state)
1269
+ image = self.diffuse_from_inits(embed_point, init = noise_point, **kwargs)["image"][0]
1270
+ images.append(image)
1271
+ if save:
1272
+ image.save(f"{seed}-{idx:02d}.png", "PNG")
1273
+ return {"images": images, "noise_samples": slerp_embed_points,"generator_state": generator_state}
1274
+
1275
+ @torch.no_grad()
1276
+ def diffuse_from_inits(self, text_embeddings,
1277
+ init = None,
1278
+ height: Optional[int] = 512,
1279
+ width: Optional[int] = 512,
1280
+ num_inference_steps: Optional[int] = 50,
1281
+ guidance_scale: Optional[float] = 7.5,
1282
+ eta: Optional[float] = 0.0,
1283
+ generator: Optional[torch.Generator] = None,
1284
+ output_type: Optional[str] = "pil",
1285
+ **kwargs,):
1286
+
1287
+ batch_size = 1
1288
+
1289
+ if generator == None:
1290
+ generator = torch.Generator("cuda")
1291
+ generator_state = generator.get_state()
1292
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1293
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1294
+ # corresponds to doing no classifier free guidance.
1295
+ do_classifier_free_guidance = guidance_scale > 1.0
1296
+ # get the intial random noise
1297
+ latents = init if init is not None else torch.randn(
1298
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
1299
+ generator=generator,
1300
+ device=self.device,)
1301
+
1302
+ # set timesteps
1303
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
1304
+ extra_set_kwargs = {}
1305
+ if accepts_offset:
1306
+ extra_set_kwargs["offset"] = 1
1307
+
1308
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
1309
+
1310
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
1311
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1312
+ latents = latents * self.scheduler.sigmas[0]
1313
+
1314
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1315
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1316
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1317
+ # and should be between [0, 1]
1318
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
1319
+ extra_step_kwargs = {}
1320
+ if accepts_eta:
1321
+ extra_step_kwargs["eta"] = eta
1322
+
1323
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
1324
+ # expand the latents if we are doing classifier free guidance
1325
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1326
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1327
+ sigma = self.scheduler.sigmas[i]
1328
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1329
+
1330
+ # predict the noise residual
1331
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1332
+
1333
+ # perform guidance
1334
+ if do_classifier_free_guidance:
1335
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1336
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1337
+
1338
+ # compute the previous noisy sample x_t -> x_t-1
1339
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1340
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
1341
+ else:
1342
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1343
+
1344
+ # scale and decode the image latents with vae
1345
+ latents = 1 / 0.18215 * latents
1346
+ image = self.vae.decode(latents)
1347
+
1348
+ image = (image / 2 + 0.5).clamp(0, 1)
1349
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
1350
+
1351
+ if output_type == "pil":
1352
+ image = self.numpy_to_pil(image)
1353
+
1354
+ return {"image": image, "generator_state": generator_state}
1355
+
1356
+ def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1357
+ # random vector to move in latent space
1358
+ rand_t = (torch.rand(text_embeddings.shape, device = self.device) * 2) - 1
1359
+ rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1360
+ scaled_rand_t = rand_t / rand_mag
1361
+ variation_embedding = text_embeddings + scaled_rand_t
1362
+
1363
+ generator = torch.Generator("cuda")
1364
+ generator.set_state(generator_state)
1365
+ result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
1366
+ result.update({"latent_point": variation_embedding})
1367
+ return result