Dhanush S Gowda
Update app.py
038a86f verified
raw
history blame
4.64 kB
import os
import streamlit as st
from transformers import pipeline
import traceback
# Use Hugging Face Spaces' recommended persistent storage
CACHE_DIR = os.path.join(os.getcwd(), "model_cache")
def ensure_cache_dir():
"""
Ensure the cache directory exists.
Returns:
str: Path to the cache directory
"""
os.makedirs(CACHE_DIR, exist_ok=True)
return CACHE_DIR
def load_pipelines():
"""
Load summarization pipelines with persistent caching.
Returns:
dict: Dictionary of model pipelines
"""
try:
# Ensure cache directory exists
cache_dir = ensure_cache_dir()
# Define model paths within the cache directory
bart_cache = os.path.join(cache_dir, "bart-large-cnn")
t5_cache = os.path.join(cache_dir, "t5-large")
pegasus_cache = os.path.join(cache_dir, "pegasus-cnn_dailymail")
# Load pipelines with explicit cache directories
bart_pipeline = pipeline(
"summarization",
model="facebook/bart-large-cnn",
cache_dir=bart_cache
)
t5_pipeline = pipeline(
"summarization",
model="t5-large",
cache_dir=t5_cache
)
pegasus_pipeline = pipeline(
"summarization",
model="google/pegasus-cnn_dailymail",
cache_dir=pegasus_cache
)
return {
'BART': bart_pipeline,
'T5': t5_pipeline,
'Pegasus': pegasus_pipeline
}
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error(traceback.format_exc())
return {}
def generate_summary(pipeline, text, model_name):
"""
Generate summary for a specific model with error handling.
Args:
pipeline: Hugging Face summarization pipeline
text (str): Input text to summarize
model_name (str): Name of the model
Returns:
str: Generated summary or error message
"""
try:
prompt = "Summarize the below paragraph"
summary = pipeline(f"{prompt}\n{text}",
max_length=150,
min_length=50,
length_penalty=2.0,
num_beams=4,
early_stopping=True)[0]['summary_text']
return summary
except Exception as e:
error_msg = f"Error in {model_name} summarization: {str(e)}"
st.error(error_msg)
return error_msg
def main():
st.title("Text Summarization with Pre-trained Models")
# Display cache directory info (optional)
st.info(f"Models will be cached in: {CACHE_DIR}")
# Text input
text_input = st.text_area("Enter text to summarize:")
# Generate button
if st.button("Generate Summary"):
if not text_input:
st.error("Please enter text to summarize.")
return
# Load pipelines
pipelines = load_pipelines()
if not pipelines:
st.error("Failed to load models. Please check your internet connection or try again later.")
return
# Create columns for progressive display
bart_col, t5_col, pegasus_col = st.columns(3)
# BART Summary
with bart_col:
with st.spinner('Generating BART Summary...'):
bart_progress = st.progress(0)
bart_progress.progress(50)
bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
bart_progress.progress(100)
st.subheader("BART Summary")
st.write(bart_summary)
# T5 Summary
with t5_col:
with st.spinner('Generating T5 Summary...'):
t5_progress = st.progress(0)
t5_progress.progress(50)
t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
t5_progress.progress(100)
st.subheader("T5 Summary")
st.write(t5_summary)
# Pegasus Summary
with pegasus_col:
with st.spinner('Generating Pegasus Summary...'):
pegasus_progress = st.progress(0)
pegasus_progress.progress(50)
pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
pegasus_progress.progress(100)
st.subheader("Pegasus Summary")
st.write(pegasus_summary)
if __name__ == "__main__":
main()