|
import streamlit as st |
|
from transformers import pipeline |
|
import time |
|
|
|
|
|
@st.cache_resource |
|
def load_pipeline(model_name): |
|
if model_name == 'BART': |
|
return pipeline("summarization", model="facebook/bart-large-cnn") |
|
elif model_name == 'T5': |
|
return pipeline("summarization", model="t5-large") |
|
elif model_name == 'Pegasus': |
|
return pipeline("summarization", model="google/pegasus-cnn_dailymail") |
|
|
|
|
|
st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)") |
|
|
|
text_input = st.text_area("Enter text to summarize:") |
|
|
|
if text_input: |
|
|
|
word_count = len(text_input.split()) |
|
st.write(f"**Word Count:** {word_count}") |
|
|
|
model_choice = st.selectbox("Choose a model:", ['BART', 'T5', 'Pegasus']) |
|
|
|
if st.button("Generate Summary"): |
|
with st.spinner(f"Generating summary using {model_choice}..."): |
|
start_time = time.time() |
|
summarizer = load_pipeline(model_choice) |
|
summary = summarizer(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text'] |
|
end_time = time.time() |
|
summary_word_count = len(summary.split()) |
|
|
|
st.subheader(f"Summary using {model_choice}") |
|
st.write(summary.replace('<n>', '')) |
|
st.write(f"**Summary Word Count:** {summary_word_count}") |
|
st.write(f"**Time Taken:** {end_time - start_time:.2f} seconds") |
|
else: |
|
st.error("Please enter text to summarize.") |
|
|