Dhanush S Gowda
Update app.py
756850c verified
raw
history blame
7.12 kB
import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import concurrent.futures
import numpy as np
import psutil
import os
class UltraOptimizedSummarizer:
def __init__(self):
# Advanced caching and memory management
self.models = {}
self.tokenizers = {}
self.device = self._get_optimal_device()
def _get_optimal_device(self):
"""Intelligently select the best computational device."""
if torch.cuda.is_available():
# Find the GPU with most free memory
gpu_memory = [torch.cuda.memory_allocated(i) for i in range(torch.cuda.device_count())]
best_gpu = np.argmin(gpu_memory)
return torch.device(f'cuda:{best_gpu}')
elif torch.backends.mps.is_available():
return torch.device('mps')
return torch.device('cpu')
def _load_model(self, model_name):
"""
Optimized model loading with advanced memory management.
Uses half-precision (float16) for reduced memory footprint.
"""
if model_name in self.models:
return self.models[model_name], self.tokenizers[model_name]
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='/tmp/huggingface_cache')
# Load model with optimization
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir='/tmp/huggingface_cache',
torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
low_cpu_mem_usage=True
).to(self.device)
# Optional: Model compilation for additional speed (PyTorch 2.0+)
if hasattr(torch, 'compile'):
model = torch.compile(model)
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
return model, tokenizer
except Exception as e:
st.error(f"Model loading error for {model_name}: {e}")
return None, None
def summarize(self, text, model_name, max_length=150, min_length=50):
"""
Ultra-optimized summarization with intelligent truncation.
"""
model, tokenizer = self._load_model(model_name)
if not model or not tokenizer:
return "Summarization failed."
try:
# Intelligent text truncation
inputs = tokenizer(
text,
max_length=1024, # Prevent OOM errors
truncation=True,
return_tensors='pt'
).to(self.device)
# Generate summary with optimized parameters
summary_ids = model.generate(
inputs['input_ids'],
num_beams=4,
max_length=max_length,
min_length=min_length,
early_stopping=True,
no_repeat_ngram_size=2,
do_sample=False
)
# Decode summary
summary = tokenizer.decode(
summary_ids[0],
skip_special_tokens=True
)
return summary
except Exception as e:
st.error(f"Summarization error for {model_name}: {e}")
return "Could not generate summary."
def parallel_summarize(self, text, max_length=150, min_length=50):
"""
Concurrent summarization with advanced thread pooling.
"""
model_configs = [
"facebook/bart-large-cnn",
"t5-large",
"google/pegasus-cnn_dailymail"
]
# Dynamic thread count based on system resources
max_workers = min(
len(model_configs),
psutil.cpu_count(logical=False), # Physical cores
4 # Cap at 4 to prevent resource exhaustion
)
# Use concurrent futures for true parallel processing
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit summarization tasks
future_to_model = {
executor.submit(
self.summarize,
text,
model,
max_length,
min_length
): model for model in model_configs
}
# Collect results as they complete
summaries = {}
for future in concurrent.futures.as_completed(future_to_model):
model = future_to_model[future]
try:
summaries[model] = future.result()
except Exception as e:
summaries[model] = f"Error: {e}"
return summaries
def main():
st.set_page_config(
page_title="Ultra-Optimized Summarization",
page_icon="πŸš€",
layout="wide"
)
st.title("πŸ”¬ Hyper-Optimized Text Summarization")
# Initialize optimized summarizer
summarizer = UltraOptimizedSummarizer()
# Input and processing
text_input = st.text_area(
"Enter text for advanced summarization:",
height=300
)
# Advanced compression control
col1, col2 = st.columns(2)
with col1:
max_length = st.slider(
"Max Summary Length",
min_value=50,
max_value=300,
value=150
)
with col2:
compression_rate = st.slider(
"Compression Aggressiveness",
min_value=0.1,
max_value=0.5,
value=0.3,
step=0.05
)
if st.button("Generate Hyper-Optimized Summaries"):
if not text_input:
st.warning("Please provide text to summarize.")
return
# Progress tracking
progress_bar = st.progress(0)
status_text = st.empty()
try:
# Perform parallel summarization
status_text.info("Initializing ultra-optimized summarization...")
progress_bar.progress(20)
summaries = summarizer.parallel_summarize(
text_input,
max_length=max_length,
min_length=int(max_length * 0.5)
)
progress_bar.progress(100)
status_text.success("Summarization Complete!")
# Display results
cols = st.columns(3)
for (col, (model, summary)) in zip(cols, summaries.items()):
with col:
st.subheader(model.split('/')[-1].upper())
st.write(summary)
except Exception as e:
st.error(f"Optimization failed: {e}")
finally:
progress_bar.empty()
status_text.empty()
if __name__ == "__main__":
main()