File size: 898 Bytes
b396e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Summarization Device: {device}")

model_ckpt = "google/pegasus-cnn_dailymail"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model_pegasus = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)

def summarize_text(text: str) -> str:
    input_ids = tokenizer.encode(
        text,
        return_tensors="pt",
        max_length=1024,
        truncation=True,
    ).to(device)
    try:
        summary_ids = model_pegasus.generate(input_ids, max_length=130, min_length=30, do_sample=False)
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary
    except RuntimeError as e:
        print(f"Summarization Error: {e}")
        return "Error: Could not generate summary due to length constraints."