Spaces:
Runtime error
Runtime error
File size: 2,571 Bytes
d0ed3d4 10556f2 d0ed3d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from transformers import BartTokenizer, BartForConditionalGeneration
import sys
# shortTokenizer = BartTokenizer.from_pretrained('./ml/distilbart-xsum-12-6/', local_files_only=True)
# shortModel = BartForConditionalGeneration.from_pretrained('./ml/distilbart-xsum-12-6/', local_files_only=True)
# longTokenizer = BartTokenizer.from_pretrained('./ml/distilbart-cnn-12-6/', local_files_only=True)
# longModel = BartForConditionalGeneration.from_pretrained('./ml/distilbart-cnn-12-6/', local_files_only=True)
shortTokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-xsum-12-6')
shortModel = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-xsum-12-6')
longTokenizer = BartTokenizer.from_pretrained('datien228/distilbart-cnn-12-6-ftn-multi_news')
longModel = BartForConditionalGeneration.from_pretrained('datien228/distilbart-cnn-12-6-ftn-multi_news')
def summarize(text, num_beams=5, length_penalty=2.0, max_length=50, min_length=15, no_repeat_ngram_size=3):
text = text.replace('\n','')
text_input_ids = shortTokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)
summary_ids = shortModel.generate(text_input_ids, num_beams=int(num_beams),
length_penalty=float(length_penalty),
max_length=int(max_length),
min_length=int(min_length),
no_repeat_ngram_size=int(no_repeat_ngram_size),
top_k=50)
short_summary_txt = shortTokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True,
clean_up_tokenization_spaces=False)
print('Short summary done', file=sys.stderr)
text_input_ids = longTokenizer.encode(text, return_tensors='pt', max_length=1024, truncation=True)
summary_ids = longModel.generate(text_input_ids, num_beams=int(num_beams),
length_penalty=float(length_penalty),
# max_length=int(max_length)+45,
# min_length=int(min_length)+45,
no_repeat_ngram_size=int(no_repeat_ngram_size),
top_k=50)
long_summary_txt = longTokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True,
clean_up_tokenization_spaces=False)
print('Long summary done', file=sys.stderr)
return short_summary_txt, long_summary_txt |