|
import streamlit as st |
|
from transformers import pipeline |
|
import traceback |
|
|
|
def load_pipelines(): |
|
""" |
|
Load summarization pipelines with error handling. |
|
|
|
Returns: |
|
dict: Dictionary of model pipelines |
|
""" |
|
try: |
|
bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn") |
|
t5_pipeline = pipeline("summarization", model="t5-large") |
|
pegasus_pipeline = pipeline("pegasus-cnn_dailymail") |
|
return { |
|
'BART': bart_pipeline, |
|
'T5': t5_pipeline, |
|
'Pegasus': pegasus_pipeline |
|
} |
|
except Exception as e: |
|
st.error(f"Error loading models: {str(e)}") |
|
st.error(traceback.format_exc()) |
|
return {} |
|
|
|
def generate_summary(pipeline, text, model_name): |
|
""" |
|
Generate summary for a specific model with error handling. |
|
|
|
Args: |
|
pipeline: Hugging Face summarization pipeline |
|
text (str): Input text to summarize |
|
model_name (str): Name of the model |
|
|
|
Returns: |
|
str: Generated summary or error message |
|
""" |
|
try: |
|
prompt = "Summarize the below paragraph" |
|
summary = pipeline(f"{prompt}\n{text}", |
|
max_length=150, |
|
min_length=50, |
|
length_penalty=2.0, |
|
num_beams=4, |
|
early_stopping=True)[0]['summary_text'] |
|
return summary |
|
except Exception as e: |
|
error_msg = f"Error in {model_name} summarization: {str(e)}" |
|
st.error(error_msg) |
|
return error_msg |
|
|
|
def main(): |
|
st.title("Text Summarization with Pre-trained Models") |
|
|
|
|
|
text_input = st.text_area("Enter text to summarize:") |
|
|
|
|
|
if st.button("Generate Summary"): |
|
if not text_input: |
|
st.error("Please enter text to summarize.") |
|
return |
|
|
|
|
|
pipelines = load_pipelines() |
|
if not pipelines: |
|
st.error("Failed to load models. Please check your internet connection or try again later.") |
|
return |
|
|
|
|
|
bart_col, t5_col, pegasus_col = st.columns(3) |
|
|
|
|
|
with bart_col: |
|
with st.spinner('Generating BART Summary...'): |
|
bart_progress = st.progress(0) |
|
bart_progress.progress(50) |
|
bart_summary = generate_summary(pipelines['BART'], text_input, 'BART') |
|
bart_progress.progress(100) |
|
st.subheader("BART Summary") |
|
st.write(bart_summary) |
|
|
|
|
|
with t5_col: |
|
with st.spinner('Generating T5 Summary...'): |
|
t5_progress = st.progress(0) |
|
t5_progress.progress(50) |
|
t5_summary = generate_summary(pipelines['T5'], text_input, 'T5') |
|
t5_progress.progress(100) |
|
st.subheader("T5 Summary") |
|
st.write(t5_summary) |
|
|
|
|
|
with pegasus_col: |
|
with st.spinner('Generating Pegasus Summary...'): |
|
pegasus_progress = st.progress(0) |
|
pegasus_progress.progress(50) |
|
pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus') |
|
pegasus_progress.progress(100) |
|
st.subheader("Pegasus Summary") |
|
st.write(pegasus_summary) |
|
|
|
if __name__ == "__main__": |
|
main() |