debisoft commited on
Commit
feade85
·
1 Parent(s): 9e53766

sample_ddpm_context

Browse files
Files changed (1) hide show
  1. app.py +26 -2
app.py CHANGED
@@ -198,7 +198,7 @@ def denoise_add_noise(x, t, pred_noise, z=None):
198
 
199
  # sample using standard algorithm
200
  @torch.no_grad()
201
- def sample_ddpm(n_sample, save_rate=20):
202
  # x_T ~ N(0, 1), sample initial noise
203
  samples = torch.randn(n_sample, 3, height, height).to(device)
204
 
@@ -221,6 +221,30 @@ def sample_ddpm(n_sample, save_rate=20):
221
  intermediate = np.stack(intermediate)
222
  return samples, intermediate
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def greet(input):
225
  steps = int(input)
226
 
@@ -233,7 +257,7 @@ def greet(input):
233
  ctx = torch.from_numpy(mtx_2d).to(device=device).float()
234
 
235
  #samples, intermediate = sample_ddim_context(32, ctx, n=steps)
236
- samples, intermediate = sample_ddpm(32, steps)
237
 
238
  #samples, intermediate = sample_ddim(32, n=steps)
239
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
 
198
 
199
  # sample using standard algorithm
200
  @torch.no_grad()
201
+ def sample_ddpm(n_sample, context, save_rate=20):
202
  # x_T ~ N(0, 1), sample initial noise
203
  samples = torch.randn(n_sample, 3, height, height).to(device)
204
 
 
221
  intermediate = np.stack(intermediate)
222
  return samples, intermediate
223
 
224
+ @torch.no_grad()
225
+ def sample_ddpm_context(n_sample, save_rate=20):
226
+ # x_T ~ N(0, 1), sample initial noise
227
+ samples = torch.randn(n_sample, 3, height, height).to(device)
228
+
229
+ # array to keep track of generated steps for plotting
230
+ intermediate = []
231
+ for i in range(timesteps, 0, -1):
232
+ print(f'sampling timestep {i:3d}', end='\r')
233
+
234
+ # reshape time tensor
235
+ t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
236
+
237
+ # sample some random noise to inject back in. For i = 1, don't add back in noise
238
+ z = torch.randn_like(samples) if i > 1 else 0
239
+
240
+ eps = nn_model(samples, t, c=context) # predict noise e_(x_t,t)
241
+ samples = denoise_add_noise(samples, i, eps, z)
242
+ if i % save_rate ==0 or i==timesteps or i<8:
243
+ intermediate.append(samples.detach().cpu().numpy())
244
+
245
+ intermediate = np.stack(intermediate)
246
+ return samples, intermediate
247
+
248
  def greet(input):
249
  steps = int(input)
250
 
 
257
  ctx = torch.from_numpy(mtx_2d).to(device=device).float()
258
 
259
  #samples, intermediate = sample_ddim_context(32, ctx, n=steps)
260
+ samples, intermediate = sample_ddpm_context(32, ctx, steps)
261
 
262
  #samples, intermediate = sample_ddim(32, n=steps)
263
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()