AlanB commited on
Commit
bf344dc
·
1 Parent(s): e40757f

Updated for Torch Compiling fix

Browse files
Files changed (1) hide show
  1. pipeline.py +222 -4
pipeline.py CHANGED
@@ -7,6 +7,9 @@ import PIL
7
  import torch
8
  from packaging import version
9
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
 
 
 
10
 
11
  from diffusers import DiffusionPipeline
12
  from diffusers.configuration_utils import FrozenDict
@@ -792,7 +795,7 @@ class StableDiffusionLongPromptWeightingPipeline(
792
 
793
  def decode_latents(self, latents):
794
  latents = 1 / self.vae.config.scaling_factor * latents
795
- image = self.vae.decode(latents).sample
796
  image = (image / 2 + 0.5).clamp(0, 1)
797
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
798
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -1063,7 +1066,9 @@ class StableDiffusionLongPromptWeightingPipeline(
1063
  t,
1064
  encoder_hidden_states=prompt_embeds,
1065
  cross_attention_kwargs=cross_attention_kwargs,
1066
- ).sample
 
 
1067
 
1068
  # perform guidance
1069
  if do_classifier_free_guidance:
@@ -1071,7 +1076,7 @@ class StableDiffusionLongPromptWeightingPipeline(
1071
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1072
 
1073
  # compute the previous noisy sample x_t -> x_t-1
1074
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1075
 
1076
  if mask is not None:
1077
  # masking
@@ -1467,4 +1472,217 @@ class StableDiffusionLongPromptWeightingPipeline(
1467
  is_cancelled_callback=is_cancelled_callback,
1468
  callback_steps=callback_steps,
1469
  cross_attention_kwargs=cross_attention_kwargs,
1470
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch
8
  from packaging import version
9
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
+ import random
11
+ import sys
12
+ from tqdm.auto import tqdm
13
 
14
  from diffusers import DiffusionPipeline
15
  from diffusers.configuration_utils import FrozenDict
 
795
 
796
  def decode_latents(self, latents):
797
  latents = 1 / self.vae.config.scaling_factor * latents
798
+ image = self.vae.decode(latents, return_dict=False)[0] #).sample
799
  image = (image / 2 + 0.5).clamp(0, 1)
800
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
801
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
 
1066
  t,
1067
  encoder_hidden_states=prompt_embeds,
1068
  cross_attention_kwargs=cross_attention_kwargs,
1069
+ return_dict=False,
1070
+ )[0]
1071
+ #).sample
1072
 
1073
  # perform guidance
1074
  if do_classifier_free_guidance:
 
1076
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1077
 
1078
  # compute the previous noisy sample x_t -> x_t-1
1079
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] #).prev_sample
1080
 
1081
  if mask is not None:
1082
  # masking
 
1472
  is_cancelled_callback=is_cancelled_callback,
1473
  callback_steps=callback_steps,
1474
  cross_attention_kwargs=cross_attention_kwargs,
1475
+ )
1476
+
1477
+
1478
+ # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1479
+ def get_text_latent_space(self, prompt, guidance_scale = 7.5):
1480
+ # get prompt text embeddings
1481
+ text_input = self.tokenizer(
1482
+ prompt,
1483
+ padding="max_length",
1484
+ max_length=self.tokenizer.model_max_length,
1485
+ truncation=True,
1486
+ return_tensors="pt",
1487
+ )
1488
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1489
+
1490
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1491
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1492
+ # corresponds to doing no classifier free guidance.
1493
+ do_classifier_free_guidance = guidance_scale > 1.0
1494
+ # get unconditional embeddings for classifier free guidance
1495
+ if do_classifier_free_guidance:
1496
+ max_length = text_input.input_ids.shape[-1]
1497
+ uncond_input = self.tokenizer(
1498
+ [""], padding="max_length", max_length=max_length, return_tensors="pt"
1499
+ )
1500
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
1501
+
1502
+ # For classifier free guidance, we need to do two forward passes.
1503
+ # Here we concatenate the unconditional and text embeddings into a single batch
1504
+ # to avoid doing two forward passes
1505
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1506
+
1507
+ return text_embeddings
1508
+
1509
+ def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1510
+ """ helper function to spherically interpolate two arrays v1 v2
1511
+ from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
1512
+ this should be better than lerping for moving between noise spaces """
1513
+
1514
+ if not isinstance(v0, np.ndarray):
1515
+ inputs_are_torch = True
1516
+ input_device = v0.device
1517
+ v0 = v0.cpu().numpy()
1518
+ v1 = v1.cpu().numpy()
1519
+
1520
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
1521
+ if np.abs(dot) > DOT_THRESHOLD:
1522
+ v2 = (1 - t) * v0 + t * v1
1523
+ else:
1524
+ theta_0 = np.arccos(dot)
1525
+ sin_theta_0 = np.sin(theta_0)
1526
+ theta_t = theta_0 * t
1527
+ sin_theta_t = np.sin(theta_t)
1528
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
1529
+ s1 = sin_theta_t / sin_theta_0
1530
+ v2 = s0 * v0 + s1 * v1
1531
+
1532
+ if inputs_are_torch:
1533
+ v2 = torch.from_numpy(v2).to(input_device)
1534
+
1535
+ return v2
1536
+
1537
+ def lerp_between_prompts(self, first_prompt, second_prompt, seed = None, length = 10, save=False, guidance_scale: Optional[float] = 7.5, **kwargs):
1538
+ first_embedding = self.get_text_latent_space(first_prompt)
1539
+ second_embedding = self.get_text_latent_space(second_prompt)
1540
+ if not seed:
1541
+ seed = random.randint(0, sys.maxsize)
1542
+ generator = torch.Generator(self.device)
1543
+ generator.manual_seed(seed)
1544
+ generator_state = generator.get_state()
1545
+ lerp_embed_points = []
1546
+ for i in range(length):
1547
+ weight = i / length
1548
+ tensor_lerp = torch.lerp(first_embedding, second_embedding, weight)
1549
+ lerp_embed_points.append(tensor_lerp)
1550
+ images = []
1551
+ for idx, latent_point in enumerate(lerp_embed_points):
1552
+ generator.set_state(generator_state)
1553
+ image = self.diffuse_from_inits(latent_point, **kwargs)["image"][0]
1554
+ images.append(image)
1555
+ if save:
1556
+ image.save(f"{first_prompt}-{second_prompt}-{idx:02d}.png", "PNG")
1557
+ return {"images": images, "latent_points": lerp_embed_points,"generator_state": generator_state}
1558
+
1559
+ def slerp_through_seeds(self,
1560
+ prompt,
1561
+ height: Optional[int] = 512,
1562
+ width: Optional[int] = 512,
1563
+ save = False,
1564
+ seed = None, steps = 10, **kwargs):
1565
+
1566
+ if not seed:
1567
+ seed = random.randint(0, sys.maxsize)
1568
+ generator = torch.Generator(self.device)
1569
+ generator.manual_seed(seed)
1570
+ init_start = torch.randn(
1571
+ (1, self.unet.in_channels, height // 8, width // 8),
1572
+ generator = generator, device = self.device)
1573
+ init_end = torch.randn(
1574
+ (1, self.unet.in_channels, height // 8, width // 8),
1575
+ generator = generator, device = self.device)
1576
+ generator_state = generator.get_state()
1577
+ slerp_embed_points = []
1578
+ # weight from 0 to 1/(steps - 1), add init_end specifically so that we
1579
+ # have len(images) = steps
1580
+ for i in range(steps - 1):
1581
+ weight = i / steps
1582
+ tensor_slerp = self.slerp(weight, init_start, init_end)
1583
+ slerp_embed_points.append(tensor_slerp)
1584
+ slerp_embed_points.append(init_end)
1585
+ images = []
1586
+ embed_point = self.get_text_latent_space(prompt)
1587
+ for idx, noise_point in enumerate(slerp_embed_points):
1588
+ generator.set_state(generator_state)
1589
+ image = self.diffuse_from_inits(embed_point, init = noise_point, **kwargs)["image"][0]
1590
+ images.append(image)
1591
+ if save:
1592
+ image.save(f"{seed}-{idx:02d}.png", "PNG")
1593
+ return {"images": images, "noise_samples": slerp_embed_points,"generator_state": generator_state}
1594
+
1595
+ @torch.no_grad()
1596
+ def diffuse_from_inits(self, text_embeddings,
1597
+ init = None,
1598
+ height: Optional[int] = 512,
1599
+ width: Optional[int] = 512,
1600
+ num_inference_steps: Optional[int] = 50,
1601
+ guidance_scale: Optional[float] = 7.5,
1602
+ eta: Optional[float] = 0.0,
1603
+ generator: Optional[torch.Generator] = None,
1604
+ output_type: Optional[str] = "pil",
1605
+ **kwargs,):
1606
+
1607
+ from diffusers.schedulers import LMSDiscreteScheduler
1608
+ batch_size = 1
1609
+
1610
+ if generator == None:
1611
+ generator = torch.Generator("cuda")
1612
+ generator_state = generator.get_state()
1613
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1614
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1615
+ # corresponds to doing no classifier free guidance.
1616
+ do_classifier_free_guidance = guidance_scale > 1.0
1617
+ # get the intial random noise
1618
+ latents = init if init is not None else torch.randn(
1619
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
1620
+ generator=generator,
1621
+ device=self.device,)
1622
+
1623
+ # set timesteps
1624
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
1625
+ extra_set_kwargs = {}
1626
+ if accepts_offset:
1627
+ extra_set_kwargs["offset"] = 1
1628
+
1629
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
1630
+
1631
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
1632
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1633
+ latents = latents * self.scheduler.sigmas[0]
1634
+
1635
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1636
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1637
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1638
+ # and should be between [0, 1]
1639
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
1640
+ extra_step_kwargs = {}
1641
+ if accepts_eta:
1642
+ extra_step_kwargs["eta"] = eta
1643
+
1644
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
1645
+ # expand the latents if we are doing classifier free guidance
1646
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1647
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1648
+ sigma = self.scheduler.sigmas[i]
1649
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1650
+
1651
+ # predict the noise residual
1652
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, return_dict=False)[0] #).sample
1653
+
1654
+ # perform guidance
1655
+ if do_classifier_free_guidance:
1656
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1657
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1658
+
1659
+ # compute the previous noisy sample x_t -> x_t-1
1660
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1661
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs, return_dict=False)[0] #).prev_sample
1662
+ else:
1663
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] #).prev_sample
1664
+
1665
+ # scale and decode the image latents with vae
1666
+ latents = 1 / 0.18215 * latents
1667
+ image = self.vae.decode(latents)
1668
+
1669
+ image = (image / 2 + 0.5).clamp(0, 1)
1670
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
1671
+
1672
+ if output_type == "pil":
1673
+ image = self.numpy_to_pil(image)
1674
+
1675
+ return {"image": image, "generator_state": generator_state}
1676
+
1677
+ def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1678
+ # random vector to move in latent space
1679
+ rand_t = (torch.rand(text_embeddings.shape, device = self.device) * 2) - 1
1680
+ rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1681
+ scaled_rand_t = rand_t / rand_mag
1682
+ variation_embedding = text_embeddings + scaled_rand_t
1683
+
1684
+ generator = torch.Generator("cuda")
1685
+ generator.set_state(generator_state)
1686
+ result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
1687
+ result.update({"latent_point": variation_embedding})
1688
+ return result