Dhanush S Gowda
Update app.py
de13ccb verified
import os
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import traceback
# Use Hugging Face Spaces' recommended persistent storage
CACHE_DIR = os.path.join(os.getcwd(), "model_cache")
def ensure_cache_dir():
"""
Ensure the cache directory exists.
Returns:
str: Path to the cache directory
"""
os.makedirs(CACHE_DIR, exist_ok=True)
return CACHE_DIR
def load_model_and_tokenizer(model_name):
"""
Load model and tokenizer with persistent caching.
Args:
model_name (str): Name of the model to load
Returns:
tuple: (model, tokenizer)
"""
try:
# Ensure cache directory exists
cache_dir = ensure_cache_dir()
# Construct full cache path for this model
model_cache_path = os.path.join(cache_dir, model_name.replace('/', '_'))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=model_cache_path
)
# Load model
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
cache_dir=model_cache_path
)
return model, tokenizer
except Exception as e:
st.error(f"Error loading {model_name}: {str(e)}")
st.error(traceback.format_exc())
return None, None
def generate_summary(model, tokenizer, text, max_length=150):
"""
Generate summary using a specific model and tokenizer.
Args:
model: Hugging Face model
tokenizer: Hugging Face tokenizer
text (str): Input text to summarize
max_length (int): Maximum length of summary
Returns:
str: Generated summary
"""
try:
# Prepare input
inputs = tokenizer(
f"summarize: {text}",
max_length=512,
return_tensors="pt",
truncation=True
)
# Generate summary
summary_ids = model.generate(
inputs.input_ids,
num_beams=4,
max_length=max_length,
early_stopping=True
)
# Decode summary
summary = tokenizer.decode(
summary_ids[0],
skip_special_tokens=True
)
return summary
except Exception as e:
error_msg = f"Error in summarization: {str(e)}"
st.error(error_msg)
return error_msg
def main():
st.title("Text Summarization with Pre-trained Models")
# Display cache directory info (optional)
st.info(f"Models will be cached in: {CACHE_DIR}")
# Define models
models_to_load = {
'BART': 'facebook/bart-large-cnn',
'T5': 't5-large',
'Pegasus': 'google/pegasus-cnn_dailymail'
}
# Text input
text_input = st.text_area("Enter text to summarize:")
# Generate button
if st.button("Generate Summary"):
if not text_input:
st.error("Please enter text to summarize.")
return
# Create columns for progressive display
bart_col, t5_col, pegasus_col = st.columns(3)
# Function to process each model
def process_model(col, model_name, model_path):
with col:
with st.spinner(f'Generating {model_name} Summary...'):
progress = st.progress(0)
progress.progress(50)
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(model_path)
if model and tokenizer:
# Generate summary
summary = generate_summary(model, tokenizer, text_input)
progress.progress(100)
st.subheader(f"{model_name} Summary")
st.write(summary)
else:
st.error(f"Failed to load {model_name} model")
# Process each model
process_model(bart_col, 'BART', 'facebook/bart-large-cnn')
process_model(t5_col, 'T5', 't5-large')
process_model(pegasus_col, 'Pegasus', 'google/pegasus-cnn_dailymail')
if __name__ == "__main__":
main()