debisoft commited on
Commit
f5c772b
·
1 Parent(s): 8e4ca1d
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -223,7 +223,17 @@ def sample_ddpm(n_sample, save_rate=20):
223
 
224
  def greet(input):
225
  steps = int(input)
226
- samples, intermediate = sample_ddim(32, n=steps)
 
 
 
 
 
 
 
 
 
 
227
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
228
  #samples, intermediate = sample_ddim_context(32, ctx, steps)
229
  #samples, intermediate = sample_ddpm(steps)
 
223
 
224
  def greet(input):
225
  steps = int(input)
226
+
227
+ #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
228
+
229
+ # hero, non-hero, food, spell, side-facing
230
+ shape = (32, 5)
231
+ mtx_2d = np.ones(shape) * one_hot_enc
232
+ ctx = mtx_2d.to(device=device).float()
233
+
234
+ samples, intermediate = sample_ddim_ctx(32, ctx, n=steps)
235
+
236
+ #samples, intermediate = sample_ddim(32, n=steps)
237
  #ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
238
  #samples, intermediate = sample_ddim_context(32, ctx, steps)
239
  #samples, intermediate = sample_ddpm(steps)