Spaces:
Sleeping
Sleeping
File size: 1,295 Bytes
56da2e5 f9b176d 56da2e5 f9b176d 56da2e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
def load_summarizers():
models = {
"T5": "Overglitch/t5-small-cnn-dailymail",
"BART": "facebook/bart-large-cnn",
}
summarizers = {}
for model_name, model_path in models.items():
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path)
summarizers[model_name] = (model, tokenizer)
return summarizers
def load_pegasus_model_and_tokenizer(model_name: str):
model = PegasusForConditionalGeneration.from_pretrained(model_name)
tokenizer = PegasusTokenizer.from_pretrained(model_name)
return model, tokenizer
def abstractive_summary(summarizers, model_name, text, max_length, num_beams):
model, tokenizer = summarizers[model_name]
inputs = tokenizer(
text, return_tensors="pt", max_length=512, truncation=True
).to(model.device)
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|