biomedical / app.py
pendar02's picture
Update app.py
0a57b0f verified
raw
history blame
24.9 kB
import streamlit as st
import pandas as pd
import torch
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from text_processing import TextProcessor
import gc
from pathlib import Path
# Configure page
st.set_page_config(
page_title="Biomedical Papers Analysis",
page_icon="πŸ”¬",
layout="wide"
)
# Initialize session state
if 'processed_data' not in st.session_state:
st.session_state.processed_data = None
if 'summaries' not in st.session_state:
st.session_state.summaries = None
if 'text_processor' not in st.session_state:
st.session_state.text_processor = None
if 'processing_started' not in st.session_state:
st.session_state.processing_started = False
if 'focused_summary_generated' not in st.session_state:
st.session_state.focused_summary_generated = False
def preprocess_text(text):
"""Preprocess text for summarization"""
if not isinstance(text, str) or not text.strip():
return text
# Clean up whitespace
text = re.sub(r'\s+', ' ', text)
text = text.strip()
# Fix common formatting issues
text = re.sub(r'(\d+)\s*%', r'\1%', text) # Fix percentage format
text = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', text) # Fix sample size format
text = re.sub(r'([Pp])\s*([<>])\s*(\d)', r'\1\2\3', text) # Fix p-value format
return text
def verify_facts(summary, original_text):
"""Verify key facts between summary and original text"""
# Extract numbers and percentages
def extract_numbers(text):
return set(re.findall(r'(\d+\.?\d*)%?', text))
# Extract relationships
def extract_relationships(text):
patterns = [
r'associated with', r'predicted', r'correlated',
r'increased', r'decreased', r'significant'
]
found = []
for pattern in patterns:
if re.search(pattern, text.lower()):
found.append(pattern)
return set(found)
# Get facts from both texts
original_numbers = extract_numbers(original_text)
summary_numbers = extract_numbers(summary)
original_relations = extract_relationships(original_text)
summary_relations = extract_relationships(summary)
return {
'is_valid': summary_numbers.issubset(original_numbers) and
summary_relations.issubset(original_relations),
'missing_numbers': original_numbers - summary_numbers,
'missing_relations': original_relations - summary_relations
}
def load_model(model_type):
"""Load appropriate model based on type with proper memory management"""
try:
gc.collect()
torch.cuda.empty_cache()
device = "cpu"
if model_type == "summarize":
model = AutoModelForSeq2SeqLM.from_pretrained(
"pendar02/bart-large-pubmedd",
cache_dir="./models",
torch_dtype=torch.float32
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"pendar02/bart-large-pubmedd",
cache_dir="./models"
)
else:
base_model = AutoModelForSeq2SeqLM.from_pretrained(
"GanjinZero/biobart-base",
cache_dir="./models",
torch_dtype=torch.float32
).to(device)
model = PeftModel.from_pretrained(
base_model,
"pendar02/biobart-finetune",
is_trainable=False
).to(device)
tokenizer = AutoTokenizer.from_pretrained(
"GanjinZero/biobart-base",
cache_dir="./models"
)
model.eval()
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
raise
def cleanup_model(model, tokenizer):
try:
del model
del tokenizer
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass
def process_excel(uploaded_file):
try:
df = pd.read_excel(uploaded_file)
required_columns = ['Abstract', 'Article Title', 'Authors',
'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
st.error(f"Missing required columns: {', '.join(missing_columns)}")
return None
return df[required_columns]
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return None
def improve_summary_generation(text, model, tokenizer):
"""Generate improved summary with better prompt and validation"""
if not isinstance(text, str) or not text.strip():
return "No abstract available to summarize."
try:
# Simplified prompt
formatted_text = (
"Summarize this biomedical abstract into four sections:\n"
"1. Background/Objectives: State the main purpose and population\n"
"2. Methods: Describe what was done\n"
"3. Key findings: Include ALL numerical results and statistical relationships\n"
"4. Conclusions: State main implications\n\n"
"Important: Preserve all numbers, measurements, and statistical findings.\n\n"
"Text: " + preprocess_text(text)
)
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Single generation attempt with optimized parameters
with torch.no_grad():
summary_ids = model.generate(
**{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_length": 300,
"min_length": 100,
"num_beams": 5,
"length_penalty": 2.0,
"no_repeat_ngram_size": 3,
"temperature": 0.3,
"repetition_penalty": 2.5
}
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
if not summary:
return "Error: Could not generate summary."
return post_process_summary(summary)
except Exception as e:
print(f"Error in summary generation: {str(e)}")
return "Error generating summary."
def post_process_summary(summary):
"""Enhanced post-processing focused on maintaining structure and removing artifacts"""
if not summary:
return summary
# Clean up section headers
header_mappings = {
r'(?i)background.*objectives?:?': 'Background and objectives:',
r'(?i)(materials?\s*and\s*)?methods?:?': 'Methods:',
r'(?i)(key\s*)?findings?:?|results?:?': 'Key findings:',
r'(?i)conclusions?:?': 'Conclusions:',
r'(?i)(study\s*)?aims?:?|goals?:?|purpose:?': '',
r'(?i)objectives?:?': '',
r'(?i)outcomes?:?': '',
r'(?i)discussion:?': ''
}
for pattern, replacement in header_mappings.items():
summary = re.sub(pattern, replacement, summary)
# Split into sections and clean
sections = re.split(r'(?i)(Background and objectives:|Methods:|Key findings:|Conclusions:)', summary)
sections = [s.strip() for s in sections if s.strip()]
# Reorganize sections
organized_sections = {
'Background and objectives': '',
'Methods': '',
'Key findings': '',
'Conclusions': ''
}
current_section = None
for item in sections:
if item in organized_sections:
current_section = item
elif current_section:
# Clean up content
content = re.sub(r'\s+', ' ', item) # Fix spacing
content = re.sub(r'\.+', '.', content) # Fix multiple periods
content = content.strip('.: ') # Remove trailing periods and spaces
organized_sections[current_section] = content
# Build final summary
final_sections = []
for section, content in organized_sections.items():
if content:
final_sections.append(f"{section} {content}.")
return '\n\n'.join(final_sections)
def validate_summary(summary, original_text):
"""Validate summary content against original text"""
# Perform fact verification
verification = verify_facts(summary, original_text)
if not verification.get('is_valid', False):
return False
# Check for age inconsistencies
age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
if len(age_mentions) > 1: # Multiple age mentions
return False
# Check for repetitive sentences
sentences = summary.split('.')
unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
return False
# Check summary isn't too long or too short compared to original
summary_words = len(summary.split())
original_words = len(original_text.split())
if summary_words < 20 or summary_words > original_words * 0.8:
return False
return True
def generate_focused_summary(question, abstracts, model, tokenizer):
"""Generate focused summary based on question"""
try:
# Preprocess each abstract
formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts]
combined_input = f"Question: {question}\nSummarize these abstracts to answer the question:\n" + \
"\n---\n".join(formatted_abstracts)
inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
summary_ids = model.generate(
**{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_length": 300,
"min_length": 100,
"num_beams": 5,
"length_penalty": 2.0,
"temperature": 0.3,
"repetition_penalty": 2.5
}
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
except Exception as e:
print(f"Error in focused summary generation: {str(e)}")
return "Error generating focused summary."
def create_filter_controls(df, sort_column):
"""Create appropriate filter controls based on the selected column"""
filtered_df = df.copy()
if sort_column == 'Publication Year':
year_min = int(df['Publication Year'].min())
year_max = int(df['Publication Year'].max())
col1, col2 = st.columns(2)
with col1:
start_year = st.number_input('From Year',
min_value=year_min,
max_value=year_max,
value=year_min)
with col2:
end_year = st.number_input('To Year',
min_value=year_min,
max_value=year_max,
value=year_max)
filtered_df = filtered_df[
(filtered_df['Publication Year'] >= start_year) &
(filtered_df['Publication Year'] <= end_year)
]
elif sort_column == 'Authors':
unique_authors = sorted(set(
author.strip()
for authors in df['Authors'].dropna()
for author in authors.split(';')
))
selected_authors = st.multiselect(
'Select Authors',
unique_authors
)
if selected_authors:
filtered_df = filtered_df[
filtered_df['Authors'].apply(
lambda x: any(author in str(x) for author in selected_authors)
)
]
elif sort_column == 'Source Title':
unique_sources = sorted(df['Source Title'].unique())
selected_sources = st.multiselect(
'Select Sources',
unique_sources
)
if selected_sources:
filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)]
elif sort_column == 'Times Cited':
cited_min = int(df['Times Cited'].min())
cited_max = int(df['Times Cited'].max())
col1, col2 = st.columns(2)
with col1:
start_cited = st.number_input('From Cited Count',
min_value=cited_min,
max_value=cited_max,
value=cited_min)
with col2:
end_cited = st.number_input('To Cited Count',
min_value=cited_min,
max_value=cited_max,
value=cited_max)
filtered_df = filtered_df[
(filtered_df['Times Cited'] >= start_cited) &
(filtered_df['Times Cited'] <= end_cited)
]
return filtered_df
def main():
st.title("πŸ”¬ Biomedical Papers Analysis")
uploaded_file = st.file_uploader(
"Upload Excel file containing papers",
type=['xlsx', 'xls'],
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
)
question_container = st.empty()
question = ""
if uploaded_file is not None:
if st.session_state.processed_data is None:
with st.spinner("Processing file..."):
df = process_excel(uploaded_file)
if df is not None:
st.session_state.processed_data = df.dropna(subset=["Abstract"])
if st.session_state.processed_data is not None:
df = st.session_state.processed_data
st.write(f"πŸ“Š Loaded {len(df)} papers with abstracts")
with question_container:
question = st.text_input(
"Enter your research question (optional):",
help="If provided, a focused summary will be generated after individual summaries"
)
# Single button for both processes
if not st.session_state.get('processing_started', False):
if st.button("Start Analysis"):
st.session_state.processing_started = True
# Show processing status and results
if st.session_state.get('processing_started', False):
# Individual Summaries Section
st.header("πŸ“ Individual Paper Summaries")
# Generate summaries if not already done
if st.session_state.summaries is None:
try:
with st.spinner("Generating individual paper summaries..."):
model, tokenizer = load_model("summarize")
summaries = []
progress_bar = st.progress(0)
for idx, abstract in enumerate(df['Abstract']):
summary = improve_summary_generation(abstract, model, tokenizer)
summaries.append(summary)
progress_bar.progress((idx + 1) / len(df))
st.session_state.summaries = summaries
cleanup_model(model, tokenizer)
progress_bar.empty()
except Exception as e:
st.error(f"Error generating summaries: {str(e)}")
st.session_state.processing_started = False
# Display summaries with improved sorting and filtering
if st.session_state.summaries is not None:
col1, col2 = st.columns(2)
with col1:
sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited']
sort_column = st.selectbox("Sort/Filter by:", sort_options)
with col2:
# Only show A-Z/Z-A option for Article Title
if sort_column == 'Article Title':
ascending = st.radio(
"Sort order",
["A to Z", "Z to A"],
horizontal=True
) == "A to Z"
elif sort_column == 'Times Cited':
ascending = st.radio(
"Sort order",
["Most cited", "Least cited"],
horizontal=True
) == "Least cited"
else:
ascending = True # Default for other columns
# Create display dataframe
display_df = df.copy()
display_df['Summary'] = st.session_state.summaries
display_df['Publication Year'] = display_df['Publication Year'].astype(int)
display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True)
display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int)
# Apply filters
filtered_df = create_filter_controls(display_df, sort_column)
if sort_column == 'Article Title':
# Sort alphabetically
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending)
else:
# Keep original order for other columns after filtering
# Keep original order for other columns after filtering
sorted_df = filtered_df
# Show number of filtered results
if len(sorted_df) != len(display_df):
st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers")
# Apply custom styling
st.markdown("""
<style>
.paper-info {
border: 1px solid #ddd;
padding: 15px;
margin-bottom: 20px;
border-radius: 5px;
}
.paper-section {
margin-bottom: 10px;
}
.section-header {
font-weight: bold;
color: #555;
margin-bottom: 8px;
}
.paper-title {
margin-top: 5px;
margin-bottom: 10px;
}
.paper-meta {
font-size: 0.9em;
color: #666;
}
.doi-link {
color: #0366d6;
}
</style>
""", unsafe_allow_html=True)
# Display papers using the filtered and sorted dataframe
for _, row in sorted_df.iterrows():
paper_info_cols = st.columns([1, 1])
with paper_info_cols[0]: # PAPER column
st.markdown('<div class="paper-section"><div class="section-header">PAPER</div>', unsafe_allow_html=True)
st.markdown(f"""
<div class="paper-info">
<div class="paper-title">{row['Article Title']}</div>
<div class="paper-meta">
<strong>Authors:</strong> {row['Authors']}<br>
<strong>Source:</strong> {row['Source Title']}<br>
<strong>Publication Year:</strong> {row['Publication Year']}<br>
<strong>Times Cited:</strong> {row['Times Cited']}<br>
<strong>DOI:</strong> {row['DOI'] if pd.notna(row['DOI']) else 'None'}
</div>
</div>
""", unsafe_allow_html=True)
with paper_info_cols[1]: # SUMMARY column
st.markdown('<div class="paper-section"><div class="section-header">SUMMARY</div>', unsafe_allow_html=True)
st.markdown(f"""
<div class="paper-info">
{row['Summary']}
</div>
""", unsafe_allow_html=True)
# Add spacing between papers
st.markdown("<div style='margin-bottom: 20px;'></div>", unsafe_allow_html=True)
# Question-focused Summary Section (only if question provided)
if question.strip():
st.header("❓ Question-focused Summary")
if not st.session_state.get('focused_summary_generated', False):
try:
with st.spinner("Analyzing relevant papers..."):
# Initialize text processor if needed
if st.session_state.text_processor is None:
st.session_state.text_processor = TextProcessor()
# Find relevant abstracts
results = st.session_state.text_processor.find_most_relevant_abstracts(
question,
df['Abstract'].tolist(),
top_k=5
)
# Load question-focused model
model, tokenizer = load_model("question_focused")
# Generate focused summary
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist()
focused_summary = generate_focused_summary(
question,
relevant_abstracts,
model,
tokenizer
)
# Store results
st.session_state.focused_summary = focused_summary
st.session_state.relevant_papers = df.iloc[results['top_indices']]
st.session_state.relevance_scores = results['scores']
st.session_state.focused_summary_generated = True
# Cleanup second model
cleanup_model(model, tokenizer)
except Exception as e:
st.error(f"Error generating focused summary: {str(e)}")
# Display focused summary results
if st.session_state.get('focused_summary_generated', False):
st.subheader("Summary")
st.write(st.session_state.focused_summary)
st.subheader("Most Relevant Papers")
relevant_papers = st.session_state.relevant_papers[
['Article Title', 'Authors', 'Publication Year', 'DOI']
].copy()
relevant_papers['Relevance Score'] = st.session_state.relevance_scores
relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int)
st.dataframe(relevant_papers, hide_index=True)
if __name__ == "__main__":
main()