Dhanush S Gowda commited on
Commit
d769310
·
verified ·
1 Parent(s): 756850c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -206
app.py CHANGED
@@ -1,216 +1,70 @@
1
  import streamlit as st
2
- import torch
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
- import concurrent.futures
5
- import numpy as np
6
- import psutil
7
  import os
 
 
8
 
9
- class UltraOptimizedSummarizer:
10
- def __init__(self):
11
- # Advanced caching and memory management
12
- self.models = {}
13
- self.tokenizers = {}
14
- self.device = self._get_optimal_device()
15
-
16
- def _get_optimal_device(self):
17
- """Intelligently select the best computational device."""
18
- if torch.cuda.is_available():
19
- # Find the GPU with most free memory
20
- gpu_memory = [torch.cuda.memory_allocated(i) for i in range(torch.cuda.device_count())]
21
- best_gpu = np.argmin(gpu_memory)
22
- return torch.device(f'cuda:{best_gpu}')
23
- elif torch.backends.mps.is_available():
24
- return torch.device('mps')
25
- return torch.device('cpu')
26
-
27
- def _load_model(self, model_name):
28
- """
29
- Optimized model loading with advanced memory management.
30
- Uses half-precision (float16) for reduced memory footprint.
31
- """
32
- if model_name in self.models:
33
- return self.models[model_name], self.tokenizers[model_name]
34
-
35
- try:
36
- # Load tokenizer
37
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/tmp/huggingface_cache')
38
-
39
- # Load model with optimization
40
- model = AutoModelForSeq2SeqLM.from_pretrained(
41
- model_name,
42
- cache_dir='/tmp/huggingface_cache',
43
- torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
44
- low_cpu_mem_usage=True
45
- ).to(self.device)
46
-
47
- # Optional: Model compilation for additional speed (PyTorch 2.0+)
48
- if hasattr(torch, 'compile'):
49
- model = torch.compile(model)
50
-
51
- self.models[model_name] = model
52
- self.tokenizers[model_name] = tokenizer
53
-
54
- return model, tokenizer
55
-
56
- except Exception as e:
57
- st.error(f"Model loading error for {model_name}: {e}")
58
- return None, None
59
 
60
- def summarize(self, text, model_name, max_length=150, min_length=50):
61
- """
62
- Ultra-optimized summarization with intelligent truncation.
63
- """
64
- model, tokenizer = self._load_model(model_name)
65
- if not model or not tokenizer:
66
- return "Summarization failed."
67
-
68
- try:
69
- # Intelligent text truncation
70
- inputs = tokenizer(
71
- text,
72
- max_length=1024, # Prevent OOM errors
73
- truncation=True,
74
- return_tensors='pt'
75
- ).to(self.device)
76
 
77
- # Generate summary with optimized parameters
78
- summary_ids = model.generate(
79
- inputs['input_ids'],
80
- num_beams=4,
81
- max_length=max_length,
82
- min_length=min_length,
83
- early_stopping=True,
84
- no_repeat_ngram_size=2,
85
- do_sample=False
86
- )
87
 
88
- # Decode summary
89
- summary = tokenizer.decode(
90
- summary_ids[0],
91
- skip_special_tokens=True
92
- )
93
 
94
- return summary
95
-
96
- except Exception as e:
97
- st.error(f"Summarization error for {model_name}: {e}")
98
- return "Could not generate summary."
99
-
100
- def parallel_summarize(self, text, max_length=150, min_length=50):
101
- """
102
- Concurrent summarization with advanced thread pooling.
103
- """
104
- model_configs = [
105
- "facebook/bart-large-cnn",
106
- "t5-large",
107
- "google/pegasus-cnn_dailymail"
108
- ]
109
-
110
- # Dynamic thread count based on system resources
111
- max_workers = min(
112
- len(model_configs),
113
- psutil.cpu_count(logical=False), # Physical cores
114
- 4 # Cap at 4 to prevent resource exhaustion
115
- )
116
-
117
- # Use concurrent futures for true parallel processing
118
- with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
119
- # Submit summarization tasks
120
- future_to_model = {
121
- executor.submit(
122
- self.summarize,
123
- text,
124
- model,
125
- max_length,
126
- min_length
127
- ): model for model in model_configs
128
- }
129
-
130
- # Collect results as they complete
131
- summaries = {}
132
- for future in concurrent.futures.as_completed(future_to_model):
133
- model = future_to_model[future]
134
- try:
135
- summaries[model] = future.result()
136
- except Exception as e:
137
- summaries[model] = f"Error: {e}"
138
-
139
- return summaries
140
-
141
- def main():
142
- st.set_page_config(
143
- page_title="Ultra-Optimized Summarization",
144
- page_icon="🚀",
145
- layout="wide"
146
- )
147
-
148
- st.title("🔬 Hyper-Optimized Text Summarization")
149
-
150
- # Initialize optimized summarizer
151
- summarizer = UltraOptimizedSummarizer()
152
-
153
- # Input and processing
154
- text_input = st.text_area(
155
- "Enter text for advanced summarization:",
156
- height=300
157
- )
158
-
159
- # Advanced compression control
160
- col1, col2 = st.columns(2)
161
- with col1:
162
- max_length = st.slider(
163
- "Max Summary Length",
164
- min_value=50,
165
- max_value=300,
166
- value=150
167
- )
168
-
169
- with col2:
170
- compression_rate = st.slider(
171
- "Compression Aggressiveness",
172
- min_value=0.1,
173
- max_value=0.5,
174
- value=0.3,
175
- step=0.05
176
- )
177
-
178
- if st.button("Generate Hyper-Optimized Summaries"):
179
- if not text_input:
180
- st.warning("Please provide text to summarize.")
181
- return
182
-
183
- # Progress tracking
184
- progress_bar = st.progress(0)
185
- status_text = st.empty()
186
-
187
- try:
188
- # Perform parallel summarization
189
- status_text.info("Initializing ultra-optimized summarization...")
190
- progress_bar.progress(20)
191
 
