AjayP13 commited on
Commit
051330e
·
verified ·
1 Parent(s): aa07629

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -47,7 +47,7 @@ def get_luar_embeddings(texts_batch):
47
  return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
48
 
49
  def compute_mis(texts, target_texts_batch):
50
- a_texts = list(itertools.chain.from_iterable([[st] * len(target_texts) for st, target_texts in zip(source_texts, target_texts_batch)]))
51
  b_texts = list(itertools.chain.from_iterable(target_texts_batch))
52
  scores = mis.compute(a_texts, b_texts, batch_size=len(a_texts))
53
  for idx, (score, a_text, b_text) in enumerate(zip(scores, a_texts, b_texts)):
 
47
  return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
48
 
49
  def compute_mis(texts, target_texts_batch):
50
+ a_texts = list(itertools.chain.from_iterable([[t] * len(target_texts) for t, target_texts in zip(texts, target_texts_batch)]))
51
  b_texts = list(itertools.chain.from_iterable(target_texts_batch))
52
  scores = mis.compute(a_texts, b_texts, batch_size=len(a_texts))
53
  for idx, (score, a_text, b_text) in enumerate(zip(scores, a_texts, b_texts)):