andreslu commited on
Commit
cbff704
·
1 Parent(s): b67aed7

Update inductor.py

Browse files
Files changed (1) hide show
  1. inductor.py +3 -2
inductor.py CHANGED
@@ -85,11 +85,12 @@ class BartInductor(object):
85
  def generate(self, inputs, k=10, topk=10):
86
  with torch.no_grad():
87
  tB_probs = self.generate_rule(inputs, k)
88
- ret = [t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>') for t in tB_probs]
 
89
 
90
  new_ret = []
91
  for temp in ret:
92
- temp = self.clean(temp.strip())
93
  if len(new_ret) < topk and temp not in new_ret:
94
  new_ret.append(temp)
95
 
 
85
  def generate(self, inputs, k=10, topk=10):
86
  with torch.no_grad():
87
  tB_probs = self.generate_rule(inputs, k)
88
+ #ret = [t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>') for t in tB_probs]
89
+ ret = [(t[0].replace('<ent0>','<mask>').replace('<ent1>','<mask>'), t[1]) for t in tB_probs]
90
 
91
  new_ret = []
92
  for temp in ret:
93
+ temp = (self.clean(temp[0].strip()), temp[1])
94
  if len(new_ret) < topk and temp not in new_ret:
95
  new_ret.append(temp)
96