anton-l HF staff commited on
Commit
ccdcc08
·
1 Parent(s): e9b25a6

make style

Browse files
Files changed (1) hide show
  1. modeling_glide.py +38 -17
modeling_glide.py CHANGED
@@ -18,7 +18,14 @@ import numpy as np
18
  import torch
19
 
20
  import tqdm
21
- from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
 
 
 
 
 
 
 
22
  from transformers import GPT2Tokenizer
23
 
24
 
@@ -46,12 +53,16 @@ class GLIDE(DiffusionPipeline):
46
  text_encoder: CLIPTextModel,
47
  tokenizer: GPT2Tokenizer,
48
  upscale_unet: GLIDESuperResUNetModel,
49
- upscale_noise_scheduler: GlideDDIMScheduler
50
  ):
51
  super().__init__()
52
  self.register_modules(
53
- text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
54
- upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
 
 
 
 
55
  )
56
 
57
  def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
@@ -67,9 +78,7 @@ class GLIDE(DiffusionPipeline):
67
  + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
68
  )
69
  posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
70
- posterior_log_variance_clipped = _extract_into_tensor(
71
- scheduler.posterior_log_variance_clipped, t, x_t.shape
72
- )
73
  assert (
74
  posterior_mean.shape[0]
75
  == posterior_variance.shape[0]
@@ -190,19 +199,30 @@ class GLIDE(DiffusionPipeline):
190
  # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
191
  upsample_temp = 0.997
192
 
193
- image = self.upscale_noise_scheduler.sample_noise(
194
- (batch_size, 3, 256, 256), device=torch_device, generator=generator
195
- ) * upsample_temp
 
 
 
196
 
197
  num_timesteps = len(self.upscale_noise_scheduler)
198
- for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
 
 
199
  # i) define coefficients for time step t
200
  clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
201
  clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
202
- image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
203
- self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
204
- clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
205
- t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
 
 
 
 
 
 
206
 
207
  # ii) predict noise residual
208
  time_input = torch.tensor([t] * image.shape[0], device=torch_device)
@@ -216,8 +236,9 @@ class GLIDE(DiffusionPipeline):
216
  prev_image = clipped_coeff * pred_mean + image_coeff * image
217
 
218
  # iv) sample variance
219
- prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
220
- generator=generator)
 
221
 
222
  # v) sample x_{t-1} ~ N(prev_image, prev_variance)
223
  sampled_prev_image = prev_image + prev_variance
 
18
  import torch
19
 
20
  import tqdm
21
+ from diffusers import (
22
+ ClassifierFreeGuidanceScheduler,
23
+ CLIPTextModel,
24
+ DiffusionPipeline,
25
+ GlideDDIMScheduler,
26
+ GLIDESuperResUNetModel,
27
+ GLIDETextToImageUNetModel,
28
+ )
29
  from transformers import GPT2Tokenizer
30
 
31
 
 
53
  text_encoder: CLIPTextModel,
54
  tokenizer: GPT2Tokenizer,
55
  upscale_unet: GLIDESuperResUNetModel,
56
+ upscale_noise_scheduler: GlideDDIMScheduler,
57
  ):
58
  super().__init__()
59
  self.register_modules(
60
+ text_unet=text_unet,
61
+ text_noise_scheduler=text_noise_scheduler,
62
+ text_encoder=text_encoder,
63
+ tokenizer=tokenizer,
64
+ upscale_unet=upscale_unet,
65
+ upscale_noise_scheduler=upscale_noise_scheduler,
66
  )
67
 
68
  def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
 
78
  + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
79
  )
80
  posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
81
+ posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
 
 
82
  assert (
83
  posterior_mean.shape[0]
84
  == posterior_variance.shape[0]
 
199
  # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
200
  upsample_temp = 0.997
201
 
202
+ image = (
203
+ self.upscale_noise_scheduler.sample_noise(
204
+ (batch_size, 3, 256, 256), device=torch_device, generator=generator
205
+ )
206
+ * upsample_temp
207
+ )
208
 
209
  num_timesteps = len(self.upscale_noise_scheduler)
210
+ for t in tqdm.tqdm(
211
+ reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
212
+ ):
213
  # i) define coefficients for time step t
214
  clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
215
  clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
216
+ image_coeff = (
217
+ (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
218
+ * torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
219
+ / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
220
+ )
221
+ clipped_coeff = (
222
+ torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
223
+ * self.upscale_noise_scheduler.get_beta(t)
224
+ / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
225
+ )
226
 
227
  # ii) predict noise residual
228
  time_input = torch.tensor([t] * image.shape[0], device=torch_device)
 
236
  prev_image = clipped_coeff * pred_mean + image_coeff * image
237
 
238
  # iv) sample variance
239
+ prev_variance = self.upscale_noise_scheduler.sample_variance(
240
+ t, prev_image.shape, device=torch_device, generator=generator
241
+ )
242
 
243
  # v) sample x_{t-1} ~ N(prev_image, prev_variance)
244
  sampled_prev_image = prev_image + prev_variance