AjayP13 commited on
Commit
58979a1
·
verified ·
1 Parent(s): 7f2c0d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -0
app.py CHANGED
@@ -27,6 +27,7 @@ def get_target_style_embeddings(target_texts_batch):
27
  mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
28
  return mean_embeddings.float().cpu().numpy()
29
 
 
30
  def get_luar_embeddings(texts_batch):
31
  assert len(set([len(texts) for texts in texts_batch])) == 1
32
  episodes = texts_batch
 
27
  mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
28
  return mean_embeddings.float().cpu().numpy()
29
 
30
+ @torch.no_grad()
31
  def get_luar_embeddings(texts_batch):
32
  assert len(set([len(texts) for texts in texts_batch])) == 1
33
  episodes = texts_batch