Dhanush S Gowda commited on
Commit
756850c
·
verified ·
1 Parent(s): 6895495

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -104
app.py CHANGED
@@ -1,143 +1,216 @@
1
  import streamlit as st
2
- import multiprocessing
3
- from transformers import pipeline
4
- import os
5
  import torch
 
 
 
 
 
6
 
7
- # Optimize model loading and caching
8
- @st.cache_resource
9
- def load_model(model_name):
10
- """Efficiently load a summarization model."""
11
- try:
12
- # Use GPU if available
13
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
- return pipeline("summarization", model=model_name, device=device)
15
- except Exception as e:
16
- st.error(f"Error loading model {model_name}: {e}")
17
- return None
18
-
19
- def generate_summary(model, text, length_percentage=0.3):
20
- """
21
- Generate summary with intelligent length control.
22
-
23
- Args:
24
- model: Hugging Face summarization pipeline
25
- text: Input text to summarize
26
- length_percentage: Percentage of original text to use for summary
27
-
28
- Returns:
29
- Generated summary
30
- """
31
- # Intelligent length calculation
32
- word_count = len(text.split())
33
- max_length = max(50, int(word_count * length_percentage))
34
- min_length = max(30, int(word_count * 0.1))
35
-
36
- try:
37
- summary = model(
38
- text,
39
- max_length=max_length,
40
- min_length=min_length,
41
- num_beams=4,
42
- early_stopping=True
43
- )[0]['summary_text']
44
- return summary
45
- except Exception as e:
46
- st.error(f"Summarization error: {e}")
47
- return "Could not generate summary."
48
-
49
- def parallel_summarize(text, length_percentage=0.3):
50
- """
51
- Generate summaries in parallel using multiprocessing.
52
-
53
- Args:
54
- text: Input text to summarize
55
- length_percentage: Percentage of original text to use for summary
56
 
57
- Returns:
58
- Dictionary of summaries from different models
59
- """
60
- model_configs = [
61
- ("facebook/bart-large-cnn", "BART"),
62
- ("t5-large", "T5"),
63
- ("google/pegasus-cnn_dailymail", "Pegasus")
64
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- with multiprocessing.Pool(processes=min(len(model_configs), os.cpu_count())) as pool:
67
- args = [(load_model(model_name), text, length_percentage)
68
- for model_name, _ in model_configs]
 
 
 
 
69
 
70
- results = pool.starmap(generate_summary, args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- return {name: summary for (_, name), summary in zip(model_configs, results)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def main():
75
  st.set_page_config(
76
- page_title="Multi-Model Text Summarization",
77
- page_icon="📝",
78
  layout="wide"
79
  )
80
 
81
- # Title and Description
82
- st.title("🤖 Advanced Text Summarization")
83
- st.markdown("""
84
- Generate concise summaries using multiple state-of-the-art models.
85
- Intelligently adapts summary length based on input text.
86
- """)
87
 
88
- # Text Input
 
 
 
89
  text_input = st.text_area(
90
- "Paste your text here:",
91
- height=250,
92
- help="Enter the text you want to summarize"
93
  )
94
 
95
- # Length Control
96
- length_control = st.slider(
97
- "Summary Compression Rate",
98
- min_value=0.1,
99
- max_value=0.5,
100
- value=0.3,
101
- step=0.05,
102
- help="Adjust how much of the original text to keep in the summary"
103
- )
104
 
105
- if st.button("Generate Summaries", type="primary"):
 
 
 
 
 
 
 
 
 
106
  if not text_input:
107
- st.warning("Please enter some text to summarize.")
108
  return
109
 
110
- progress_text = st.empty()
111
  progress_bar = st.progress(0)
112
-
113
- stages = ["Initializing Models", "Running BART", "Running T5", "Running Pegasus", "Completed"]
114
 
115
  try:
116
- for i, stage in enumerate(stages[:-1], 1):
117
- progress_text.info(stage)
118
- progress_bar.progress(i * 20)
119
-
120
- if i == 2:
121
- summaries = parallel_summarize(text_input, length_control)
 
 
 
122
 
123
- progress_text.success("Summarization Complete!")
124
  progress_bar.progress(100)
 
125
 
126
- st.subheader("📝 Generated Summaries")
127
  cols = st.columns(3)
128
-
129
  for (col, (model, summary)) in zip(cols, summaries.items()):
130
  with col:
131
- st.markdown(f"### {model} Summary")
132
  st.write(summary)
133
- st.caption(f"Word Count: {len(summary.split())}")
134
 
135
  except Exception as e:
136
- st.error(f"An error occurred: {e}")
137
 
138
  finally:
139
- progress_text.empty()
140
  progress_bar.empty()
 
141
 
142
  if __name__ == "__main__":
143
  main()
 
1
  import streamlit as st
 
 
 
2
  import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ import concurrent.futures
5
+ import numpy as np
6
+ import psutil
7
+ import os
8
 
9
+ class UltraOptimizedSummarizer:
10
+ def __init__(self):
11
+ # Advanced caching and memory management
12
+ self.models = {}
13
+ self.tokenizers = {}
14
+ self.device = self._get_optimal_device()
15
+
16
+ def _get_optimal_device(self):
17
+ """Intelligently select the best computational device."""
18
+ if torch.cuda.is_available():
19
+ # Find the GPU with most free memory
20
+ gpu_memory = [torch.cuda.memory_allocated(i) for i in range(torch.cuda.device_count())]
21
+ best_gpu = np.argmin(gpu_memory)
22
+ return torch.device(f'cuda:{best_gpu}')
23
+ elif torch.backends.mps.is_available():
24
+ return torch.device('mps')
25
+ return torch.device('cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def _load_model(self, model_name):
28
+ """
29
+ Optimized model loading with advanced memory management.
30
+ Uses half-precision (float16) for reduced memory footprint.
31
+ """
32
+ if model_name in self.models:
33
+ return self.models[model_name], self.tokenizers[model_name]
34
+
35
+ try:
36
+ # Load tokenizer
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/tmp/huggingface_cache')
38
+
39
+ # Load model with optimization
40
+ model = AutoModelForSeq2SeqLM.from_pretrained(
41
+ model_name,
42
+ cache_dir='/tmp/huggingface_cache',
43
+ torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
44
+ low_cpu_mem_usage=True
45
+ ).to(self.device)
46
+
47
+ # Optional: Model compilation for additional speed (PyTorch 2.0+)
48
+ if hasattr(torch, 'compile'):
49
+ model = torch.compile(model)
50
+
51
+ self.models[model_name] = model
52
+ self.tokenizers[model_name] = tokenizer
53
+
54
+ return model, tokenizer
55
+
56
+ except Exception as e:
57
+ st.error(f"Model loading error for {model_name}: {e}")
58
+ return None, None
59
 
60
+ def summarize(self, text, model_name, max_length=150, min_length=50):
61
+ """
62
+ Ultra-optimized summarization with intelligent truncation.
63
+ """
64
+ model, tokenizer = self._load_model(model_name)
65
+ if not model or not tokenizer:
66
+ return "Summarization failed."
67
 
68
+ try:
69
+ # Intelligent text truncation
70
+ inputs = tokenizer(
71
+ text,
72
+ max_length=1024, # Prevent OOM errors
73
+ truncation=True,
74
+ return_tensors='pt'
75
+ ).to(self.device)
76
+
77
+ # Generate summary with optimized parameters
78
+ summary_ids = model.generate(
79
+ inputs['input_ids'],
80
+ num_beams=4,
81
+ max_length=max_length,
82
+ min_length=min_length,
83
+ early_stopping=True,
84
+ no_repeat_ngram_size=2,
85
+ do_sample=False
86
+ )
87
+
88
+ # Decode summary
89
+ summary = tokenizer.decode(
90
+ summary_ids[0],
91
+ skip_special_tokens=True
92
+ )
93
+
94
+ return summary
95
+
96
+ except Exception as e:
97
+ st.error(f"Summarization error for {model_name}: {e}")
98
+ return "Could not generate summary."
99
 
100
+ def parallel_summarize(self, text, max_length=150, min_length=50):
101
+ """
102
+ Concurrent summarization with advanced thread pooling.
103
+ """
104
+ model_configs = [
105
+ "facebook/bart-large-cnn",
106
+ "t5-large",
107
+ "google/pegasus-cnn_dailymail"
108
+ ]
109
+
110
+ # Dynamic thread count based on system resources
111
+ max_workers = min(
112
+ len(model_configs),
113
+ psutil.cpu_count(logical=False), # Physical cores
114
+ 4 # Cap at 4 to prevent resource exhaustion
115
+ )
116
+
117
+ # Use concurrent futures for true parallel processing
118
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
119
+ # Submit summarization tasks
120
+ future_to_model = {
121
+ executor.submit(
122
+ self.summarize,
123
+ text,
124
+ model,
125
+ max_length,
126
+ min_length
127
+ ): model for model in model_configs
128
+ }
129
+
130
+ # Collect results as they complete
131
+ summaries = {}
132
+ for future in concurrent.futures.as_completed(future_to_model):
133
+ model = future_to_model[future]
134
+ try:
135
+ summaries[model] = future.result()
136
+ except Exception as e:
137
+ summaries[model] = f"Error: {e}"
138
+
139
+ return summaries
140
 
141
  def main():
142
  st.set_page_config(
143
+ page_title="Ultra-Optimized Summarization",
144
+ page_icon="🚀",
145
  layout="wide"
146
  )
147
 
148
+ st.title("🔬 Hyper-Optimized Text Summarization")
 
 
 
 
 
149
 
150
+ # Initialize optimized summarizer
151
+ summarizer = UltraOptimizedSummarizer()
152
+
153
+ # Input and processing
154
  text_input = st.text_area(
155
+ "Enter text for advanced summarization:",
156
+ height=300
 
157
  )
158
 
159
+ # Advanced compression control
160
+ col1, col2 = st.columns(2)
161
+ with col1:
162
+ max_length = st.slider(
163
+ "Max Summary Length",
164
+ min_value=50,
165
+ max_value=300,
166
+ value=150
167
+ )
168
 
169
+ with col2:
170
+ compression_rate = st.slider(
171
+ "Compression Aggressiveness",
172
+ min_value=0.1,
173
+ max_value=0.5,
174
+ value=0.3,
175
+ step=0.05
176
+ )
177
+
178
+ if st.button("Generate Hyper-Optimized Summaries"):
179
  if not text_input:
180
+ st.warning("Please provide text to summarize.")
181
  return
182
 
183
+ # Progress tracking
184
  progress_bar = st.progress(0)
185
+ status_text = st.empty()
 
186
 
187
  try:
188
+ # Perform parallel summarization
189
+ status_text.info("Initializing ultra-optimized summarization...")
190
+ progress_bar.progress(20)
191
+
192
+ summaries = summarizer.parallel_summarize(
193
+ text_input,
194
+ max_length=max_length,
195
+ min_length=int(max_length * 0.5)
196
+ )
197
 
 
198
  progress_bar.progress(100)
199
+ status_text.success("Summarization Complete!")
200
 
201
+ # Display results
202
  cols = st.columns(3)
 
203
  for (col, (model, summary)) in zip(cols, summaries.items()):
204
  with col:
205
+ st.subheader(model.split('/')[-1].upper())
206
  st.write(summary)
 
207
 
208
  except Exception as e:
209
+ st.error(f"Optimization failed: {e}")
210
 
211
  finally:
 
212
  progress_bar.empty()
213
+ status_text.empty()
214
 
215
  if __name__ == "__main__":
216
  main()