Dhanush S Gowda
Update app.py
a37af66 verified
raw
history blame
2.16 kB
import streamlit as st
from transformers import pipeline
from concurrent.futures import ThreadPoolExecutor
import time
# Function to load a single model
def load_model(model_name):
if model_name == 'BART':
return pipeline("summarization", model="facebook/bart-large-cnn")
elif model_name == 'T5':
return pipeline("summarization", model="t5-large")
elif model_name == 'Pegasus':
return pipeline("summarization", model="google/pegasus-cnn_dailymail")
# Function to load all models concurrently
@st.cache_resource
def load_all_models():
model_names = ['BART', 'T5', 'Pegasus']
models = {}
with ThreadPoolExecutor() as executor:
futures = {executor.submit(load_model, name): name for name in model_names}
for future in futures:
model_name = futures[future]
models[model_name] = future.result()
return models
# Streamlit app layout
st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)")
text_input = st.text_area("Enter text to summarize:")
if text_input:
# Display word count of input text
word_count = len(text_input.split())
st.write(f"**Word Count:** {word_count}")
if st.button("Generate Summaries"):
with st.spinner("Loading models and generating summaries..."):
start_time = time.time()
models = load_all_models()
summaries = {}
for model_name, model in models.items():
summary = model(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text']
summaries[model_name] = summary
end_time = time.time()
st.subheader("Summaries")
for model_name, summary in summaries.items():
summary_word_count = len(summary.split())
st.write(f"**{model_name}**")
st.write(summary.replace('<n>', ''))
st.write(f"**Summary Word Count:** {summary_word_count}")
st.write("---")
st.write(f"**Total Time Taken:** {end_time - start_time:.2f} seconds")
else:
st.error("Please enter text to summarize.")