File size: 3,554 Bytes
caf0283
d769310
0330532
53545b3
0330532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d769310
0330532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d769310
0330532
 
6895495
0330532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import pipeline
import traceback

def load_pipelines():
    """
    Load summarization pipelines with error handling.
    
    Returns:
        dict: Dictionary of model pipelines
    """
    try:
        bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
        t5_pipeline = pipeline("summarization", model="t5-large")
        pegasus_pipeline = pipeline("pegasus-cnn_dailymail")
        return {
            'BART': bart_pipeline, 
            'T5': t5_pipeline, 
            'Pegasus': pegasus_pipeline
        }
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        st.error(traceback.format_exc())
        return {}

def generate_summary(pipeline, text, model_name):
    """
    Generate summary for a specific model with error handling.
    
    Args:
        pipeline: Hugging Face summarization pipeline
        text (str): Input text to summarize
        model_name (str): Name of the model
    
    Returns:
        str: Generated summary or error message
    """
    try:
        prompt = "Summarize the below paragraph"
        summary = pipeline(f"{prompt}\n{text}", 
                           max_length=150, 
                           min_length=50, 
                           length_penalty=2.0, 
                           num_beams=4, 
                           early_stopping=True)[0]['summary_text']
        return summary
    except Exception as e:
        error_msg = f"Error in {model_name} summarization: {str(e)}"
        st.error(error_msg)
        return error_msg

def main():
    st.title("Text Summarization with Pre-trained Models")
    
    # Text input
    text_input = st.text_area("Enter text to summarize:")
    
    # Generate button
    if st.button("Generate Summary"):
        if not text_input:
            st.error("Please enter text to summarize.")
            return
        
        # Load pipelines
        pipelines = load_pipelines()
        if not pipelines:
            st.error("Failed to load models. Please check your internet connection or try again later.")
            return
        
        # Create columns for progressive display
        bart_col, t5_col, pegasus_col = st.columns(3)
        
        # BART Summary
        with bart_col:
            with st.spinner('Generating BART Summary...'):
                bart_progress = st.progress(0)
                bart_progress.progress(50)
                bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
                bart_progress.progress(100)
                st.subheader("BART Summary")
                st.write(bart_summary)
        
        # T5 Summary
        with t5_col:
            with st.spinner('Generating T5 Summary...'):
                t5_progress = st.progress(0)
                t5_progress.progress(50)
                t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
                t5_progress.progress(100)
                st.subheader("T5 Summary")
                st.write(t5_summary)
        
        # Pegasus Summary
        with pegasus_col:
            with st.spinner('Generating Pegasus Summary...'):
                pegasus_progress = st.progress(0)
                pegasus_progress.progress(50)
                pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
                pegasus_progress.progress(100)
                st.subheader("Pegasus Summary")
                st.write(pegasus_summary)

if __name__ == "__main__":
    main()