anton-l HF staff commited on
Commit
bbd7574
·
1 Parent(s): 71d8718

add pipeline src

Browse files
Files changed (1) hide show
  1. modeling_glide.py +228 -0
modeling_glide.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+
14
+ # limitations under the License.
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ import tqdm
21
+ from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
22
+ from transformers import GPT2Tokenizer
23
+
24
+
25
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
26
+ """
27
+ Extract values from a 1-D numpy array for a batch of indices.
28
+
29
+ :param arr: the 1-D numpy array.
30
+ :param timesteps: a tensor of indices into the array to extract.
31
+ :param broadcast_shape: a larger shape of K dimensions with the batch
32
+ dimension equal to the length of timesteps.
33
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
34
+ """
35
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
36
+ while len(res.shape) < len(broadcast_shape):
37
+ res = res[..., None]
38
+ return res + torch.zeros(broadcast_shape, device=timesteps.device)
39
+
40
+
41
+ class GLIDE(DiffusionPipeline):
42
+ def __init__(
43
+ self,
44
+ text_unet: GLIDETextToImageUNetModel,
45
+ text_noise_scheduler: ClassifierFreeGuidanceScheduler,
46
+ text_encoder: CLIPTextModel,
47
+ tokenizer: GPT2Tokenizer,
48
+ upscale_unet: GLIDESuperResUNetModel,
49
+ upscale_noise_scheduler: GlideDDIMScheduler
50
+ ):
51
+ super().__init__()
52
+ self.register_modules(
53
+ text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
54
+ upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
55
+ )
56
+
57
+ def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
58
+ """
59
+ Compute the mean and variance of the diffusion posterior:
60
+
61
+ q(x_{t-1} | x_t, x_0)
62
+
63
+ """
64
+ assert x_start.shape == x_t.shape
65
+ posterior_mean = (
66
+ _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
67
+ + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
68
+ )
69
+ posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
70
+ posterior_log_variance_clipped = _extract_into_tensor(
71
+ scheduler.posterior_log_variance_clipped, t, x_t.shape
72
+ )
73
+ assert (
74
+ posterior_mean.shape[0]
75
+ == posterior_variance.shape[0]
76
+ == posterior_log_variance_clipped.shape[0]
77
+ == x_start.shape[0]
78
+ )
79
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
80
+
81
+ def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
82
+ """
83
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
84
+ the initial x, x_0.
85
+
86
+ :param model: the model, which takes a signal and a batch of timesteps
87
+ as input.
88
+ :param x: the [N x C x ...] tensor at time t.
89
+ :param t: a 1-D Tensor of timesteps.
90
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
91
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
92
+ pass to the model. This can be used for conditioning.
93
+ :return: a dict with the following keys:
94
+ - 'mean': the model mean output.
95
+ - 'variance': the model variance output.
96
+ - 'log_variance': the log of 'variance'.
97
+ - 'pred_xstart': the prediction for x_0.
98
+ """
99
+
100
+ B, C = x.shape[:2]
101
+ assert t.shape == (B,)
102
+ if transformer_out is None:
103
+ # super-res model
104
+ model_output = model(x, t, low_res)
105
+ else:
106
+ # text2image model
107
+ model_output = model(x, t, transformer_out)
108
+
109
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
110
+ model_output, model_var_values = torch.split(model_output, C, dim=1)
111
+ min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
112
+ max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
113
+ # The model_var_values is [-1, 1] for [min_var, max_var].
114
+ frac = (model_var_values + 1) / 2
115
+ model_log_variance = frac * max_log + (1 - frac) * min_log
116
+ model_variance = torch.exp(model_log_variance)
117
+
118
+ pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
119
+ if clip_denoised:
120
+ pred_xstart = pred_xstart.clamp(-1, 1)
121
+ model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
122
+
123
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
124
+ return model_mean, model_variance, model_log_variance, pred_xstart
125
+
126
+ def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
127
+ assert x_t.shape == eps.shape
128
+ return (
129
+ _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
130
+ - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
131
+ )
132
+
133
+ def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
134
+ return (
135
+ _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
136
+ ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
137
+
138
+ @torch.no_grad()
139
+ def __call__(self, prompt, generator=None, torch_device=None):
140
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
141
+
142
+ self.text_unet.to(torch_device)
143
+ self.text_encoder.to(torch_device)
144
+ self.upscale_unet.to(torch_device)
145
+
146
+ # Create a classifier-free guidance sampling function
147
+ guidance_scale = 3.0
148
+
149
+ def text_model_fn(x_t, ts, transformer_out, **kwargs):
150
+ half = x_t[: len(x_t) // 2]
151
+ combined = torch.cat([half, half], dim=0)
152
+ model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
153
+ eps, rest = model_out[:, :3], model_out[:, 3:]
154
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
155
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
156
+ eps = torch.cat([half_eps, half_eps], dim=0)
157
+ return torch.cat([eps, rest], dim=1)
158
+
159
+ # 1. Sample gaussian noise
160
+ batch_size = 2 # second image is empty for classifier-free guidance
161
+ image = self.text_noise_scheduler.sample_noise(
162
+ (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
163
+ )
164
+
165
+ # 2. Encode tokens
166
+ # an empty input is needed to guide the model away from (
167
+ inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
168
+ input_ids = inputs["input_ids"].to(torch_device)
169
+ attention_mask = inputs["attention_mask"].to(torch_device)
170
+ transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
171
+
172
+ # 3. Run the text2image generation step
173
+ num_timesteps = len(self.text_noise_scheduler)
174
+ for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
175
+ t = torch.tensor([i] * image.shape[0], device=torch_device)
176
+ mean, variance, log_variance, pred_xstart = self.p_mean_variance(
177
+ text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
178
+ )
179
+ noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
180
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
181
+ image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
182
+
183
+ # 4. Run the upscaling step
184
+ batch_size = 1
185
+ image = image[:1]
186
+ low_res = ((image + 1) * 127.5).round() / 127.5 - 1
187
+ eta = 0.0
188
+
189
+ # Tune this parameter to control the sharpness of 256x256 images.
190
+ # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
191
+ upsample_temp = 0.997
192
+
193
+ image = self.upscale_noise_scheduler.sample_noise(
194
+ (batch_size, 3, 256, 256), device=torch_device, generator=generator
195
+ ) * upsample_temp
196
+
197
+ num_timesteps = len(self.upscale_noise_scheduler)
198
+ for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
199
+ # i) define coefficients for time step t
200
+ clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
201
+ clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
202
+ image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
203
+ self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
204
+ clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
205
+ t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
206
+
207
+ # ii) predict noise residual
208
+ time_input = torch.tensor([t] * image.shape[0], device=torch_device)
209
+ model_output = self.upscale_unet(image, time_input, low_res)
210
+ noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
211
+
212
+ # iii) compute predicted image from residual
213
+ # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
214
+ pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
215
+ pred_mean = torch.clamp(pred_mean, -1, 1)
216
+ prev_image = clipped_coeff * pred_mean + image_coeff * image
217
+
218
+ # iv) sample variance
219
+ prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
220
+ generator=generator)
221
+
222
+ # v) sample x_{t-1} ~ N(prev_image, prev_variance)
223
+ sampled_prev_image = prev_image + prev_variance
224
+ image = sampled_prev_image
225
+
226
+ image = image[0].permute(1, 2, 0)
227
+
228
+ return image