Dhanush S Gowda
Update app.py
0330532 verified
raw
history blame
3.55 kB
import streamlit as st
from transformers import pipeline
import traceback
def load_pipelines():
"""
Load summarization pipelines with error handling.
Returns:
dict: Dictionary of model pipelines
"""
try:
bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
t5_pipeline = pipeline("summarization", model="t5-large")
pegasus_pipeline = pipeline("pegasus-cnn_dailymail")
return {
'BART': bart_pipeline,
'T5': t5_pipeline,
'Pegasus': pegasus_pipeline
}
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error(traceback.format_exc())
return {}
def generate_summary(pipeline, text, model_name):
"""
Generate summary for a specific model with error handling.
Args:
pipeline: Hugging Face summarization pipeline
text (str): Input text to summarize
model_name (str): Name of the model
Returns:
str: Generated summary or error message
"""
try:
prompt = "Summarize the below paragraph"
summary = pipeline(f"{prompt}\n{text}",
max_length=150,
min_length=50,
length_penalty=2.0,
num_beams=4,
early_stopping=True)[0]['summary_text']
return summary
except Exception as e:
error_msg = f"Error in {model_name} summarization: {str(e)}"
st.error(error_msg)
return error_msg
def main():
st.title("Text Summarization with Pre-trained Models")
# Text input
text_input = st.text_area("Enter text to summarize:")
# Generate button
if st.button("Generate Summary"):
if not text_input:
st.error("Please enter text to summarize.")
return
# Load pipelines
pipelines = load_pipelines()
if not pipelines:
st.error("Failed to load models. Please check your internet connection or try again later.")
return
# Create columns for progressive display
bart_col, t5_col, pegasus_col = st.columns(3)
# BART Summary
with bart_col:
with st.spinner('Generating BART Summary...'):
bart_progress = st.progress(0)
bart_progress.progress(50)
bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
bart_progress.progress(100)
st.subheader("BART Summary")
st.write(bart_summary)
# T5 Summary
with t5_col:
with st.spinner('Generating T5 Summary...'):
t5_progress = st.progress(0)
t5_progress.progress(50)
t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
t5_progress.progress(100)
st.subheader("T5 Summary")
st.write(t5_summary)
# Pegasus Summary
with pegasus_col:
with st.spinner('Generating Pegasus Summary...'):
pegasus_progress = st.progress(0)
pegasus_progress.progress(50)
pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
pegasus_progress.progress(100)
st.subheader("Pegasus Summary")
st.write(pegasus_summary)
if __name__ == "__main__":
main()