Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- inductor.py +17 -14
inductor.py
CHANGED
@@ -43,11 +43,11 @@ class BartInductor(object):
|
|
43 |
self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
|
44 |
|
45 |
if group_beam:
|
46 |
-
self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
|
47 |
else:
|
48 |
-
self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
|
49 |
|
50 |
-
self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval()
|
51 |
|
52 |
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
53 |
self.word_length = 2
|
@@ -128,15 +128,18 @@ class BartInductor(object):
|
|
128 |
|
129 |
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
|
130 |
spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
|
131 |
-
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device)
|
132 |
-
generated_ret = self.orion_instance_generator.generate(generated_ids
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
138 |
summary_ids = generated_ret['sequences']
|
139 |
-
probs = F.softmax(generated_ret['sequences_scores'])
|
140 |
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
|
141 |
ret = []
|
142 |
|
@@ -221,7 +224,7 @@ class BartInductor(object):
|
|
221 |
index_words[len(index_words)] = '\t'.join(words)
|
222 |
# index_words[len(templates)-1] = '\t'.join(words)
|
223 |
if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
|
224 |
-
generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device)
|
225 |
generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
|
226 |
num_beam_groups=num_beams,
|
227 |
max_length=28, #template_length+5,
|
@@ -236,7 +239,7 @@ class BartInductor(object):
|
|
236 |
top_p=0.95,
|
237 |
)
|
238 |
summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1))
|
239 |
-
probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1)
|
240 |
for ii in range(summary_ids.size(0)):
|
241 |
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
|
242 |
summary_ids[ii]]
|
@@ -303,7 +306,7 @@ class BartInductor(object):
|
|
303 |
|
304 |
class CometInductor(object):
|
305 |
def __init__(self):
|
306 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval() # .half()
|
307 |
self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
|
308 |
self.task = "summarization"
|
309 |
self.use_task_specific_params()
|
|
|
43 |
self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
|
44 |
|
45 |
if group_beam:
|
46 |
+
self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
|
47 |
else:
|
48 |
+
self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
|
49 |
|
50 |
+
self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval()
|
51 |
|
52 |
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
53 |
self.word_length = 2
|
|
|
128 |
|
129 |
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
|
130 |
spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
|
131 |
+
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64)
|
132 |
+
generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
|
133 |
+
#num_beam_groups=max(120, k),
|
134 |
+
max_length=generated_ids.size(1) + 15,
|
135 |
+
num_return_sequences=max(120, k), #min_length=generated_ids.size(1),
|
136 |
+
#diversity_penalty=2.0,
|
137 |
+
#length_penalty= 0.8,
|
138 |
+
#early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2,
|
139 |
+
output_scores=True,
|
140 |
+
return_dict_in_generate=True)
|
141 |
summary_ids = generated_ret['sequences']
|
142 |
+
probs = F.softmax(generated_ret['sequences_scores'].to(torch.float32))
|
143 |
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
|
144 |
ret = []
|
145 |
|
|
|
224 |
index_words[len(index_words)] = '\t'.join(words)
|
225 |
# index_words[len(templates)-1] = '\t'.join(words)
|
226 |
if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
|
227 |
+
generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device).to(torch.int64)
|
228 |
generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
|
229 |
num_beam_groups=num_beams,
|
230 |
max_length=28, #template_length+5,
|
|
|
239 |
top_p=0.95,
|
240 |
)
|
241 |
summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1))
|
242 |
+
probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1).to(torch.float32)
|
243 |
for ii in range(summary_ids.size(0)):
|
244 |
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
|
245 |
summary_ids[ii]]
|
|
|
306 |
|
307 |
class CometInductor(object):
|
308 |
def __init__(self):
|
309 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval().float() # .half()->float
|
310 |
self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
|
311 |
self.task = "summarization"
|
312 |
self.use_task_specific_params()
|