debisoft commited on
Commit
14d3ba8
·
1 Parent(s): b6a29ba
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -222,7 +222,7 @@ def sample_ddpm(n_sample, context, save_rate=20):
222
  return samples, intermediate
223
 
224
  @torch.no_grad()
225
- def sample_ddpm_context(n_sample,timesteps, context, save_rate=20):
226
  # x_T ~ N(0, 1), sample initial noise
227
  samples = torch.randn(n_sample, 3, height, height).to(device)
228
 
@@ -259,7 +259,7 @@ def greet(input):
259
  #samples, intermediate = sample_ddim_context(32, ctx, n=steps)
260
 
261
  #samples, intermediate = sample_ddpm_context(image_count, steps, ctx)
262
- samples, intermediate = sample_ddim_context(image_count, steps, ctx)
263
 
264
  #samples, intermediate = sample_ddim(32, n=steps)
265
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
 
222
  return samples, intermediate
223
 
224
  @torch.no_grad()
225
+ def sample_ddpm_context(n_sample, timesteps, context, save_rate=20):
226
  # x_T ~ N(0, 1), sample initial noise
227
  samples = torch.randn(n_sample, 3, height, height).to(device)
228
 
 
259
  #samples, intermediate = sample_ddim_context(32, ctx, n=steps)
260
 
261
  #samples, intermediate = sample_ddpm_context(image_count, steps, ctx)
262
+ samples, intermediate = sample_ddim_context(image_count, ctx, steps)
263
 
264
  #samples, intermediate = sample_ddim(32, n=steps)
265
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()