andreslu commited on
Commit
feac146
·
1 Parent(s): f770c44

Update inductor.py

Browse files
Files changed (1) hide show
  1. inductor.py +6 -9
inductor.py CHANGED
@@ -129,15 +129,12 @@ class BartInductor(object):
129
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
130
  spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
131
  generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device)
132
- generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
133
- #num_beam_groups=max(120, k),
134
- max_length=generated_ids.size(1) + 15,
135
- num_return_sequences=max(120, k), #min_length=generated_ids.size(1),
136
- #diversity_penalty=2.0,
137
- #length_penalty= 0.8,
138
- #early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2,
139
- output_scores=True,
140
- return_dict_in_generate=True)
141
  summary_ids = generated_ret['sequences']
142
  probs = F.softmax(generated_ret['sequences_scores'])
143
  txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
 
129
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
130
  spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
131
  generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device)
132
+ generated_ret = self.orion_instance_generator.generate(generated_ids.astype(np.float32), num_beams=max(120, k),
133
+ min_length=min_length,
134
+ max_length=max_length,
135
+ early_stopping=early_stopping,
136
+ **generator_kwargs)
137
+
 
 
 
138
  summary_ids = generated_ret['sequences']
139
  probs = F.softmax(generated_ret['sequences_scores'])
140
  txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]