Spaces:
Sleeping
Sleeping
File size: 1,329 Bytes
56da2e5 06108ff 56da2e5 580f5a5 56da2e5 06108ff 56da2e5 4a25d51 56da2e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
def load_summarizers():
models = {
"Pegasus": "google/pegasus-cnn_dailymail",
"T5": "Overglitch/t5-small-cnn-dailymail",
"BART": "facebook/bart-large-cnn",
}
summarizers = {}
for model_name, model_path in models.items():
if model_name == "Pegasus":
tokenizer = PegasusTokenizer.from_pretrained(model_path)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path)
summarizers[model_name] = (model, tokenizer)
return summarizers
def abstractive_summary(summarizers, model_name, text, max_length, num_beams):
model, tokenizer = summarizers[model_name]
inputs = tokenizer(
text, return_tensors="pt", max_length=1024, truncation=True
).to(model.device)
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|