Update pipeline_glide.py
Browse files- pipeline_glide.py +66 -131
pipeline_glide.py
CHANGED
@@ -24,20 +24,16 @@ import torch.utils.checkpoint
|
|
24 |
from torch import nn
|
25 |
|
26 |
import tqdm
|
27 |
-
from diffusers.models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
28 |
-
from diffusers.pipeline_utils import DiffusionPipeline
|
29 |
-
from diffusers.schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
|
30 |
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
31 |
from transformers.activations import ACT2FN
|
32 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
34 |
-
from transformers.utils import
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
)
|
41 |
|
42 |
|
43 |
#####################
|
@@ -719,7 +715,7 @@ class GLIDE(DiffusionPipeline):
|
|
719 |
def __init__(
|
720 |
self,
|
721 |
text_unet: GLIDETextToImageUNetModel,
|
722 |
-
text_noise_scheduler:
|
723 |
text_encoder: CLIPTextModel,
|
724 |
tokenizer: GPT2Tokenizer,
|
725 |
upscale_unet: GLIDESuperResUNetModel,
|
@@ -735,100 +731,28 @@ class GLIDE(DiffusionPipeline):
|
|
735 |
upscale_noise_scheduler=upscale_noise_scheduler,
|
736 |
)
|
737 |
|
738 |
-
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
739 |
-
"""
|
740 |
-
Compute the mean and variance of the diffusion posterior:
|
741 |
-
|
742 |
-
q(x_{t-1} | x_t, x_0)
|
743 |
-
|
744 |
-
"""
|
745 |
-
assert x_start.shape == x_t.shape
|
746 |
-
posterior_mean = (
|
747 |
-
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
748 |
-
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
749 |
-
)
|
750 |
-
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
751 |
-
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
752 |
-
assert (
|
753 |
-
posterior_mean.shape[0]
|
754 |
-
== posterior_variance.shape[0]
|
755 |
-
== posterior_log_variance_clipped.shape[0]
|
756 |
-
== x_start.shape[0]
|
757 |
-
)
|
758 |
-
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
759 |
-
|
760 |
-
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
|
761 |
-
"""
|
762 |
-
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
763 |
-
the initial x, x_0.
|
764 |
-
|
765 |
-
:param model: the model, which takes a signal and a batch of timesteps
|
766 |
-
as input.
|
767 |
-
:param x: the [N x C x ...] tensor at time t.
|
768 |
-
:param t: a 1-D Tensor of timesteps.
|
769 |
-
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
770 |
-
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
771 |
-
pass to the model. This can be used for conditioning.
|
772 |
-
:return: a dict with the following keys:
|
773 |
-
- 'mean': the model mean output.
|
774 |
-
- 'variance': the model variance output.
|
775 |
-
- 'log_variance': the log of 'variance'.
|
776 |
-
- 'pred_xstart': the prediction for x_0.
|
777 |
-
"""
|
778 |
-
|
779 |
-
B, C = x.shape[:2]
|
780 |
-
assert t.shape == (B,)
|
781 |
-
if transformer_out is None:
|
782 |
-
# super-res model
|
783 |
-
model_output = model(x, t, low_res)
|
784 |
-
else:
|
785 |
-
# text2image model
|
786 |
-
model_output = model(x, t, transformer_out)
|
787 |
-
|
788 |
-
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
789 |
-
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
790 |
-
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
|
791 |
-
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
|
792 |
-
# The model_var_values is [-1, 1] for [min_var, max_var].
|
793 |
-
frac = (model_var_values + 1) / 2
|
794 |
-
model_log_variance = frac * max_log + (1 - frac) * min_log
|
795 |
-
model_variance = torch.exp(model_log_variance)
|
796 |
-
|
797 |
-
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
|
798 |
-
if clip_denoised:
|
799 |
-
pred_xstart = pred_xstart.clamp(-1, 1)
|
800 |
-
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
|
801 |
-
|
802 |
-
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
803 |
-
return model_mean, model_variance, model_log_variance, pred_xstart
|
804 |
-
|
805 |
-
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
|
806 |
-
assert x_t.shape == eps.shape
|
807 |
-
return (
|
808 |
-
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
809 |
-
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
810 |
-
)
|
811 |
-
|
812 |
-
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
|
813 |
-
return (
|
814 |
-
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
815 |
-
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
816 |
-
|
817 |
@torch.no_grad()
|
818 |
-
def __call__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
819 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
820 |
|
821 |
self.text_unet.to(torch_device)
|
822 |
self.text_encoder.to(torch_device)
|
823 |
self.upscale_unet.to(torch_device)
|
824 |
|
825 |
-
|
826 |
-
guidance_scale = 3.0
|
827 |
-
|
828 |
-
def text_model_fn(x_t, ts, transformer_out, **kwargs):
|
829 |
half = x_t[: len(x_t) // 2]
|
830 |
combined = torch.cat([half, half], dim=0)
|
831 |
-
model_out = self.text_unet(combined,
|
832 |
eps, rest = model_out[:, :3], model_out[:, 3:]
|
833 |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
834 |
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
@@ -837,71 +761,82 @@ class GLIDE(DiffusionPipeline):
|
|
837 |
|
838 |
# 1. Sample gaussian noise
|
839 |
batch_size = 2 # second image is empty for classifier-free guidance
|
840 |
-
image =
|
841 |
-
(
|
842 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
843 |
|
844 |
# 2. Encode tokens
|
845 |
-
# an empty input is needed to guide the model away from
|
846 |
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
847 |
input_ids = inputs["input_ids"].to(torch_device)
|
848 |
attention_mask = inputs["attention_mask"].to(torch_device)
|
849 |
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
850 |
|
851 |
# 3. Run the text2image generation step
|
852 |
-
|
853 |
-
for
|
854 |
-
|
855 |
-
|
856 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
857 |
)
|
858 |
-
noise =
|
859 |
-
|
860 |
-
|
|
|
|
|
861 |
|
862 |
# 4. Run the upscaling step
|
863 |
batch_size = 1
|
864 |
image = image[:1]
|
865 |
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
866 |
-
eta = 0.0
|
867 |
-
|
868 |
-
# Tune this parameter to control the sharpness of 256x256 images.
|
869 |
-
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
870 |
-
upsample_temp = 0.997
|
871 |
|
872 |
# Sample gaussian noise to begin loop
|
873 |
image = torch.randn(
|
874 |
-
(
|
|
|
|
|
|
|
|
|
|
|
875 |
generator=generator,
|
876 |
-
)
|
877 |
-
image = image
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
# Notation (<variable name> -> <name in paper>
|
883 |
-
# - pred_noise_t -> e_theta(x_t, t)
|
884 |
-
# - pred_original_image -> f_theta(x_t, t) or x_0
|
885 |
-
# - std_dev_t -> sigma_t
|
886 |
-
# - eta -> η
|
887 |
-
# - pred_image_direction -> "direction pointingc to x_t"
|
888 |
-
# - pred_prev_image -> "x_t-1"
|
889 |
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
890 |
# 1. predict noise residual
|
891 |
with torch.no_grad():
|
892 |
-
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
893 |
model_output = self.upscale_unet(image, time_input, low_res)
|
894 |
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
895 |
|
896 |
# 2. predict previous mean of image x_t-1
|
897 |
pred_prev_image = self.upscale_noise_scheduler.step(
|
898 |
-
noise_residual, image, t, num_inference_steps_upscale, eta
|
899 |
)
|
900 |
|
901 |
# 3. optionally sample variance
|
902 |
variance = 0
|
903 |
if eta > 0:
|
904 |
-
noise = torch.randn(image.shape, generator=generator).to(
|
905 |
variance = (
|
906 |
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
907 |
)
|
@@ -909,6 +844,6 @@ class GLIDE(DiffusionPipeline):
|
|
909 |
# 4. set current image to prev_image: x_t -> x_t-1
|
910 |
image = pred_prev_image + variance
|
911 |
|
912 |
-
image = image.permute(0, 2, 3, 1)
|
913 |
|
914 |
return image
|
|
|
24 |
from torch import nn
|
25 |
|
26 |
import tqdm
|
|
|
|
|
|
|
27 |
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
28 |
from transformers.activations import ACT2FN
|
29 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
30 |
from transformers.modeling_utils import PreTrainedModel
|
31 |
+
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
|
32 |
+
|
33 |
+
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
34 |
+
from ..pipeline_utils import DiffusionPipeline
|
35 |
+
from ..schedulers import DDPMScheduler, DDIMScheduler
|
36 |
+
from ..utils import logging
|
|
|
37 |
|
38 |
|
39 |
#####################
|
|
|
715 |
def __init__(
|
716 |
self,
|
717 |
text_unet: GLIDETextToImageUNetModel,
|
718 |
+
text_noise_scheduler: DDPMScheduler,
|
719 |
text_encoder: CLIPTextModel,
|
720 |
tokenizer: GPT2Tokenizer,
|
721 |
upscale_unet: GLIDESuperResUNetModel,
|
|
|
731 |
upscale_noise_scheduler=upscale_noise_scheduler,
|
732 |
)
|
733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
@torch.no_grad()
|
735 |
+
def __call__(
|
736 |
+
self,
|
737 |
+
prompt,
|
738 |
+
generator=None,
|
739 |
+
torch_device=None,
|
740 |
+
num_inference_steps_upscale=50,
|
741 |
+
guidance_scale=3.0,
|
742 |
+
eta=0.0,
|
743 |
+
upsample_temp=0.997,
|
744 |
+
):
|
745 |
+
|
746 |
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
747 |
|
748 |
self.text_unet.to(torch_device)
|
749 |
self.text_encoder.to(torch_device)
|
750 |
self.upscale_unet.to(torch_device)
|
751 |
|
752 |
+
def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
|
|
|
|
|
|
|
753 |
half = x_t[: len(x_t) // 2]
|
754 |
combined = torch.cat([half, half], dim=0)
|
755 |
+
model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
|
756 |
eps, rest = model_out[:, :3], model_out[:, 3:]
|
757 |
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
758 |
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
|
|
761 |
|
762 |
# 1. Sample gaussian noise
|
763 |
batch_size = 2 # second image is empty for classifier-free guidance
|
764 |
+
image = torch.randn(
|
765 |
+
(
|
766 |
+
batch_size,
|
767 |
+
self.text_unet.in_channels,
|
768 |
+
self.text_unet.resolution,
|
769 |
+
self.text_unet.resolution,
|
770 |
+
),
|
771 |
+
generator=generator,
|
772 |
+
).to(torch_device)
|
773 |
|
774 |
# 2. Encode tokens
|
775 |
+
# an empty input is needed to guide the model away from it
|
776 |
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
777 |
input_ids = inputs["input_ids"].to(torch_device)
|
778 |
attention_mask = inputs["attention_mask"].to(torch_device)
|
779 |
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
780 |
|
781 |
# 3. Run the text2image generation step
|
782 |
+
num_prediction_steps = len(self.text_noise_scheduler)
|
783 |
+
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
784 |
+
with torch.no_grad():
|
785 |
+
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
786 |
+
model_output = text_model_fn(image, time_input, transformer_out)
|
787 |
+
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
|
788 |
+
|
789 |
+
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
|
790 |
+
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
|
791 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
792 |
+
frac = (model_var_values + 1) / 2
|
793 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
794 |
+
|
795 |
+
pred_prev_image = self.upscale_noise_scheduler.step(
|
796 |
+
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
797 |
)
|
798 |
+
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
799 |
+
variance = torch.exp(0.5 * model_log_variance) * noise
|
800 |
+
|
801 |
+
# set current image to prev_image: x_t -> x_t-1
|
802 |
+
image = pred_prev_image + variance
|
803 |
|
804 |
# 4. Run the upscaling step
|
805 |
batch_size = 1
|
806 |
image = image[:1]
|
807 |
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
|
|
|
|
|
|
|
|
|
|
808 |
|
809 |
# Sample gaussian noise to begin loop
|
810 |
image = torch.randn(
|
811 |
+
(
|
812 |
+
batch_size,
|
813 |
+
self.upscale_unet.in_channels // 2,
|
814 |
+
self.upscale_unet.resolution,
|
815 |
+
self.upscale_unet.resolution,
|
816 |
+
),
|
817 |
generator=generator,
|
818 |
+
).to(torch_device)
|
819 |
+
image = image * upsample_temp
|
820 |
+
|
821 |
+
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
822 |
+
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
823 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
824 |
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
825 |
# 1. predict noise residual
|
826 |
with torch.no_grad():
|
827 |
+
time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
828 |
model_output = self.upscale_unet(image, time_input, low_res)
|
829 |
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
830 |
|
831 |
# 2. predict previous mean of image x_t-1
|
832 |
pred_prev_image = self.upscale_noise_scheduler.step(
|
833 |
+
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
834 |
)
|
835 |
|
836 |
# 3. optionally sample variance
|
837 |
variance = 0
|
838 |
if eta > 0:
|
839 |
+
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
840 |
variance = (
|
841 |
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
842 |
)
|
|
|
844 |
# 4. set current image to prev_image: x_t -> x_t-1
|
845 |
image = pred_prev_image + variance
|
846 |
|
847 |
+
image = image.clamp(-1, 1).permute(0, 2, 3, 1)
|
848 |
|
849 |
return image
|