File size: 3,554 Bytes
caf0283 d769310 0330532 53545b3 0330532 d769310 0330532 d769310 0330532 6895495 0330532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import streamlit as st
from transformers import pipeline
import traceback
def load_pipelines():
"""
Load summarization pipelines with error handling.
Returns:
dict: Dictionary of model pipelines
"""
try:
bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
t5_pipeline = pipeline("summarization", model="t5-large")
pegasus_pipeline = pipeline("pegasus-cnn_dailymail")
return {
'BART': bart_pipeline,
'T5': t5_pipeline,
'Pegasus': pegasus_pipeline
}
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error(traceback.format_exc())
return {}
def generate_summary(pipeline, text, model_name):
"""
Generate summary for a specific model with error handling.
Args:
pipeline: Hugging Face summarization pipeline
text (str): Input text to summarize
model_name (str): Name of the model
Returns:
str: Generated summary or error message
"""
try:
prompt = "Summarize the below paragraph"
summary = pipeline(f"{prompt}\n{text}",
max_length=150,
min_length=50,
length_penalty=2.0,
num_beams=4,
early_stopping=True)[0]['summary_text']
return summary
except Exception as e:
error_msg = f"Error in {model_name} summarization: {str(e)}"
st.error(error_msg)
return error_msg
def main():
st.title("Text Summarization with Pre-trained Models")
# Text input
text_input = st.text_area("Enter text to summarize:")
# Generate button
if st.button("Generate Summary"):
if not text_input:
st.error("Please enter text to summarize.")
return
# Load pipelines
pipelines = load_pipelines()
if not pipelines:
st.error("Failed to load models. Please check your internet connection or try again later.")
return
# Create columns for progressive display
bart_col, t5_col, pegasus_col = st.columns(3)
# BART Summary
with bart_col:
with st.spinner('Generating BART Summary...'):
bart_progress = st.progress(0)
bart_progress.progress(50)
bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
bart_progress.progress(100)
st.subheader("BART Summary")
st.write(bart_summary)
# T5 Summary
with t5_col:
with st.spinner('Generating T5 Summary...'):
t5_progress = st.progress(0)
t5_progress.progress(50)
t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
t5_progress.progress(100)
st.subheader("T5 Summary")
st.write(t5_summary)
# Pegasus Summary
with pegasus_col:
with st.spinner('Generating Pegasus Summary...'):
pegasus_progress = st.progress(0)
pegasus_progress.progress(50)
pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
pegasus_progress.progress(100)
st.subheader("Pegasus Summary")
st.write(pegasus_summary)
if __name__ == "__main__":
main() |