anton-l HF staff commited on
Commit
70ed895
·
1 Parent(s): 861b3c0

Update pipeline_glide.py

Browse files
Files changed (1) hide show
  1. pipeline_glide.py +66 -131
pipeline_glide.py CHANGED
@@ -24,20 +24,16 @@ import torch.utils.checkpoint
24
  from torch import nn
25
 
26
  import tqdm
27
- from diffusers.models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
28
- from diffusers.pipeline_utils import DiffusionPipeline
29
- from diffusers.schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
30
  from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
31
  from transformers.activations import ACT2FN
32
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
33
  from transformers.modeling_utils import PreTrainedModel
34
- from transformers.utils import (
35
- ModelOutput,
36
- add_start_docstrings,
37
- add_start_docstrings_to_model_forward,
38
- logging,
39
- replace_return_docstrings,
40
- )
41
 
42
 
43
  #####################
@@ -719,7 +715,7 @@ class GLIDE(DiffusionPipeline):
719
  def __init__(
720
  self,
721
  text_unet: GLIDETextToImageUNetModel,
722
- text_noise_scheduler: ClassifierFreeGuidanceScheduler,
723
  text_encoder: CLIPTextModel,
724
  tokenizer: GPT2Tokenizer,
725
  upscale_unet: GLIDESuperResUNetModel,
@@ -735,100 +731,28 @@ class GLIDE(DiffusionPipeline):
735
  upscale_noise_scheduler=upscale_noise_scheduler,
736
  )
737
 