192
- summaries = summarizer.parallel_summarize(
193
- text_input,
194
- max_length=max_length,
195
- min_length=int(max_length * 0.5)
196
- )
197
 
198
- progress_bar.progress(100)
199
- status_text.success("Summarization Complete!")
 
200
 
201
- # Display results
202
- cols = st.columns(3)
203
- for (col, (model, summary)) in zip(cols, summaries.items()):
204
- with col:
205
- st.subheader(model.split('/')[-1].upper())
206
- st.write(summary)
207
-
208
- except Exception as e:
209
- st.error(f"Optimization failed: {e}")
210
-
211
- finally:
212
- progress_bar.empty()
213
- status_text.empty()
214
-
215
- if __name__ == "__main__":
216
- main()
 
1
  import streamlit as st
2
+ from transformers import pipeline
 
 
 
 
3
  import os
4
+ # Set Hugging Face cache directory
5
+ os.environ['TRANSFORMERS_CACHE'] = os.getenv('HF_HOME', os.path.expanduser('~/.cache/huggingface/hub'))
6
 
7
+ # Function to load all three models
8
+ @st.cache_resource
9
+ def load_models():
10
+ bart_summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
11
+ t5_summarizer = pipeline("summarization", model="t5-large")
12
+ pegasus_summarizer = pipeline("summarization", model="google/pegasus-cnn_dailymail")
13
+ return bart_summarizer, t5_summarizer, pegasus_summarizer
14
+
15
+ # Streamlit app layout
16
+ st.title("Text Summarization with Pre-trained Models: BART, T5, Pegasus")
17
+
18
+ # Load models
19
+ with st.spinner("Loading models..."):
20
+ bart_model, t5_model, pegasus_model = load_models()
21
+
22
+ # Input text
23
+ text_input = st.text_area("Enter text to summarize:")
24
+
25
+ # Compression rate slider
26
+ compression_rate = st.slider(
27
+ "Summary Compression Rate",
28
+ min_value=0.1,
29
+ max_value=0.5,
30
+ value=0.3,
31
+ step=0.05,
32
+ help="Adjust how much of the original text to keep in the summary"
33
+ )
34
+
35
+ if text_input:
36
+ word_count = len(text_input.split())
37
+ st.write(f"**Input Word Count:** {word_count}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ if st.button("Generate Summaries"):
40
+ with st.spinner("Generating summaries..."):
41
+ # Calculate dynamic max length based on compression rate
42
+ max_length = max(50, int(word_count * compression_rate))
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Generate summaries
45
+ bart_summary = bart_model(
46
+ text_input, max_length=max_length, min_length=30, num_beams=4, early_stopping=True
47
+ )[0]['summary_text']
 
 
 
 
 
 
48
 
49
+ t5_summary = t5_model(
50
+ text_input, max_length=max_length, min_length=30, num_beams=4, early_stopping=True
51
+ )[0]['summary_text']
 
 
52
 
53
+ pegasus_summary = pegasus_model(
54
+ text_input, max_length=max_length, min_length=30, num_beams=4, early_stopping=True
55
+ )[0]['summary_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Display summaries
58
+ st.subheader("BART Summary")
59
+ st.write(bart_summary)
60
+ st.write(f"**Word Count:** {len(bart_summary.split())}")
 
61
 
62
+ st.subheader("T5 Summary")
63
+ st.write(t5_summary)
64
+ st.write(f"**Word Count:** {len(t5_summary.split())}")
65
 
66
+ st.subheader("Pegasus Summary")
67
+ st.write(pegasus_summary)
68
+ st.write(f"**Word Count:** {len(pegasus_summary.split())}")
69
+ else:
70
+ st.warning("Please enter text to summarize.")