v1shal's picture
first_commit
b396e94
raw
history blame contribute delete
897 Bytes
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Summarization Device: {device}")
model_ckpt = "google/pegasus-cnn_dailymail"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model_pegasus = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)
def summarize_text(text: str) -> str:
input_ids = tokenizer.encode(
text,
return_tensors="pt",
max_length=1024,
truncation=True,
).to(device)
try:
summary_ids = model_pegasus.generate(input_ids, max_length=130, min_length=30, do_sample=False)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except RuntimeError as e:
print(f"Summarization Error: {e}")
return "Error: Could not generate summary due to length constraints."