import os import streamlit as st from transformers import pipeline import traceback # Use Hugging Face Spaces' recommended persistent storage CACHE_DIR = os.path.join(os.getcwd(), "model_cache") def ensure_cache_dir(): """ Ensure the cache directory exists. Returns: str: Path to the cache directory """ os.makedirs(CACHE_DIR, exist_ok=True) return CACHE_DIR def load_pipelines(): """ Load summarization pipelines with persistent caching. Returns: dict: Dictionary of model pipelines """ try: # Ensure cache directory exists cache_dir = ensure_cache_dir() # Define model paths within the cache directory bart_cache = os.path.join(cache_dir, "bart-large-cnn") t5_cache = os.path.join(cache_dir, "t5-large") pegasus_cache = os.path.join(cache_dir, "pegasus-cnn_dailymail") # Load pipelines with explicit cache directories bart_pipeline = pipeline( "summarization", model="facebook/bart-large-cnn", cache_dir=bart_cache ) t5_pipeline = pipeline( "summarization", model="t5-large", cache_dir=t5_cache ) pegasus_pipeline = pipeline( "summarization", model="google/pegasus-cnn_dailymail", cache_dir=pegasus_cache ) return { 'BART': bart_pipeline, 'T5': t5_pipeline, 'Pegasus': pegasus_pipeline } except Exception as e: st.error(f"Error loading models: {str(e)}") st.error(traceback.format_exc()) return {} def generate_summary(pipeline, text, model_name): """ Generate summary for a specific model with error handling. Args: pipeline: Hugging Face summarization pipeline text (str): Input text to summarize model_name (str): Name of the model Returns: str: Generated summary or error message """ try: prompt = "Summarize the below paragraph" summary = pipeline(f"{prompt}\n{text}", max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)[0]['summary_text'] return summary except Exception as e: error_msg = f"Error in {model_name} summarization: {str(e)}" st.error(error_msg) return error_msg def main(): st.title("Text Summarization with Pre-trained Models") # Display cache directory info (optional) st.info(f"Models will be cached in: {CACHE_DIR}") # Text input text_input = st.text_area("Enter text to summarize:") # Generate button if st.button("Generate Summary"): if not text_input: st.error("Please enter text to summarize.") return # Load pipelines pipelines = load_pipelines() if not pipelines: st.error("Failed to load models. Please check your internet connection or try again later.") return # Create columns for progressive display bart_col, t5_col, pegasus_col = st.columns(3) # BART Summary with bart_col: with st.spinner('Generating BART Summary...'): bart_progress = st.progress(0) bart_progress.progress(50) bart_summary = generate_summary(pipelines['BART'], text_input, 'BART') bart_progress.progress(100) st.subheader("BART Summary") st.write(bart_summary) # T5 Summary with t5_col: with st.spinner('Generating T5 Summary...'): t5_progress = st.progress(0) t5_progress.progress(50) t5_summary = generate_summary(pipelines['T5'], text_input, 'T5') t5_progress.progress(100) st.subheader("T5 Summary") st.write(t5_summary) # Pegasus Summary with pegasus_col: with st.spinner('Generating Pegasus Summary...'): pegasus_progress = st.progress(0) pegasus_progress.progress(50) pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus') pegasus_progress.progress(100) st.subheader("Pegasus Summary") st.write(pegasus_summary) if __name__ == "__main__": main()