Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from peft import PeftModel | |
from text_processing import TextProcessor | |
import gc | |
from pathlib import Path | |
# Configure page | |
st.set_page_config( | |
page_title="Biomedical Papers Analysis", | |
page_icon="π¬", | |
layout="wide" | |
) | |
# Initialize session state | |
if 'processed_data' not in st.session_state: | |
st.session_state.processed_data = None | |
if 'summaries' not in st.session_state: | |
st.session_state.summaries = None | |
if 'text_processor' not in st.session_state: | |
st.session_state.text_processor = None | |
if 'processing_started' not in st.session_state: | |
st.session_state.processing_started = False | |
if 'focused_summary_generated' not in st.session_state: | |
st.session_state.focused_summary_generated = False | |
def load_model(model_type): | |
"""Load appropriate model based on type with proper memory management""" | |
try: | |
# Clear any existing cached data | |
torch.cuda.empty_cache() | |
gc.collect() | |
if model_type == "summarize": | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/bart-large-cnn", | |
cache_dir="./models", | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32 | |
) | |
model = PeftModel.from_pretrained( | |
base_model, | |
"pendar02/results", | |
device_map="auto", | |
torch_dtype=torch.float32 | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"facebook/bart-large-cnn", | |
cache_dir="./models" | |
) | |
else: # question_focused | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models", | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32 | |
) | |
model = PeftModel.from_pretrained( | |
base_model, | |
"pendar02/biobart-finetune", | |
device_map="auto", | |
torch_dtype=torch.float32 | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models" | |
) | |
model.eval() | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
raise | |
def cleanup_model(model, tokenizer): | |
"""Properly cleanup model resources""" | |
try: | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception: | |
pass | |
def process_excel(uploaded_file): | |
"""Process uploaded Excel file""" | |
try: | |
df = pd.read_excel(uploaded_file) | |
required_columns = ['Abstract', 'Article Title', 'Authors', | |
'Source Title', 'Publication Year', 'DOI'] | |
# Check required columns | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
st.error(f"Missing required columns: {', '.join(missing_columns)}") | |
return None | |
return df[required_columns] | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
return None | |
def preprocess_text(text): | |
"""Preprocess text to add appropriate formatting before summarization""" | |
if not isinstance(text, str) or not text.strip(): | |
return text | |
# Split text into sentences (basic implementation) | |
sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')] | |
# Remove empty sentences | |
sentences = [s for s in sentences if s] | |
# Join with proper line breaks | |
formatted_text = '\n'.join(sentences) | |
return formatted_text | |
def generate_summary(text, model, tokenizer): | |
"""Generate summary for single abstract""" | |
if not isinstance(text, str) or not text.strip(): | |
return "No abstract available to summarize." | |
# Check if abstract is too short | |
word_count = len(text.split()) | |
if word_count < 50: # Threshold for "short" abstracts | |
return text # Return original text for very short abstracts | |
# Preprocess the text first | |
formatted_text = preprocess_text(text) | |
# Adjust generation parameters based on input length | |
max_length = min(150, word_count + 50) # Dynamic max length | |
min_length = min(50, word_count) # Dynamic min length | |
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": max_length, | |
"min_length": min_length, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True, | |
"no_repeat_ngram_size": 3 # Prevent repetition of phrases | |
} | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Post-process summary | |
if summary.lower() == text.lower() or len(summary.split()) / word_count > 0.9: | |
return text # Return original if summary is too similar | |
return summary | |
def generate_focused_summary(question, abstracts, model, tokenizer): | |
"""Generate focused summary based on question""" | |
# Preprocess each abstract | |
formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts] | |
combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts) | |
inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 200, | |
"min_length": 50, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True | |
} | |
) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def main(): | |
st.title("π¬ Biomedical Papers Analysis") | |
# File upload section | |
uploaded_file = st.file_uploader( | |
"Upload Excel file containing papers", | |
type=['xlsx', 'xls'], | |
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" | |
) | |
# Question input - moved up but hidden initially | |
question_container = st.empty() | |
question = "" | |
if uploaded_file is not None: | |
# Process Excel file | |
if st.session_state.processed_data is None: | |
with st.spinner("Processing file..."): | |
df = process_excel(uploaded_file) | |
if df is not None: | |
st.session_state.processed_data = df.dropna(subset=["Abstract"]) | |
if st.session_state.processed_data is not None: | |
df = st.session_state.processed_data | |
st.write(f"π Loaded {len(df)} papers with abstracts") | |
# Get question before processing | |
with question_container: | |
question = st.text_input( | |
"Enter your research question (optional):", | |
help="If provided, a question-focused summary will be generated after individual summaries" | |
) | |
# Single button for both processes | |
if not st.session_state.get('processing_started', False): | |
if st.button("Start Analysis"): | |
st.session_state.processing_started = True | |
# Show processing status and results | |
if st.session_state.get('processing_started', False): | |
# Individual Summaries Section | |
st.header("π Individual Paper Summaries") | |
if st.session_state.summaries is None: | |
try: | |
with st.spinner("Generating summaries..."): | |
# Load summarization model | |
model, tokenizer = load_model("summarize") | |
# Process abstracts with real-time updates | |
summaries = [] | |
progress_bar = st.progress(0) | |
summary_display = st.empty() | |
for i, (_, row) in enumerate(df.iterrows()): | |
summary = generate_summary(row['Abstract'], model, tokenizer) | |
summaries.append(summary) | |
# Update progress and show current summary | |
progress = (i + 1) / len(df) | |
progress_bar.progress(progress) | |
summary_display.write(f"Processing paper {i+1}/{len(df)}:\n{row['Article Title']}") | |
st.session_state.summaries = summaries | |
# Cleanup first model | |
cleanup_model(model, tokenizer) | |
except Exception as e: | |
st.error(f"Error generating summaries: {str(e)}") | |
# Display summaries with improved sorting | |
if st.session_state.summaries is not None: | |
col1, col2 = st.columns(2) | |
with col1: | |
sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title'] | |
sort_column = st.selectbox("Sort by:", sort_options) | |
with col2: | |
ascending = st.checkbox("Ascending order", True) | |
# Create display dataframe with formatted year | |
display_df = df.copy() | |
display_df['Summary'] = st.session_state.summaries | |
display_df['Publication Year'] = display_df['Publication Year'].astype(int) | |
sorted_df = display_df.sort_values(by=sort_column, ascending=ascending) | |
# Apply custom formatting | |
st.markdown(""" | |
<style> | |
.stDataFrame { | |
font-size: 16px; | |
} | |
.stDataFrame td { | |
white-space: normal !important; | |
padding: 8px !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.dataframe( | |
sorted_df[['Article Title', 'Authors', 'Source Title', | |
'Publication Year', 'DOI', 'Summary']], | |
hide_index=True | |
) | |
# Question-focused Summary Section (only if question provided) | |
if question.strip(): | |
st.header("β Question-focused Summary") | |
if not st.session_state.get('focused_summary_generated', False): | |
try: | |
with st.spinner("Analyzing relevant papers..."): | |
# Initialize text processor if needed | |
if st.session_state.text_processor is None: | |
st.session_state.text_processor = TextProcessor() | |
# Find relevant abstracts | |
results = st.session_state.text_processor.find_most_relevant_abstracts( | |
question, | |
df['Abstract'].tolist(), | |
top_k=5 | |
) | |
# Load question-focused model | |
model, tokenizer = load_model("question_focused") | |
# Generate focused summary | |
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() | |
focused_summary = generate_focused_summary( | |
question, | |
relevant_abstracts, | |
model, | |
tokenizer | |
) | |
# Store results | |
st.session_state.focused_summary = focused_summary | |
st.session_state.relevant_papers = df.iloc[results['top_indices']] | |
st.session_state.relevance_scores = results['scores'] | |
st.session_state.focused_summary_generated = True | |
# Cleanup second model | |
cleanup_model(model, tokenizer) | |
except Exception as e: | |
st.error(f"Error generating focused summary: {str(e)}") | |
# Display focused summary results | |
if st.session_state.get('focused_summary_generated', False): | |
st.subheader("Summary") | |
st.write(st.session_state.focused_summary) | |
st.subheader("Most Relevant Papers") | |
relevant_papers = st.session_state.relevant_papers[ | |
['Article Title', 'Authors', 'Publication Year', 'DOI'] | |
].copy() | |
relevant_papers['Relevance Score'] = st.session_state.relevance_scores | |
relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) | |
st.dataframe(relevant_papers, hide_index=True) | |
if __name__ == "__main__": | |
main() |