debisoft commited on
Commit
62b4021
·
1 Parent(s): dbd4772
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -222,7 +222,7 @@ def sample_ddpm(n_sample, context, save_rate=20):
222
  return samples, intermediate
223
 
224
  @torch.no_grad()
225
- def sample_ddpm_context(n_sample, context, save_rate=20):
226
  # x_T ~ N(0, 1), sample initial noise
227
  samples = torch.randn(n_sample, 3, height, height).to(device)
228
 
@@ -257,7 +257,7 @@ def greet(input):
257
  ctx = torch.from_numpy(mtx_2d).to(device=device).float()
258
 
259
  #samples, intermediate = sample_ddim_context(32, ctx, n=steps)
260
- samples, intermediate = sample_ddpm_context(32, ctx, steps)
261
 
262
  #samples, intermediate = sample_ddim(32, n=steps)
263
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
 
222
  return samples, intermediate
223
 
224
  @torch.no_grad()
225
+ def sample_ddpm_context(n_sample,timesteps, context, save_rate=20):
226
  # x_T ~ N(0, 1), sample initial noise
227
  samples = torch.randn(n_sample, 3, height, height).to(device)
228
 
 
257
  ctx = torch.from_numpy(mtx_2d).to(device=device).float()
258
 
259
  #samples, intermediate = sample_ddim_context(32, ctx, n=steps)
260
+ samples, intermediate = sample_ddpm_context(32, steps, ctx)
261
 
262
  #samples, intermediate = sample_ddim(32, n=steps)
263
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()