File size: 5,719 Bytes
0d925f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import streamlit as st
import requests
from bs4 import BeautifulSoup
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
import os
# Initialize session state for model and tokenizer
if 'model' not in st.session_state:
st.session_state.model = None
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer = None
@st.cache_resource
def load_model():
try:
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the model
model = T5ForConditionalGeneration.from_pretrained('t5-base')
# Load the saved weights with appropriate map_location
checkpoint = torch.load('abstractive-model-sihanas.pth', map_location=device)
model.load_state_dict(checkpoint)
model.to(device)
# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-base')
return model, tokenizer, device
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None, None
def clean_text(text):
"""Clean and preprocess the input text"""
# Remove extra whitespace
text = ' '.join(text.split())
# Remove very long words (likely garbage)
text = ' '.join(word for word in text.split() if len(word) < 100)
return text
def summarize_text(text, model, tokenizer, device):
try:
# Clean the text
cleaned_text = clean_text(text)
# Tokenize and generate summary
inputs = tokenizer.encode("summarize: " + cleaned_text,
return_tensors='pt',
max_length=512,
truncation=True).to(device)
summary_ids = model.generate(
inputs,
max_length=150,
min_length=40,
num_beams=4,
length_penalty=2.0,
early_stopping=True
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
st.error(f"Error in summarization: {str(e)}")
return None
def fetch_article(url):
"""Fetch article content and metadata from URL"""
try:
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status() # Raise an exception for bad status codes
soup = BeautifulSoup(response.content, 'html.parser')
# Extract metadata
title = soup.find('meta', property='og:title') or soup.title
title = title.get('content', '').strip() if title else 'No title found'
authors = soup.find('meta', {'name': 'author'})
authors = authors.get('content', '').strip() if authors else 'No author information'
publish_date = soup.find('meta', {'property': 'article:published_time'})
publish_date = publish_date.get('content', '').strip() if publish_date else 'No publish date found'
publisher = soup.find('meta', {'property': 'og:site_name'})
publisher = publisher.get('content', '').strip() if publisher else 'No publisher information'
# Remove scripts, styles, and navigation elements
for element in soup(['script', 'style', 'nav', 'header', 'footer']):
element.decompose()
text = soup.get_text(separator=' ', strip=True)
return title, authors, publish_date, publisher, text
except requests.exceptions.RequestException as e:
st.error(f"Error fetching the article: {str(e)}")
return None, None, None, None, None
def main():
st.title("News Article Summarizer")
st.write("Enter a news article URL to get a summary.")
# Load model and tokenizer
model, tokenizer, device = load_model()
if model is None or tokenizer is None:
st.error("Failed to load the model. Please check your model file and dependencies.")
return
# URL input
url = st.text_input("News Article URL")
if st.button("Summarize"):
if not url:
st.warning("Please enter a URL")
return
with st.spinner("Fetching article and generating summary..."):
# Fetch article
title, authors, publish_date, publisher, article_text = fetch_article(url)
if article_text:
# Display metadata
st.write(f"**Title**: {title}")
st.write(f"**Authors**: {authors}")
st.write(f"**Publish Date**: {publish_date}")
st.write(f"**Publisher**: {publisher}")
# Generate summary
summary = summarize_text(article_text, model, tokenizer, device)
if summary:
st.success("Summary generated successfully!")
st.write("### Summary")
st.write(summary)
# Display original text (collapsed)
with st.expander("Show original article"):
st.write(article_text)
else:
st.error("Failed to fetch the article. Please check the URL and try again.")
if __name__ == "__main__":
main()
|