Spaces:
Runtime error
Runtime error
import nltk | |
nltk.download('punkt') | |
# Third cell - Main implementation | |
import torch | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
from newspaper import Article | |
import gradio as gr | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Check if GPU is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Initialize model and tokenizer | |
model_name = "google/pegasus-large" | |
try: | |
tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
model = PegasusForConditionalGeneration.from_pretrained(model_name) | |
model = model.to(device) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
def fetch_article_text(url): | |
"""Fetch and extract text from a given URL""" | |
try: | |
article = Article(url) | |
article.download() | |
article.parse() | |
return article.text | |
except Exception as e: | |
return f"Error fetching article: {e}" | |
def summarize_text(text, max_length=150, min_length=40): | |
"""Generate summary using the Pegasus model""" | |
try: | |
# Tokenize with padding and truncation | |
inputs = tokenizer( | |
text, | |
max_length=1024, | |
truncation=True, | |
padding="max_length", | |
return_tensors="pt" | |
).to(device) | |
# Generate summary | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
min_length=min_length, | |
length_penalty=2.0, | |
num_beams=4, | |
early_stopping=True | |
) | |
# Decode and return summary | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
except Exception as e: | |
return f"Error generating summary: {e}" | |
def process_input(input_text, input_type, max_length=150, min_length=40): | |
"""Process either URL or direct text input""" | |
try: | |
if input_type == "URL": | |
text = fetch_article_text(input_text) | |
if "Error" in text: | |
return text | |
else: | |
text = input_text | |
if not text or len(text.strip()) < 100: | |
return "Error: Input text is too short or empty." | |
return summarize_text(text, max_length, min_length) | |
except Exception as e: | |
return f"Error processing input: {e}" | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="Research Article Summarizer") as interface: | |
gr.Markdown("# Research Article Summarizer") | |
gr.Markdown("Enter either a URL or paste the article text directly.") | |
with gr.Row(): | |
input_type = gr.Radio( | |
choices=["URL", "Text"], | |
value="URL", | |
label="Input Type" | |
) | |
with gr.Row(): | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter URL or paste article text here...", | |
label="Input" | |
) | |
with gr.Row(): | |
max_length = gr.Slider( | |
minimum=50, | |
maximum=500, | |
value=150, | |
step=10, | |
label="Maximum Summary Length" | |
) | |
min_length = gr.Slider( | |
minimum=20, | |
maximum=200, | |
value=40, | |
step=10, | |
label="Minimum Summary Length" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Generate Summary") | |
with gr.Row(): | |
output = gr.Textbox( | |
lines=5, | |
label="Generated Summary" | |
) | |
submit_btn.click( | |
fn=process_input, | |
inputs=[input_text, input_type, max_length, min_length], | |
outputs=output | |
) | |
return interface | |
# Launch the interface | |
demo = create_interface() | |
demo.launch(debug=True, share=True) |