debisoft commited on
Commit
3e7e692
·
1 Parent(s): becafd9
Files changed (1) hide show
  1. app.py +36 -2
app.py CHANGED
@@ -188,11 +188,45 @@ def sample_ddim_context(n_sample, context, n=20):
188
  intermediate = np.stack(intermediate)
189
  return samples, intermediate
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def greet(input):
192
  steps = int(input)
193
  #samples, intermediate = sample_ddim(32, n=steps)
194
- ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
195
- samples, intermediate = sample_ddim_context(32, ctx, steps)
 
196
  #response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
197
  #response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
198
  #response = im.fromarray(intermediate[24][0][1]).convert("RGB")
 
188
  intermediate = np.stack(intermediate)
189
  return samples, intermediate
190
 
191
+ # helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
192
+ def denoise_add_noise(x, t, pred_noise, z=None):
193
+ if z is None:
194
+ z = torch.randn_like(x)
195
+ noise = b_t.sqrt()[t] * z
196
+ mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
197
+ return mean + noise
198
+
199
+ # sample using standard algorithm
200
+ @torch.no_grad()
201
+ def sample_ddpm(n_sample, save_rate=20):
202
+ # x_T ~ N(0, 1), sample initial noise
203
+ samples = torch.randn(n_sample, 3, height, height).to(device)
204
+
205
+ # array to keep track of generated steps for plotting
206
+ intermediate = []
207
+ for i in range(timesteps, 0, -1):
208
+ print(f'sampling timestep {i:3d}', end='\r')
209
+
210
+ # reshape time tensor
211
+ t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
212
+
213
+ # sample some random noise to inject back in. For i = 1, don't add back in noise
214
+ z = torch.randn_like(samples) if i > 1 else 0
215
+
216
+ eps = nn_model(samples, t) # predict noise e_(x_t,t)
217
+ samples = denoise_add_noise(samples, i, eps, z)
218
+ if i % save_rate ==0 or i==timesteps or i<8:
219
+ intermediate.append(samples.detach().cpu().numpy())
220
+
221
+ intermediate = np.stack(intermediate)
222
+ return samples, intermediate
223
+
224
  def greet(input):
225
  steps = int(input)
226
  #samples, intermediate = sample_ddim(32, n=steps)
227
+ #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
228
+ #samples, intermediate = sample_ddim_context(32, ctx, steps)
229
+ samples, intermediate = sample_ddpm(32, )
230
  #response = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[-1]))
231
  #response2 = transform2(transform(np.moveaxis(samples.detach().cpu().numpy(),1,3)[1]))
232
  #response = im.fromarray(intermediate[24][0][1]).convert("RGB")