make style
Browse files- modeling_glide.py +38 -17
modeling_glide.py
CHANGED
@@ -18,7 +18,14 @@ import numpy as np
|
|
18 |
import torch
|
19 |
|
20 |
import tqdm
|
21 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
from transformers import GPT2Tokenizer
|
23 |
|
24 |
|
@@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline):
|
|
46 |
text_encoder: CLIPTextModel,
|
47 |
tokenizer: GPT2Tokenizer,
|
48 |
upscale_unet: GLIDESuperResUNetModel,
|
49 |
-
upscale_noise_scheduler: GlideDDIMScheduler
|
50 |
):
|
51 |
super().__init__()
|
52 |
self.register_modules(
|
53 |
-
text_unet=text_unet,
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
)
|
56 |
|
57 |
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
@@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline):
|
|
67 |
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
68 |
)
|
69 |
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
70 |
-
posterior_log_variance_clipped = _extract_into_tensor(
|
71 |
-
scheduler.posterior_log_variance_clipped, t, x_t.shape
|
72 |
-
)
|
73 |
assert (
|
74 |
posterior_mean.shape[0]
|
75 |
== posterior_variance.shape[0]
|
@@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline):
|
|
190 |
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
191 |
upsample_temp = 0.997
|
192 |
|
193 |
-
image =
|
194 |
-
(
|
195 |
-
|
|
|
|
|
|
|
196 |
|
197 |
num_timesteps = len(self.upscale_noise_scheduler)
|
198 |
-
for t in tqdm.tqdm(
|
|
|
|
|
199 |
# i) define coefficients for time step t
|
200 |
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
201 |
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
202 |
-
image_coeff = (
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
# ii) predict noise residual
|
208 |
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
@@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline):
|
|
216 |
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
217 |
|
218 |
# iv) sample variance
|
219 |
-
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
220 |
-
|
|
|
221 |
|
222 |
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
223 |
sampled_prev_image = prev_image + prev_variance
|
|
|
18 |
import torch
|
19 |
|
20 |
import tqdm
|
21 |
+
from diffusers import (
|
22 |
+
ClassifierFreeGuidanceScheduler,
|
23 |
+
CLIPTextModel,
|
24 |
+
DiffusionPipeline,
|
25 |
+
GlideDDIMScheduler,
|
26 |
+
GLIDESuperResUNetModel,
|
27 |
+
GLIDETextToImageUNetModel,
|
28 |
+
)
|
29 |
from transformers import GPT2Tokenizer
|
30 |
|
31 |
|
|
|
53 |
text_encoder: CLIPTextModel,
|
54 |
tokenizer: GPT2Tokenizer,
|
55 |
upscale_unet: GLIDESuperResUNetModel,
|
56 |
+
upscale_noise_scheduler: GlideDDIMScheduler,
|
57 |
):
|
58 |
super().__init__()
|
59 |
self.register_modules(
|
60 |
+
text_unet=text_unet,
|
61 |
+
text_noise_scheduler=text_noise_scheduler,
|
62 |
+
text_encoder=text_encoder,
|
63 |
+
tokenizer=tokenizer,
|
64 |
+
upscale_unet=upscale_unet,
|
65 |
+
upscale_noise_scheduler=upscale_noise_scheduler,
|
66 |
)
|
67 |
|
68 |
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
|
|
78 |
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
79 |
)
|
80 |
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
81 |
+
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
82 |
assert (
|
83 |
posterior_mean.shape[0]
|
84 |
== posterior_variance.shape[0]
|
|
|
199 |
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
200 |
upsample_temp = 0.997
|
201 |
|
202 |
+
image = (
|
203 |
+
self.upscale_noise_scheduler.sample_noise(
|
204 |
+
(batch_size, 3, 256, 256), device=torch_device, generator=generator
|
205 |
+
)
|
206 |
+
* upsample_temp
|
207 |
+
)
|
208 |
|
209 |
num_timesteps = len(self.upscale_noise_scheduler)
|
210 |
+
for t in tqdm.tqdm(
|
211 |
+
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
|
212 |
+
):
|
213 |
# i) define coefficients for time step t
|
214 |
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
|
215 |
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
|
216 |
+
image_coeff = (
|
217 |
+
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
218 |
+
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
|
219 |
+
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
220 |
+
)
|
221 |
+
clipped_coeff = (
|
222 |
+
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
|
223 |
+
* self.upscale_noise_scheduler.get_beta(t)
|
224 |
+
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
|
225 |
+
)
|
226 |
|
227 |
# ii) predict noise residual
|
228 |
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
|
|
236 |
prev_image = clipped_coeff * pred_mean + image_coeff * image
|
237 |
|
238 |
# iv) sample variance
|
239 |
+
prev_variance = self.upscale_noise_scheduler.sample_variance(
|
240 |
+
t, prev_image.shape, device=torch_device, generator=generator
|
241 |
+
)
|
242 |
|
243 |
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
|
244 |
sampled_prev_image = prev_image + prev_variance
|