File size: 4,638 Bytes
038a86f
caf0283
d769310
0330532
53545b3
038a86f
 
 
 
 
 
 
 
 
 
 
 
 
0330532
 
038a86f
0330532
 
 
 
 
038a86f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0330532
 
 
 
 
 
 
 
 
d769310
0330532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d769310
0330532
 
6895495
038a86f
 
 
0330532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import streamlit as st
from transformers import pipeline
import traceback

# Use Hugging Face Spaces' recommended persistent storage
CACHE_DIR = os.path.join(os.getcwd(), "model_cache")

def ensure_cache_dir():
    """
    Ensure the cache directory exists.
    
    Returns:
        str: Path to the cache directory
    """
    os.makedirs(CACHE_DIR, exist_ok=True)
    return CACHE_DIR

def load_pipelines():
    """
    Load summarization pipelines with persistent caching.
    
    Returns:
        dict: Dictionary of model pipelines
    """
    try:
        # Ensure cache directory exists
        cache_dir = ensure_cache_dir()
        
        # Define model paths within the cache directory
        bart_cache = os.path.join(cache_dir, "bart-large-cnn")
        t5_cache = os.path.join(cache_dir, "t5-large")
        pegasus_cache = os.path.join(cache_dir, "pegasus-cnn_dailymail")
        
        # Load pipelines with explicit cache directories
        bart_pipeline = pipeline(
            "summarization", 
            model="facebook/bart-large-cnn", 
            cache_dir=bart_cache
        )
        t5_pipeline = pipeline(
            "summarization", 
            model="t5-large", 
            cache_dir=t5_cache
        )
        pegasus_pipeline = pipeline(
            "summarization", 
            model="google/pegasus-cnn_dailymail", 
            cache_dir=pegasus_cache
        )
        
        return {
            'BART': bart_pipeline, 
            'T5': t5_pipeline, 
            'Pegasus': pegasus_pipeline
        }
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        st.error(traceback.format_exc())
        return {}

def generate_summary(pipeline, text, model_name):
    """
    Generate summary for a specific model with error handling.
    
    Args:
        pipeline: Hugging Face summarization pipeline
        text (str): Input text to summarize
        model_name (str): Name of the model
    
    Returns:
        str: Generated summary or error message
    """
    try:
        prompt = "Summarize the below paragraph"
        summary = pipeline(f"{prompt}\n{text}", 
                           max_length=150, 
                           min_length=50, 
                           length_penalty=2.0, 
                           num_beams=4, 
                           early_stopping=True)[0]['summary_text']
        return summary
    except Exception as e:
        error_msg = f"Error in {model_name} summarization: {str(e)}"
        st.error(error_msg)
        return error_msg

def main():
    st.title("Text Summarization with Pre-trained Models")
    
    # Display cache directory info (optional)
    st.info(f"Models will be cached in: {CACHE_DIR}")
    
    # Text input
    text_input = st.text_area("Enter text to summarize:")
    
    # Generate button
    if st.button("Generate Summary"):
        if not text_input:
            st.error("Please enter text to summarize.")
            return
        
        # Load pipelines
        pipelines = load_pipelines()
        if not pipelines:
            st.error("Failed to load models. Please check your internet connection or try again later.")
            return
        
        # Create columns for progressive display
        bart_col, t5_col, pegasus_col = st.columns(3)
        
        # BART Summary
        with bart_col:
            with st.spinner('Generating BART Summary...'):
                bart_progress = st.progress(0)
                bart_progress.progress(50)
                bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
                bart_progress.progress(100)
                st.subheader("BART Summary")
                st.write(bart_summary)
        
        # T5 Summary
        with t5_col:
            with st.spinner('Generating T5 Summary...'):
                t5_progress = st.progress(0)
                t5_progress.progress(50)
                t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
                t5_progress.progress(100)
                st.subheader("T5 Summary")
                st.write(t5_summary)
        
        # Pegasus Summary
        with pegasus_col:
            with st.spinner('Generating Pegasus Summary...'):
                pegasus_progress = st.progress(0)
                pegasus_progress.progress(50)
                pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
                pegasus_progress.progress(100)
                st.subheader("Pegasus Summary")
                st.write(pegasus_summary)

if __name__ == "__main__":
    main()