|
import os |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import traceback |
|
|
|
|
|
CACHE_DIR = os.path.join(os.getcwd(), "model_cache") |
|
|
|
def ensure_cache_dir(): |
|
""" |
|
Ensure the cache directory exists. |
|
|
|
Returns: |
|
str: Path to the cache directory |
|
""" |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
return CACHE_DIR |
|
|
|
def load_model_and_tokenizer(model_name): |
|
""" |
|
Load model and tokenizer with persistent caching. |
|
|
|
Args: |
|
model_name (str): Name of the model to load |
|
|
|
Returns: |
|
tuple: (model, tokenizer) |
|
""" |
|
try: |
|
|
|
cache_dir = ensure_cache_dir() |
|
|
|
|
|
model_cache_path = os.path.join(cache_dir, model_name.replace('/', '_')) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=model_cache_path |
|
) |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
cache_dir=model_cache_path |
|
) |
|
|
|
return model, tokenizer |
|
except Exception as e: |
|
st.error(f"Error loading {model_name}: {str(e)}") |
|
st.error(traceback.format_exc()) |
|
return None, None |
|
|
|
def generate_summary(model, tokenizer, text, max_length=150): |
|
""" |
|
Generate summary using a specific model and tokenizer. |
|
|
|
Args: |
|
model: Hugging Face model |
|
tokenizer: Hugging Face tokenizer |
|
text (str): Input text to summarize |
|
max_length (int): Maximum length of summary |
|
|
|
Returns: |
|
str: Generated summary |
|
""" |
|
try: |
|
|
|
inputs = tokenizer( |
|
f"summarize: {text}", |
|
max_length=512, |
|
return_tensors="pt", |
|
truncation=True |
|
) |
|
|
|
|
|
summary_ids = model.generate( |
|
inputs.input_ids, |
|
num_beams=4, |
|
max_length=max_length, |
|
early_stopping=True |
|
) |
|
|
|
|
|
summary = tokenizer.decode( |
|
summary_ids[0], |
|
skip_special_tokens=True |
|
) |
|
|
|
return summary |
|
except Exception as e: |
|
error_msg = f"Error in summarization: {str(e)}" |
|
st.error(error_msg) |
|
return error_msg |
|
|
|
def main(): |
|
st.title("Text Summarization with Pre-trained Models") |
|
|
|
|
|
st.info(f"Models will be cached in: {CACHE_DIR}") |
|
|
|
|
|
models_to_load = { |
|
'BART': 'facebook/bart-large-cnn', |
|
'T5': 't5-large', |
|
'Pegasus': 'google/pegasus-cnn_dailymail' |
|
} |
|
|
|
|
|
text_input = st.text_area("Enter text to summarize:") |
|
|
|
|
|
if st.button("Generate Summary"): |
|
if not text_input: |
|
st.error("Please enter text to summarize.") |
|
return |
|
|
|
|
|
bart_col, t5_col, pegasus_col = st.columns(3) |
|
|
|
|
|
def process_model(col, model_name, model_path): |
|
with col: |
|
with st.spinner(f'Generating {model_name} Summary...'): |
|
progress = st.progress(0) |
|
progress.progress(50) |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(model_path) |
|
|
|
if model and tokenizer: |
|
|
|
summary = generate_summary(model, tokenizer, text_input) |
|
|
|
progress.progress(100) |
|
st.subheader(f"{model_name} Summary") |
|
st.write(summary) |
|
else: |
|
st.error(f"Failed to load {model_name} model") |
|
|
|
|
|
process_model(bart_col, 'BART', 'facebook/bart-large-cnn') |
|
process_model(t5_col, 'T5', 't5-large') |
|
process_model(pegasus_col, 'Pegasus', 'google/pegasus-cnn_dailymail') |
|
|
|
if __name__ == "__main__": |
|
main() |