File size: 4,364 Bytes
038a86f
caf0283
de13ccb
0330532
53545b3
038a86f
 
 
 
 
 
 
 
 
 
 
 
 
de13ccb
0330532
de13ccb
 
 
 
0330532
 
de13ccb
0330532
 
038a86f
 
 
de13ccb
 
038a86f
de13ccb
 
 
 
038a86f
de13ccb
 
 
 
 
038a86f
 
de13ccb
0330532
de13ccb
0330532
de13ccb
d769310
de13ccb
0330532
de13ccb
0330532
 
de13ccb
 
0330532
de13ccb
0330532
 
de13ccb
0330532
 
de13ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0330532
 
de13ccb
0330532
 
d769310
0330532
 
6895495
038a86f
 
 
de13ccb
 
 
 
 
 
 
0330532
 
 
 
 
 
 
 
 
 
 
 
de13ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0330532
de13ccb
 
 
 
0330532
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import traceback

# Use Hugging Face Spaces' recommended persistent storage
CACHE_DIR = os.path.join(os.getcwd(), "model_cache")

def ensure_cache_dir():
    """
    Ensure the cache directory exists.
    
    Returns:
        str: Path to the cache directory
    """
    os.makedirs(CACHE_DIR, exist_ok=True)
    return CACHE_DIR

def load_model_and_tokenizer(model_name):
    """
    Load model and tokenizer with persistent caching.
    
    Args:
        model_name (str): Name of the model to load
    
    Returns:
        tuple: (model, tokenizer)
    """
    try:
        # Ensure cache directory exists
        cache_dir = ensure_cache_dir()
        
        # Construct full cache path for this model
        model_cache_path = os.path.join(cache_dir, model_name.replace('/', '_'))
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            model_name, 
            cache_dir=model_cache_path
        )
        
        # Load model
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name, 
            cache_dir=model_cache_path
        )
        
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading {model_name}: {str(e)}")
        st.error(traceback.format_exc())
        return None, None

def generate_summary(model, tokenizer, text, max_length=150):
    """
    Generate summary using a specific model and tokenizer.
    
    Args:
        model: Hugging Face model
        tokenizer: Hugging Face tokenizer
        text (str): Input text to summarize
        max_length (int): Maximum length of summary
    
    Returns:
        str: Generated summary
    """
    try:
        # Prepare input
        inputs = tokenizer(
            f"summarize: {text}", 
            max_length=512, 
            return_tensors="pt", 
            truncation=True
        )
        
        # Generate summary
        summary_ids = model.generate(
            inputs.input_ids, 
            num_beams=4, 
            max_length=max_length, 
            early_stopping=True
        )
        
        # Decode summary
        summary = tokenizer.decode(
            summary_ids[0], 
            skip_special_tokens=True
        )
        
        return summary
    except Exception as e:
        error_msg = f"Error in summarization: {str(e)}"
        st.error(error_msg)
        return error_msg

def main():
    st.title("Text Summarization with Pre-trained Models")
    
    # Display cache directory info (optional)
    st.info(f"Models will be cached in: {CACHE_DIR}")
    
    # Define models
    models_to_load = {
        'BART': 'facebook/bart-large-cnn',
        'T5': 't5-large',
        'Pegasus': 'google/pegasus-cnn_dailymail'
    }
    
    # Text input
    text_input = st.text_area("Enter text to summarize:")
    
    # Generate button
    if st.button("Generate Summary"):
        if not text_input:
            st.error("Please enter text to summarize.")
            return
        
        # Create columns for progressive display
        bart_col, t5_col, pegasus_col = st.columns(3)
        
        # Function to process each model
        def process_model(col, model_name, model_path):
            with col:
                with st.spinner(f'Generating {model_name} Summary...'):
                    progress = st.progress(0)
                    progress.progress(50)
                    
                    # Load model and tokenizer
                    model, tokenizer = load_model_and_tokenizer(model_path)
                    
                    if model and tokenizer:
                        # Generate summary
                        summary = generate_summary(model, tokenizer, text_input)
                        
                        progress.progress(100)
                        st.subheader(f"{model_name} Summary")
                        st.write(summary)
                    else:
                        st.error(f"Failed to load {model_name} model")
        
        # Process each model
        process_model(bart_col, 'BART', 'facebook/bart-large-cnn')
        process_model(t5_col, 'T5', 't5-large')
        process_model(pegasus_col, 'Pegasus', 'google/pegasus-cnn_dailymail')

if __name__ == "__main__":
    main()