File size: 4,638 Bytes
038a86f caf0283 d769310 0330532 53545b3 038a86f 0330532 038a86f 0330532 038a86f 0330532 d769310 0330532 d769310 0330532 6895495 038a86f 0330532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import streamlit as st
from transformers import pipeline
import traceback
# Use Hugging Face Spaces' recommended persistent storage
CACHE_DIR = os.path.join(os.getcwd(), "model_cache")
def ensure_cache_dir():
"""
Ensure the cache directory exists.
Returns:
str: Path to the cache directory
"""
os.makedirs(CACHE_DIR, exist_ok=True)
return CACHE_DIR
def load_pipelines():
"""
Load summarization pipelines with persistent caching.
Returns:
dict: Dictionary of model pipelines
"""
try:
# Ensure cache directory exists
cache_dir = ensure_cache_dir()
# Define model paths within the cache directory
bart_cache = os.path.join(cache_dir, "bart-large-cnn")
t5_cache = os.path.join(cache_dir, "t5-large")
pegasus_cache = os.path.join(cache_dir, "pegasus-cnn_dailymail")
# Load pipelines with explicit cache directories
bart_pipeline = pipeline(
"summarization",
model="facebook/bart-large-cnn",
cache_dir=bart_cache
)
t5_pipeline = pipeline(
"summarization",
model="t5-large",
cache_dir=t5_cache
)
pegasus_pipeline = pipeline(
"summarization",
model="google/pegasus-cnn_dailymail",
cache_dir=pegasus_cache
)
return {
'BART': bart_pipeline,
'T5': t5_pipeline,
'Pegasus': pegasus_pipeline
}
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error(traceback.format_exc())
return {}
def generate_summary(pipeline, text, model_name):
"""
Generate summary for a specific model with error handling.
Args:
pipeline: Hugging Face summarization pipeline
text (str): Input text to summarize
model_name (str): Name of the model
Returns:
str: Generated summary or error message
"""
try:
prompt = "Summarize the below paragraph"
summary = pipeline(f"{prompt}\n{text}",
max_length=150,
min_length=50,
length_penalty=2.0,
num_beams=4,
early_stopping=True)[0]['summary_text']
return summary
except Exception as e:
error_msg = f"Error in {model_name} summarization: {str(e)}"
st.error(error_msg)
return error_msg
def main():
st.title("Text Summarization with Pre-trained Models")
# Display cache directory info (optional)
st.info(f"Models will be cached in: {CACHE_DIR}")
# Text input
text_input = st.text_area("Enter text to summarize:")
# Generate button
if st.button("Generate Summary"):
if not text_input:
st.error("Please enter text to summarize.")
return
# Load pipelines
pipelines = load_pipelines()
if not pipelines:
st.error("Failed to load models. Please check your internet connection or try again later.")
return
# Create columns for progressive display
bart_col, t5_col, pegasus_col = st.columns(3)
# BART Summary
with bart_col:
with st.spinner('Generating BART Summary...'):
bart_progress = st.progress(0)
bart_progress.progress(50)
bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
bart_progress.progress(100)
st.subheader("BART Summary")
st.write(bart_summary)
# T5 Summary
with t5_col:
with st.spinner('Generating T5 Summary...'):
t5_progress = st.progress(0)
t5_progress.progress(50)
t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
t5_progress.progress(100)
st.subheader("T5 Summary")
st.write(t5_summary)
# Pegasus Summary
with pegasus_col:
with st.spinner('Generating Pegasus Summary...'):
pegasus_progress = st.progress(0)
pegasus_progress.progress(50)
pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
pegasus_progress.progress(100)
st.subheader("Pegasus Summary")
st.write(pegasus_summary)
if __name__ == "__main__":
main() |