biomedical / app.py
pendar02's picture
Update app.py
dee9a31 verified
import streamlit as st
import pandas as pd
import torch
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from text_processing import TextProcessor
import gc
from pathlib import Path
import concurrent.futures
import time
import nltk
from nltk.tokenize import sent_tokenize
from concurrent.futures import ThreadPoolExecutor # Add this import
nltk.download('punkt')
# Configure page
st.set_page_config(
page_title="Biomedical Papers Analysis",
page_icon="πŸ”¬",
layout="wide"
)
# Initialize session state
if 'relevant_papers' not in st.session_state:
st.session_state.relevant_papers = None
if 'relevance_scores' not in st.session_state:
st.session_state.relevance_scores = None
if 'processed_data' not in st.session_state:
st.session_state.processed_data = None
if 'summaries' not in st.session_state:
st.session_state.summaries = None
if 'text_processor' not in st.session_state:
st.session_state.text_processor = None
if 'processing_started' not in st.session_state:
st.session_state.processing_started = False
if 'focused_summary_generated' not in st.session_state:
st.session_state.focused_summary_generated = False
if 'current_model' not in st.session_state:
st.session_state.current_model = None
if 'current_tokenizer' not in st.session_state:
st.session_state.current_tokenizer = None
if 'model_type' not in st.session_state:
st.session_state.model_type = None
if 'focused_summary' not in st.session_state:
st.session_state.focused_summary = None
# TextProcessor class definition
try:
from text_processing import TextProcessor
except ImportError:
class TextProcessor:
def find_most_relevant_abstracts(self, question, abstracts, top_k=5):
return {
'top_indices': list(range(min(top_k, len(abstracts)))),
'scores': [1.0] * min(top_k, len(abstracts))
}
def load_model(model_type):
"""Load appropriate model based on type with proper memory management"""
try:
# Clear any existing cached data
gc.collect()
torch.cuda.empty_cache()
device = "cpu" # Force CPU usage
if model_type == "summarize":
# Load the new fine-tuned model directly
model = AutoModelForSeq2SeqLM.from_pretrained(
"pendar02/bart-large-pubmedd",
cache_dir="./models",
torch_dtype=torch.float32
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"pendar02/bart-large-pubmedd",
cache_dir="./models"
)
else: # question_focused
base_model = AutoModelForSeq2SeqLM.from_pretrained(
"GanjinZero/biobart-base",
cache_dir="./models",
torch_dtype=torch.float32
).to(device)
model = PeftModel.from_pretrained(
base_model,
"pendar02/biobart-finetune",
is_trainable=False
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"GanjinZero/biobart-base",
cache_dir="./models"
)
model.eval()
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
raise
def get_model(model_type):
"""Get model from session state or load if needed"""
try:
if (st.session_state.current_model is None or
st.session_state.model_type != model_type):
# Clean up existing model
if st.session_state.current_model is not None:
cleanup_model(st.session_state.current_model,
st.session_state.current_tokenizer)
# Load new model
model, tokenizer = load_model(model_type)
st.session_state.current_model = model
st.session_state.current_tokenizer = tokenizer
st.session_state.model_type = model_type
return st.session_state.current_model, st.session_state.current_tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.session_state.processing_started = False
return None, None
def cleanup_model(model, tokenizer):
"""Properly cleanup model resources"""
try:
del model
del tokenizer
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass
@st.cache_data
def process_excel(uploaded_file):
"""Process uploaded Excel file"""
try:
df = pd.read_excel(uploaded_file)
required_columns = ['Abstract', 'Article Title', 'Authors',
'Source Title', 'Publication Year', 'DOI',
'Times Cited, All Databases']
# Check required columns first
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error("❌ Missing required columns: " + ", ".join(missing_columns))
st.error("Please ensure your Excel file contains all required columns.")
return None
# Only proceed with validation if all required columns exist
if len(df) > 5:
st.error("❌ Your file contains more than 5 papers. Please upload a file with maximum 5 papers.")
return None
# Now safe to validate structure as we know columns exist
is_valid, messages = validate_excel_structure(df)
if not is_valid:
for msg in messages:
st.error(f"❌ {msg}")
return None
return df[required_columns]
except Exception as e:
st.error(f"❌ Error reading file: {str(e)}")
st.error("Please check if your file is in the correct Excel format (.xlsx or .xls)")
return None
def validate_excel_structure(df):
"""Validate the structure and content of the Excel file"""
validation_messages = []
# Check for minimum content
if len(df) == 0:
validation_messages.append("File contains no data")
return False, validation_messages
try:
# Check publication year format - this is useful for sorting/filtering
df['Publication Year'] = pd.to_numeric(df['Publication Year'], errors='coerce')
if df['Publication Year'].isna().any():
validation_messages.append("Some publication years are invalid. Please ensure all years are in numeric format (e.g., 2024)")
else:
years = df['Publication Year'].dropna()
if len(years) > 0:
if years.min() < 1900 or years.max() > 2025:
validation_messages.append("Publication years must be between 1900 and 2025")
# For short abstracts - just show a warning
short_abstracts = df['Abstract'].fillna('').astype(str).str.len() < 50
if short_abstracts.any():
st.warning("ℹ️ Some abstracts are quite short, but will still be processed")
except Exception as e:
validation_messages.append(f"Error checking data format: {str(e)}")
return len(validation_messages) == 0, validation_messages
def preprocess_text(text):
"""Clean biomedical text by handling common formatting issues and standardizing structure."""
if not isinstance(text, str) or not text.strip():
return text
# Remove extra whitespace
text = ' '.join(text.split())
# Roman numeral conversion
roman_map = {'i': '1', 'ii': '2', 'iii': '3', 'iv': '4', 'v': '5',
'vi': '6', 'vii': '7', 'viii': '8', 'ix': '9', 'x': '10'}
def replace_roman(match):
roman = match.group(1).lower()
return f"({roman_map.get(roman, roman)})"
text = re.sub(r'\(([ivx]+)\)', replace_roman, text)
# Clean enumerated lists
for roman in roman_map:
text = re.sub(f"\\b{roman}\\)", f"{roman_map[roman]})", text, flags=re.IGNORECASE)
# Standardize section headers
section_patterns = {
r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ',
r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ',
r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ',
r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: ',
r'\b(?:results?\s+and\s+conclusions?)\s*:?\s*(?=.*?:)': '', # Remove if followed by another section
r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)': 'Results: ' # Fix malformed combination
}
for pattern, replacement in section_patterns.items():
text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
# Ensure complete sentences in sections
text = re.sub(r'(?<=:)\s*([^.!?\n]*?)(?=\s*(?:[A-Z][^:]*:|$))',
lambda m: f" {m.group(1)}." if m.group(1) and not m.group(1).strip().endswith('.') else m.group(0),
text)
# Fix truncated sentences
text = re.sub(r'(?<=:)\s*([^.!?\n]*?)\s*(?=[A-Z][^:]*:)',
lambda m: f" {m.group(1)}." if m.group(1) else "",
text)
# Clean formatting
text = re.sub(r'[\r\n]+', ' ', text)
text = re.sub(r'\s*:\s*', ': ', text)
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text)
text = re.sub(r'β€’|\*|β– |β–‘|β†’|βœ“', '', text)
text = re.sub(r'\\n|\\r', ' ', text)
text = re.sub(r'\s*\(\s*', ' (', text)
text = re.sub(r'\s*\)\s*', ') ', text)
# Fix statistical notations
text = re.sub(r'p\s*[<=>]\s*0\.\d+', lambda m: m.group().replace(' ', ''), text)
text = re.sub(r'(?<=\d)\s*%', '%', text)
# Fix abbreviations spacing
text = re.sub(r'(?<=\w)vs\.(?=\w)', 'vs. ', text)
text = re.sub(r'(?<=\w)et\s+al\.(?=\w)', 'et al. ', text)
# Remove repeated punctuation
text = re.sub(r'([.!?])\1+', r'\1', text)
# Final cleanup
text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text)
text = text.strip()
if not text.endswith('.'):
text += '.'
return text
# """Enhanced text preprocessing with better section handling and prompt removal."""
# if not isinstance(text, str) or not text.strip():
# return text
# # Remove prompt leakage
# prompt_patterns = [
# r'Generate a structured summary addressing this question:.*?(?=\w+:)',
# r'Focus on key findings and methods\.',
# r'is a structured summary addressing this question:'
# ]
# for pattern in prompt_patterns:
# text = re.sub(pattern, '', text, flags=re.IGNORECASE)
# # Clean section headers more aggressively
# section_patterns = {
# r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ',
# r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ',
# r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ',
# r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: '
# }
# # Apply section normalization
# for pattern, replacement in section_patterns.items():
# text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
# # Remove combined section headers
# combined_headers = [
# r'\bmethods?\s+and\s+conclusions?\b',
# r'\bresults?\s+and\s+conclusions?\b',
# r'\bmaterials?\s+and\s+methods?\b'
# ]
# for pattern in combined_headers:
# text = re.sub(pattern, 'Methods:', text, flags=re.IGNORECASE)
# # Clean up sentences
# sentences = text.split('.')
# cleaned_sentences = []
# for sentence in sentences:
# # Remove redundant section references
# sentence = re.sub(r'\b(?:first|second|third|fourth|fifth)\s+sections?\b', '', sentence, flags=re.IGNORECASE)
# # Remove comparative phrases about section details
# sentence = re.sub(r'\b(?:more|less)\s+detailed\s+than.*', '', sentence, flags=re.IGNORECASE)
# if sentence.strip():
# cleaned_sentences.append(sentence.strip())
# # Rejoin and format
# text = '. '.join(cleaned_sentences)
# text = re.sub(r'\s+', ' ', text) # Remove extra spaces
# text = re.sub(r'\s*:\s*', ': ', text) # Fix spacing around colons
# return text.strip()
def generate_focused_summary(question, abstracts, model, tokenizer):
formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts if abstract.strip()]
abstracts_content = " [SEP] ".join(formatted_abstracts)
prompt = f"""
Provide a factual summary structured as:
- Background: Context and origin only if present
- Methods: Key procedures and approaches
- Results: Specific findings with numbers
- Conclusions: Main implications
Requirements:
- Present sections sequentially
- Merge related points within sections
- Complete all sentences
- Avoid repeating section headers
- Use original terminology
Content: {abstracts_content}
"""
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
summary_ids = model.generate(
**{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_length": 512,
"min_length": 200,
"num_beams": 4,
"length_penalty": 2.0,
"no_repeat_ngram_size": 3,
"temperature": 0.7,
"do_sample": False
}
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return post_process_summary(summary)
def post_process_summary(summary):
"""Post-process summary with improved section handling and formatting."""
if not summary:
return summary
valid_sections = ['Background', 'Methods', 'Results', 'Conclusions']
sections = {}
current_section = None
current_content = []
# Pre-clean section headers
summary = re.sub(r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)', 'Results:', summary, flags=re.IGNORECASE)
summary = re.sub(r'\bresults?\s*and\s*conclusions?\s*:', 'Results:', summary, flags=re.IGNORECASE)
# Process line by line
lines = [line.strip() for line in summary.split('.') if line.strip()]
for i, line in enumerate(lines):
section_match = None
for section in valid_sections:
if re.match(fr'\b{section}:', line, re.IGNORECASE):
section_match = section
break
if section_match:
if current_section:
content = ' '.join(current_content)
if content:
sections[current_section] = content
current_section = section_match
content = re.sub(fr'\b{section_match}:\s*', '', line, flags=re.IGNORECASE)
current_content = [content] if content else []
elif current_section:
# Prevent section header splitting
if not any(sect.lower() in line.lower() for sect in valid_sections):
current_content.append(line)
if current_section and current_content:
sections[current_section] = ' '.join(current_content)
# Format sections
formatted_sections = []
for section in valid_sections:
if section in sections:
content = sections[section].strip()
if content:
# Complete truncated sentences
if not re.search(r'[.!?]$', content):
if len(content.split()) >= 3: # Only complete if substantial
content += '.'
# Ensure capitalization
content = content[0].upper() + content[1:]
# Fix double periods
content = re.sub(r'\.+', '.', content)
formatted_sections.append(f"{section}: {content}")
return ' '.join(formatted_sections)
def process_papers_in_batches(df, model, tokenizer, batch_size=2):
"""Process papers in batches for better efficiency"""
abstracts = df['Abstract'].tolist()
summaries = []
with ThreadPoolExecutor(max_workers=4) as executor: # Parallel processing
future_to_batch = {executor.submit(generate_focused_summary, "Focus on key findings and methods.", [abstract], model, tokenizer): abstract for abstract in abstracts}
for future in future_to_batch:
summaries.append(future.result())
return summaries
def create_filter_controls(df, sort_column):
"""Create appropriate filter controls based on the selected column"""
filtered_df = df.copy()
if sort_column == 'Publication Year':
# Year range slider
year_min = int(df['Publication Year'].min())
year_max = int(df['Publication Year'].max())
col1, col2 = st.columns(2)
with col1:
start_year = st.number_input('From Year',
min_value=year_min,
max_value=year_max,
value=year_min)
with col2:
end_year = st.number_input('To Year',
min_value=year_min,
max_value=year_max,
value=year_max)
filtered_df = filtered_df[
(filtered_df['Publication Year'] >= start_year) &
(filtered_df['Publication Year'] <= end_year)
]
elif sort_column == 'Authors':
# Multi-select for authors
unique_authors = sorted(set(
author.strip()
for authors in df['Authors'].dropna()
for author in authors.split(';')
))
selected_authors = st.multiselect(
'Select Authors',
unique_authors
)
if selected_authors:
filtered_df = filtered_df[
filtered_df['Authors'].apply(
lambda x: any(author in str(x) for author in selected_authors)
)
]
elif sort_column == 'Source Title':
# Multi-select for source titles
unique_sources = sorted(df['Source Title'].unique())
selected_sources = st.multiselect(
'Select Sources',
unique_sources
)
if selected_sources:
filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)]
elif sort_column == 'Article Title':
# Only alphabetical sorting, no filtering
pass
return filtered_df
def main():
st.title("πŸ”¬ Biomedical Papers Analysis")
st.info("""
**πŸ“‹ File Upload Requirements:**
- Excel file (.xlsx or .xls) with **maximum 5 papers**
- Must contain these columns:
β€’ Abstract
β€’ Article Title
β€’ Authors
β€’ Source Title
β€’ Publication Year
β€’ DOI
β€’ Times Cited, All Databases
""")
# File upload section
uploaded_file = st.file_uploader(
"Upload Excel file containing papers (max 5 papers)",
type=['xlsx', 'xls'],
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
)
# Question input - moved up but hidden initially
question_container = st.empty()
question = ""
if uploaded_file is not None:
# Process Excel file
if st.session_state.processed_data is None:
with st.spinner("Processing file..."):
df = process_excel(uploaded_file)
if df is not None:
df = df.dropna(subset=["Abstract"])
if len(df) > 0:
st.session_state.processed_data = df
st.success(f"βœ… Successfully loaded {len(df)} papers with abstracts")
else:
st.error("❌ No valid papers found after processing. Please check your file.")
if st.session_state.processed_data is not None:
df = st.session_state.processed_data
st.write(f"πŸ“Š Loaded {len(df)} papers with abstracts")
# Get question before processing
with question_container:
question = st.text_input(
"Enter your research question (optional):",
help="If provided, a question-focused summary will be generated after individual summaries"
)
# Single button for both processes
if not st.session_state.get('processing_started', False):
if st.button("Start Analysis"):
st.session_state.processing_started = True
# Show processing status and results
if st.session_state.get('processing_started', False):
# Individual Summaries Section
st.header("πŸ“ Individual Paper Summaries")
# Generate summaries if not already done
if st.session_state.summaries is None:
try:
with st.spinner("Generating individual paper summaries..."):
model, tokenizer = get_model("summarize")
if model is None or tokenizer is None:
reset_processing_state()
return
start_time = time.time()
st.session_state.summaries = process_papers_in_batches(
df, model, tokenizer, batch_size=2
)
end_time = time.time()
st.write(f"Processing time: {end_time - start_time:.2f} seconds")
except Exception as e:
st.error(f"Error generating summaries: {str(e)}")
reset_processing_state()
# Display summaries with improved sorting and filtering
if st.session_state.summaries is not None:
col1, col2 = st.columns(2)
with col1:
sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited']
sort_column = st.selectbox("Sort/Filter by:", sort_options)
with col2:
if sort_column == 'Article Title':
ascending = st.radio(
"Sort order",
["A to Z", "Z to A"],
horizontal=True
) == "A to Z"
elif sort_column == 'Times Cited':
ascending = st.radio(
"Sort order",
["Most cited first", "Least cited first"],
horizontal=True
) == "Least cited first"
else:
ascending = True # Default for other columns
# Create display dataframe
display_df = df.copy()
display_df['Summary'] = st.session_state.summaries
display_df['Publication Year'] = display_df['Publication Year'].astype(int)
display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True)
display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int)
# Apply filters
filtered_df = create_filter_controls(display_df, sort_column)
# Apply sorting
if sort_column == 'Times Cited':
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending)
elif sort_column == 'Article Title':
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending)
else:
sorted_df = filtered_df
# Show number of filtered results
if len(sorted_df) != len(display_df):
st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers")
# Apply custom styling
st.markdown("""
<style>
.paper-info {
border: 1px solid #ddd;
padding: 15px;
margin-bottom: 20px;
border-radius: 5px;
}
.paper-section {
margin-bottom: 10px;
}
.section-header {
font-weight: bold;
color: #555;
margin-bottom: 8px;
}
.paper-title {
margin-top: 5px;
margin-bottom: 10px;
}
.paper-meta {
font-size: 0.9em;
color: #666;
}
.doi-link {
color: #0366d6;
}
</style>
""", unsafe_allow_html=True)
# Display papers using the filtered and sorted dataframe
for _, row in sorted_df.iterrows():
paper_info_cols = st.columns([1, 1])
with paper_info_cols[0]: # PAPER column
st.markdown('<div class="paper-section"><div class="section-header">PAPER</div>', unsafe_allow_html=True)
st.markdown(f"""
<div class="paper-info">
<div class="paper-title">{row['Article Title']}</div>
<div class="paper-meta">
<strong>Authors:</strong> {row['Authors']}<br>
<strong>Source:</strong> {row['Source Title']}<br>
<strong>Publication Year:</strong> {row['Publication Year']}<br>
<strong>Times Cited:</strong> {row['Times Cited']}<br>
<strong>DOI:</strong> {row['DOI'] if pd.notna(row['DOI']) else 'None'}
</div>
</div>
""", unsafe_allow_html=True)
with paper_info_cols[1]: # SUMMARY column
st.markdown('<div class="paper-section"><div class="section-header">SUMMARY</div>', unsafe_allow_html=True)
st.markdown(f"""
<div class="paper-info">
{row['Summary']}
</div>
""", unsafe_allow_html=True)
# Add spacing between papers
st.markdown("<div style='margin-bottom: 20px;'></div>", unsafe_allow_html=True)
# Question-focused Summary Section (only if question provided)
if question.strip():
st.header("❓ Question-focused Summary")
if not st.session_state.get('focused_summary_generated', False):
try:
with st.spinner("Analyzing relevant papers..."):
if st.session_state.text_processor is None:
st.session_state.text_processor = TextProcessor()
model, tokenizer = get_model("question_focused")
if model is None or tokenizer is None:
raise Exception("Failed to load question-focused model")
results = st.session_state.text_processor.find_most_relevant_abstracts(
question,
df['Abstract'].tolist(),
top_k=5
)
if not results['top_indices']:
st.warning("No papers found relevant to your question")
return
# Store relevant papers and scores
st.session_state.relevant_papers = df.iloc[results['top_indices']]
st.session_state.relevance_scores = results['scores']
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist()
st.session_state.focused_summary = generate_focused_summary(
question,
relevant_abstracts,
model,
tokenizer
)
st.session_state.focused_summary_generated = True
except Exception as e:
st.error(f"Error generating focused summary: {str(e)}")
reset_processing_state()
finally:
cleanup_model(model, tokenizer)
# Display focused summary results
if st.session_state.get('focused_summary_generated', False):
st.subheader("Summary")
st.write(st.session_state.focused_summary)
st.subheader("Most Relevant Papers")
relevant_papers = st.session_state.relevant_papers[
['Article Title', 'Authors', 'Publication Year', 'DOI']
].copy()
relevant_papers['Relevance Score'] = st.session_state.relevance_scores
relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int)
st.dataframe(relevant_papers, hide_index=True)
if __name__ == "__main__":
main()