738
- def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
739
- """
740
- Compute the mean and variance of the diffusion posterior:
741
-
742
- q(x_{t-1} | x_t, x_0)
743
-
744
- """
745
- assert x_start.shape == x_t.shape
746
- posterior_mean = (
747
- _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
748
- + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
749
- )
750
- posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
751
- posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
752
- assert (
753
- posterior_mean.shape[0]
754
- == posterior_variance.shape[0]
755
- == posterior_log_variance_clipped.shape[0]
756
- == x_start.shape[0]
757
- )
758
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
759
-
760
- def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
761
- """
762
- Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
763
- the initial x, x_0.
764
-
765
- :param model: the model, which takes a signal and a batch of timesteps
766
- as input.
767
- :param x: the [N x C x ...] tensor at time t.
768
- :param t: a 1-D Tensor of timesteps.
769
- :param clip_denoised: if True, clip the denoised signal into [-1, 1].
770
- :param model_kwargs: if not None, a dict of extra keyword arguments to
771
- pass to the model. This can be used for conditioning.
772
- :return: a dict with the following keys:
773
- - 'mean': the model mean output.
774
- - 'variance': the model variance output.
775
- - 'log_variance': the log of 'variance'.
776
- - 'pred_xstart': the prediction for x_0.
777
- """
778
-
779
- B, C = x.shape[:2]
780
- assert t.shape == (B,)
781
- if transformer_out is None:
782
- # super-res model
783
- model_output = model(x, t, low_res)
784
- else:
785
- # text2image model
786
- model_output = model(x, t, transformer_out)
787
-
788
- assert model_output.shape == (B, C * 2, *x.shape[2:])
789
- model_output, model_var_values = torch.split(model_output, C, dim=1)
790
- min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
791
- max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
792
- # The model_var_values is [-1, 1] for [min_var, max_var].
793
- frac = (model_var_values + 1) / 2
794
- model_log_variance = frac * max_log + (1 - frac) * min_log
795
- model_variance = torch.exp(model_log_variance)
796
-
797
- pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
798
- if clip_denoised:
799
- pred_xstart = pred_xstart.clamp(-1, 1)
800
- model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
801
-
802
- assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
803
- return model_mean, model_variance, model_log_variance, pred_xstart
804
-
805
- def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
806
- assert x_t.shape == eps.shape
807
- return (
808
- _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
809
- - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
810
- )
811
-
812
- def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
813
- return (
814
- _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
815
- ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
816
-
817
  @torch.no_grad()
818
- def __call__(self, prompt, generator=None, torch_device=None, num_inference_steps_upscale=50):
 
 
 
 
 
 
 
 
 
 
819
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
820
 
821
  self.text_unet.to(torch_device)
822
  self.text_encoder.to(torch_device)
823
  self.upscale_unet.to(torch_device)
824
 
825
- # Create a classifier-free guidance sampling function
826
- guidance_scale = 3.0
827
-
828
- def text_model_fn(x_t, ts, transformer_out, **kwargs):
829
  half = x_t[: len(x_t) // 2]
830
  combined = torch.cat([half, half], dim=0)
831
- model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
832
  eps, rest = model_out[:, :3], model_out[:, 3:]
833
  cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
834
  half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
@@ -837,71 +761,82 @@ class GLIDE(DiffusionPipeline):
837
 
838
  # 1. Sample gaussian noise
839
  batch_size = 2 # second image is empty for classifier-free guidance
840
- image = self.text_noise_scheduler.sample_noise(
841
- (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
842
- )
 
 
 
 
 
 
843
 
844
  # 2. Encode tokens
845
- # an empty input is needed to guide the model away from (
846
  inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
847
  input_ids = inputs["input_ids"].to(torch_device)
848
  attention_mask = inputs["attention_mask"].to(torch_device)
849
  transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
850
 
851
  # 3. Run the text2image generation step
852
- num_timesteps = len(self.text_noise_scheduler)
853
- for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
854
- t = torch.tensor([i] * image.shape[0], device=torch_device)
855
- mean, variance, log_variance, pred_xstart = self.p_mean_variance(
856
- text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
 
 
 
 
 
 
 
 
 
 
857
  )
858
- noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
859
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
860
- image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
 
 
861
 
862
  # 4. Run the upscaling step
863
  batch_size = 1
864
  image = image[:1]
865
  low_res = ((image + 1) * 127.5).round() / 127.5 - 1
866
- eta = 0.0
867
-
868
- # Tune this parameter to control the sharpness of 256x256 images.
869
- # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
870
- upsample_temp = 0.997
871
 
872
  # Sample gaussian noise to begin loop
873
  image = torch.randn(
874
- (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
 
 
 
 
 
875
  generator=generator,
876
- )
877
- image = image.to(torch_device)
878
-
879
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
880
- # Ideally, read DDIM paper in-detail understanding
881
-
882
- # Notation (<variable name> -> <name in paper>
883
- # - pred_noise_t -> e_theta(x_t, t)
884
- # - pred_original_image -> f_theta(x_t, t) or x_0
885
- # - std_dev_t -> sigma_t
886
- # - eta -> η
887
- # - pred_image_direction -> "direction pointingc to x_t"
888
- # - pred_prev_image -> "x_t-1"
889
  for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
890
  # 1. predict noise residual
891
  with torch.no_grad():
892
- time_input = torch.tensor([t] * image.shape[0], device=torch_device)
893
  model_output = self.upscale_unet(image, time_input, low_res)
894
  noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
895
 
896
  # 2. predict previous mean of image x_t-1
897
  pred_prev_image = self.upscale_noise_scheduler.step(
898
- noise_residual, image, t, num_inference_steps_upscale, eta
899
  )
900
 
901
  # 3. optionally sample variance
902
  variance = 0
903
  if eta > 0:
904
- noise = torch.randn(image.shape, generator=generator).to(image.device)
905
  variance = (
906
  self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
907
  )
@@ -909,6 +844,6 @@ class GLIDE(DiffusionPipeline):
909
  # 4. set current image to prev_image: x_t -> x_t-1
910
  image = pred_prev_image + variance
911
 
912
- image = image.permute(0, 2, 3, 1)
913
 
914
  return image
 
24
  from torch import nn
25
 
26
  import tqdm
 
 
 
27
  from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
30
  from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
32
+
33
+ from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
34
+ from ..pipeline_utils import DiffusionPipeline
35
+ from ..schedulers import DDPMScheduler, DDIMScheduler
36
+ from ..utils import logging
 
37
 
38
 
39
  #####################
 
715
  def __init__(
716
  self,
717
  text_unet: GLIDETextToImageUNetModel,
718
+ text_noise_scheduler: DDPMScheduler,
719
  text_encoder: CLIPTextModel,
720
  tokenizer: GPT2Tokenizer,
721
  upscale_unet: GLIDESuperResUNetModel,
 
731
  upscale_noise_scheduler=upscale_noise_scheduler,
732
  )
733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  @torch.no_grad()
735
+ def __call__(
736
+ self,
737
+ prompt,
738
+ generator=None,
739
+ torch_device=None,
740
+ num_inference_steps_upscale=50,
741
+ guidance_scale=3.0,
742
+ eta=0.0,
743
+ upsample_temp=0.997,
744
+ ):
745
+
746
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
747
 
748
  self.text_unet.to(torch_device)
749
  self.text_encoder.to(torch_device)
750
  self.upscale_unet.to(torch_device)
751
 
752
+ def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
 
 
 
753
  half = x_t[: len(x_t) // 2]
754
  combined = torch.cat([half, half], dim=0)
755
+ model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
756
  eps, rest = model_out[:, :3], model_out[:, 3:]
757
  cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
758
  half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
 
761
 
762
  # 1. Sample gaussian noise
763
  batch_size = 2 # second image is empty for classifier-free guidance
764
+ image = torch.randn(
765
+ (
766
+ batch_size,
767
+ self.text_unet.in_channels,
768
+ self.text_unet.resolution,
769
+ self.text_unet.resolution,
770
+ ),
771
+ generator=generator,
772
+ ).to(torch_device)
773
 
774
  # 2. Encode tokens
775
+ # an empty input is needed to guide the model away from it
776
  inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
777
  input_ids = inputs["input_ids"].to(torch_device)
778
  attention_mask = inputs["attention_mask"].to(torch_device)
779
  transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
780
 
781
  # 3. Run the text2image generation step
782
+ num_prediction_steps = len(self.text_noise_scheduler)
783
+ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
784
+ with torch.no_grad():
785
+ time_input = torch.tensor([t] * image.shape[0], device=torch_device)
786
+ model_output = text_model_fn(image, time_input, transformer_out)
787
+ noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
788
+
789
+ min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
790
+ max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
791
+ # The model_var_values is [-1, 1] for [min_var, max_var].
792
+ frac = (model_var_values + 1) / 2
793
+ model_log_variance = frac * max_log + (1 - frac) * min_log
794
+
795
+ pred_prev_image = self.upscale_noise_scheduler.step(
796
+ noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
797
  )
798
+ noise = torch.randn(image.shape, generator=generator).to(torch_device)
799
+ variance = torch.exp(0.5 * model_log_variance) * noise
800
+
801
+ # set current image to prev_image: x_t -> x_t-1
802
+ image = pred_prev_image + variance
803
 
804
  # 4. Run the upscaling step
805
  batch_size = 1
806
  image = image[:1]
807
  low_res = ((image + 1) * 127.5).round() / 127.5 - 1
 
 
 
 
 
808
 
809
  # Sample gaussian noise to begin loop
810
  image = torch.randn(
811
+ (
812
+ batch_size,
813
+ self.upscale_unet.in_channels // 2,
814
+ self.upscale_unet.resolution,
815
+ self.upscale_unet.resolution,
816
+ ),
817
  generator=generator,
818
+ ).to(torch_device)
819
+ image = image * upsample_temp
820
+
821
+ num_trained_timesteps = self.upscale_noise_scheduler.timesteps
822
+ inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
823
+
 
 
 
 
 
 
 
824
  for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
825
  # 1. predict noise residual
826
  with torch.no_grad():
827
+ time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
828
  model_output = self.upscale_unet(image, time_input, low_res)
829
  noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
830
 
831
  # 2. predict previous mean of image x_t-1
832
  pred_prev_image = self.upscale_noise_scheduler.step(
833
+ noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
834
  )
835
 
836
  # 3. optionally sample variance
837
  variance = 0
838
  if eta > 0:
839
+ noise = torch.randn(image.shape, generator=generator).to(torch_device)
840
  variance = (
841
  self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
842
  )
 
844
  # 4. set current image to prev_image: x_t -> x_t-1
845
  image = pred_prev_image + variance
846
 
847
+ image = image.clamp(-1, 1).permute(0, 2, 3, 1)
848
 
849
  return image