andreslu commited on
Commit
97e6893
·
1 Parent(s): 0044d3c

Upload 2 files

Browse files
Files changed (1) hide show
  1. inductor.py +17 -14
inductor.py CHANGED
@@ -43,11 +43,11 @@ class BartInductor(object):
43
  self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
44
 
45
  if group_beam:
46
- self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval().half()
47
  else:
48
- self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval().half()
49
 
50
- self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval().half()
51
 
52
  self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
53
  self.word_length = 2
@@ -128,15 +128,18 @@ class BartInductor(object):
128
 
129
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
130
  spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
131
- generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device)
132
- generated_ret = self.orion_instance_generator.generate(generated_ids.astype(np.float32), num_beams=max(120, k),
133
- min_length=min_length,
134
- max_length=max_length,
135
- early_stopping=early_stopping,
136
- **generator_kwargs)
137
-
 
 
 
138
  summary_ids = generated_ret['sequences']
139
- probs = F.softmax(generated_ret['sequences_scores'])
140
  txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
141
  ret = []
142
 
@@ -221,7 +224,7 @@ class BartInductor(object):
221
  index_words[len(index_words)] = '\t'.join(words)
222
  # index_words[len(templates)-1] = '\t'.join(words)
223
  if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
224
- generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device)
225
  generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
226
  num_beam_groups=num_beams,
227
  max_length=28, #template_length+5,
@@ -236,7 +239,7 @@ class BartInductor(object):
236
  top_p=0.95,
237
  )
238
  summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1))
239
- probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1)
240
  for ii in range(summary_ids.size(0)):
241
  txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
242
  summary_ids[ii]]
@@ -303,7 +306,7 @@ class BartInductor(object):
303
 
304
  class CometInductor(object):
305
  def __init__(self):
306
- self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval() # .half()
307
  self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
308
  self.task = "summarization"
309
  self.use_task_specific_params()
 
43
  self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
44
 
45
  if group_beam:
46
+ self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
47
  else:
48
+ self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
49
 
50
+ self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval()
51
 
52
  self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
53
  self.word_length = 2
 
128
 
129
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
130
  spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
131
+ generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64)
132
+ generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
133
+ #num_beam_groups=max(120, k),
134
+ max_length=generated_ids.size(1) + 15,
135
+ num_return_sequences=max(120, k), #min_length=generated_ids.size(1),
136
+ #diversity_penalty=2.0,
137
+ #length_penalty= 0.8,
138
+ #early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2,
139
+ output_scores=True,
140
+ return_dict_in_generate=True)
141
  summary_ids = generated_ret['sequences']
142
+ probs = F.softmax(generated_ret['sequences_scores'].to(torch.float32))
143
  txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
144
  ret = []
145
 
 
224
  index_words[len(index_words)] = '\t'.join(words)
225
  # index_words[len(templates)-1] = '\t'.join(words)
226
  if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
227
+ generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device).to(torch.int64)
228
  generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
229
  num_beam_groups=num_beams,
230
  max_length=28, #template_length+5,
 
239
  top_p=0.95,
240
  )
241
  summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1))
242
+ probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1).to(torch.float32)
243
  for ii in range(summary_ids.size(0)):
244
  txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
245
  summary_ids[ii]]
 
306
 
307
  class CometInductor(object):
308
  def __init__(self):
309
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval().float() # .half()->float
310
  self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
311
  self.task = "summarization"
312
  self.use_task_specific_params()