|
import streamlit as st |
|
import multiprocessing |
|
from transformers import pipeline |
|
import os |
|
import torch |
|
|
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
"""Efficiently load a summarization model.""" |
|
try: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
return pipeline("summarization", model=model_name, device=device) |
|
except Exception as e: |
|
st.error(f"Error loading model {model_name}: {e}") |
|
return None |
|
|
|
def generate_summary(model, text, length_percentage=0.3): |
|
""" |
|
Generate summary with intelligent length control. |
|
|
|
Args: |
|
model: Hugging Face summarization pipeline |
|
text: Input text to summarize |
|
length_percentage: Percentage of original text to use for summary |
|
|
|
Returns: |
|
Generated summary |
|
""" |
|
|
|
word_count = len(text.split()) |
|
max_length = max(50, int(word_count * length_percentage)) |
|
min_length = max(30, int(word_count * 0.1)) |
|
|
|
try: |
|
summary = model( |
|
text, |
|
max_length=max_length, |
|
min_length=min_length, |
|
num_beams=4, |
|
early_stopping=True |
|
)[0]['summary_text'] |
|
return summary |
|
except Exception as e: |
|
st.error(f"Summarization error: {e}") |
|
return "Could not generate summary." |
|
|
|
def parallel_summarize(text, length_percentage=0.3): |
|
""" |
|
Generate summaries in parallel using multiprocessing. |
|
|
|
Args: |
|
text: Input text to summarize |
|
length_percentage: Percentage of original text to use for summary |
|
|
|
Returns: |
|
Dictionary of summaries from different models |
|
""" |
|
model_configs = [ |
|
("facebook/bart-large-cnn", "BART"), |
|
("t5-large", "T5"), |
|
("google/pegasus-cnn_dailymail", "Pegasus") |
|
] |
|
|
|
with multiprocessing.Pool(processes=min(len(model_configs), os.cpu_count())) as pool: |
|
args = [(load_model(model_name), text, length_percentage) |
|
for model_name, _ in model_configs] |
|
|
|
results = pool.starmap(generate_summary, args) |
|
|
|
return {name: summary for (_, name), summary in zip(model_configs, results)} |
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="Multi-Model Text Summarization", |
|
page_icon="π", |
|
layout="wide" |
|
) |
|
|
|
|
|
st.title("π€ Advanced Text Summarization") |
|
st.markdown(""" |
|
Generate concise summaries using multiple state-of-the-art models. |
|
Intelligently adapts summary length based on input text. |
|
""") |
|
|
|
|
|
text_input = st.text_area( |
|
"Paste your text here:", |
|
height=250, |
|
help="Enter the text you want to summarize" |
|
) |
|
|
|
|
|
length_control = st.slider( |
|
"Summary Compression Rate", |
|
min_value=0.1, |
|
max_value=0.5, |
|
value=0.3, |
|
step=0.05, |
|
help="Adjust how much of the original text to keep in the summary" |
|
) |
|
|
|
if st.button("Generate Summaries", type="primary"): |
|
if not text_input: |
|
st.warning("Please enter some text to summarize.") |
|
return |
|
|
|
progress_text = st.empty() |
|
progress_bar = st.progress(0) |
|
|
|
stages = ["Initializing Models", "Running BART", "Running T5", "Running Pegasus", "Completed"] |
|
|
|
try: |
|
for i, stage in enumerate(stages[:-1], 1): |
|
progress_text.info(stage) |
|
progress_bar.progress(i * 20) |
|
|
|
if i == 2: |
|
summaries = parallel_summarize(text_input, length_control) |
|
|
|
progress_text.success("Summarization Complete!") |
|
progress_bar.progress(100) |
|
|
|
st.subheader("π Generated Summaries") |
|
cols = st.columns(3) |
|
|
|
for (col, (model, summary)) in zip(cols, summaries.items()): |
|
with col: |
|
st.markdown(f"### {model} Summary") |
|
st.write(summary) |
|
st.caption(f"Word Count: {len(summary.split())}") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
|
|
finally: |
|
progress_text.empty() |
|
progress_bar.empty() |
|
|
|
if __name__ == "__main__": |
|
main() |