Dhanush S Gowda
Update app.py
53545b3 verified
raw
history blame
1.59 kB
import streamlit as st
from transformers import pipeline
import os
# Set the cache directory
CACHE_DIR = os.getenv('HF_HOME', os.path.expanduser('~/.cache/huggingface/hub'))
# Function to load a single model
@st.cache_resource
def load_model(model_name):
if model_name == 'BART':
return pipeline("summarization", model="facebook/bart-large-cnn", cache_dir=CACHE_DIR)
elif model_name == 'T5':
return pipeline("summarization", model="t5-large", cache_dir=CACHE_DIR)
elif model_name == 'Pegasus':
return pipeline("summarization", model="google/pegasus-cnn_dailymail", cache_dir=CACHE_DIR)
# Streamlit app layout
st.title("Text Summarization with Pre-trained Models (BART, T5, Pegasus)")
text_input = st.text_area("Enter text to summarize:")
if text_input:
# Display word count of input text
word_count = len(text_input.split())
st.write(f"**Word Count:** {word_count}")
model_choice = st.selectbox("Choose a model:", ['BART', 'T5', 'Pegasus'])
if st.button("Generate Summary"):
with st.spinner(f"Generating summary using {model_choice}..."):
summarizer = load_model(model_choice)
summary = summarizer(text_input, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text']
summary_word_count = len(summary.split())
st.subheader(f"Summary using {model_choice}")
st.write(summary.replace('<n>', ''))
st.write(f"**Summary Word Count:** {summary_word_count}")
else:
st.error("Please enter text to summarize.")