Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
import re | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from peft import PeftModel | |
from text_processing import TextProcessor | |
import gc | |
from pathlib import Path | |
import concurrent.futures | |
import time | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
from concurrent.futures import ThreadPoolExecutor # Add this import | |
nltk.download('punkt') | |
# Configure page | |
st.set_page_config( | |
page_title="Biomedical Papers Analysis", | |
page_icon="π¬", | |
layout="wide" | |
) | |
# Initialize session state | |
if 'processed_data' not in st.session_state: | |
st.session_state.processed_data = None | |
if 'summaries' not in st.session_state: | |
st.session_state.summaries = None | |
if 'text_processor' not in st.session_state: | |
st.session_state.text_processor = None | |
if 'processing_started' not in st.session_state: | |
st.session_state.processing_started = False | |
if 'focused_summary_generated' not in st.session_state: | |
st.session_state.focused_summary_generated = False | |
if 'current_model' not in st.session_state: | |
st.session_state.current_model = None | |
if 'current_tokenizer' not in st.session_state: | |
st.session_state.current_tokenizer = None | |
if 'model_type' not in st.session_state: | |
st.session_state.model_type = None | |
# TextProcessor class definition | |
try: | |
from text_processing import TextProcessor | |
except ImportError: | |
class TextProcessor: | |
def find_most_relevant_abstracts(self, question, abstracts, top_k=5): | |
return { | |
'top_indices': list(range(min(top_k, len(abstracts)))), | |
'scores': [1.0] * min(top_k, len(abstracts)) | |
} | |
def load_model(model_type): | |
"""Load appropriate model based on type with proper memory management""" | |
try: | |
# Clear any existing cached data | |
gc.collect() | |
torch.cuda.empty_cache() | |
device = "cpu" # Force CPU usage | |
if model_type == "summarize": | |
# Load the new fine-tuned model directly | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"pendar02/bart-large-pubmedd", | |
cache_dir="./models", | |
torch_dtype=torch.float32 | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"pendar02/bart-large-pubmedd", | |
cache_dir="./models" | |
) | |
else: # question_focused | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models", | |
torch_dtype=torch.float32 | |
).to(device) | |
model = PeftModel.from_pretrained( | |
base_model, | |
"pendar02/biobart-finetune", | |
is_trainable=False | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models" | |
) | |
model.eval() | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
raise | |
def get_model(model_type): | |
"""Get model from session state or load if needed""" | |
try: | |
if (st.session_state.current_model is None or | |
st.session_state.model_type != model_type): | |
# Clean up existing model | |
if st.session_state.current_model is not None: | |
cleanup_model(st.session_state.current_model, | |
st.session_state.current_tokenizer) | |
# Load new model | |
model, tokenizer = load_model(model_type) | |
st.session_state.current_model = model | |
st.session_state.current_tokenizer = tokenizer | |
st.session_state.model_type = model_type | |
return st.session_state.current_model, st.session_state.current_tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.session_state.processing_started = False | |
return None, None | |
def cleanup_model(model, tokenizer): | |
"""Properly cleanup model resources""" | |
try: | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception: | |
pass | |
def process_excel(uploaded_file): | |
"""Process uploaded Excel file""" | |
try: | |
df = pd.read_excel(uploaded_file) | |
required_columns = ['Abstract', 'Article Title', 'Authors', | |
'Source Title', 'Publication Year', 'DOI', | |
'Times Cited, All Databases'] | |
# Check required columns first | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
st.error("β Missing required columns: " + ", ".join(missing_columns)) | |
st.error("Please ensure your Excel file contains all required columns.") | |
return None | |
# Only proceed with validation if all required columns exist | |
if len(df) > 5: | |
st.error("β Your file contains more than 5 papers. Please upload a file with maximum 5 papers.") | |
return None | |
# Now safe to validate structure as we know columns exist | |
is_valid, messages = validate_excel_structure(df) | |
if not is_valid: | |
for msg in messages: | |
st.error(f"β {msg}") | |
return None | |
return df[required_columns] | |
except Exception as e: | |
st.error(f"β Error reading file: {str(e)}") | |
st.error("Please check if your file is in the correct Excel format (.xlsx or .xls)") | |
return None | |
def validate_excel_structure(df): | |
"""Validate the structure and content of the Excel file""" | |
validation_messages = [] | |
# Check for minimum content | |
if len(df) == 0: | |
validation_messages.append("File contains no data") | |
return False, validation_messages | |
try: | |
# Check publication year format - this is useful for sorting/filtering | |
df['Publication Year'] = pd.to_numeric(df['Publication Year'], errors='coerce') | |
if df['Publication Year'].isna().any(): | |
validation_messages.append("Some publication years are invalid. Please ensure all years are in numeric format (e.g., 2024)") | |
else: | |
years = df['Publication Year'].dropna() | |
if len(years) > 0: | |
if years.min() < 1900 or years.max() > 2025: | |
validation_messages.append("Publication years must be between 1900 and 2025") | |
# For short abstracts - just show a warning | |
short_abstracts = df['Abstract'].fillna('').astype(str).str.len() < 50 | |
if short_abstracts.any(): | |
st.warning("βΉοΈ Some abstracts are quite short, but will still be processed") | |
except Exception as e: | |
validation_messages.append(f"Error checking data format: {str(e)}") | |
return len(validation_messages) == 0, validation_messages | |
def preprocess_text(text): | |
"""Enhanced text preprocessing with improved header and list handling""" | |
if not isinstance(text, str) or not text.strip(): | |
return text | |
# Initial cleanup | |
text = re.sub(r'\s+', ' ', text.strip()) | |
# Standardize case for specific terms (e.g., PRIME -> Prime) | |
text = re.sub(r'\b([A-Z]{2,})\b', lambda m: m.group(1).title(), text) | |
# Fix spacing around punctuation and parentheses | |
text = re.sub(r'\s*:\s*', ': ', text) | |
text = re.sub(r'\s*,\s*', ', ', text) | |
text = re.sub(r'\(\s*([ivx\d]+)\s*\)', r'(\1)', text) | |
# Convert numbered lists to consistent format | |
text = re.sub(r'(?m)^\s*(\d+)\.\s*', r'(\1) ', text) | |
# Normalize section headers (using comprehensive patterns) | |
section_patterns = { | |
r'\b(?:Introduction|Background|Objectives|Purpose|Context)\s*:': 'Background and Objectives: ', | |
r'\b(?:Methods|Materials and Methods|Approach|Study Design|Experimental Design)\s*:': 'Methods: ', | |
r'\b(?:Results|Findings|Observations|Key Findings)\s*:': 'Results: ', | |
r'\b(?:Discussion|Analysis|Implications|Interpretation)\s*:': 'Discussion: ', | |
r'\b(?:Conclusion|Conclusions|Summary|Final Remarks)\s*:': 'Conclusions: ' | |
} | |
# Remove nested headers | |
nested_header_pattern = r'\d+\.\s*(?:Background|Objectives|Methods|Results|Discussion|Conclusions)\s*:' | |
text = re.sub(nested_header_pattern, '', text) | |
# Standardize section headers | |
for pattern, replacement in section_patterns.items(): | |
text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
# Split merged section headers | |
text = re.sub(r'(?i)Results\s+and\s+Conclusions:', 'Results: ', text) | |
# Handle special characters and normalize spacing | |
text = re.sub(r'[ββ]', '"', text) # Correctly handle double quotes | |
text = re.sub(r"[ββ]", "'", text) # Correctly handle single quotes | |
text = re.sub(r'\s*-\s*', '-', text) | |
# Tokenize and capitalize sentences | |
sentences = re.split(r'(?<=\w[.!?])\s+|\n(?=\d+\.|\(\w+\)|-)', text) | |
formatted_sentences = [s.strip().capitalize() for s in sentences if s.strip()] | |
return ' '.join(formatted_sentences) | |
def post_process_summary(summary): | |
"""Enhanced summary post-processing with improved formatting.""" | |
if not summary: | |
return summary | |
# Step 1: Remove empty or redundant headers | |
summary = re.sub(r'\b(?:Background|Objectives|Methods|Results|Conclusions)\s*:\s*\.?\s*', '', summary) | |
# Step 2: Fix spacing issues in lists and parentheses | |
summary = re.sub(r'\(\s*([ivx\d]+)\s*\)', r'(\1)', summary) # Fix space inside parentheses | |
summary = re.sub(r'\s*,\s*(\([ivx\d]+\))', r', \1', summary) # Fix spacing before list items | |
# Step 3: Ensure proper punctuation and spacing | |
summary = re.sub(r'(?<=[.!?])\s*([A-Z])', r' \1', summary) # Add space after punctuation | |
summary = re.sub(r'\s*:\s*', ': ', summary) # Fix spacing around colons | |
# Step 4: Remove sections with too little content | |
sections = [s.strip() for s in summary.split('\n') if len(s.split()) > 3] | |
summary = ' '.join(sections) | |
# Step 5: Remove multiple periods | |
summary = re.sub(r'\.\.+', '.', summary) | |
# Step 6: Ensure summary ends with a single period | |
summary = summary.strip() | |
if not summary.endswith('.'): | |
summary += '.' | |
return summary | |
def generate_focused_summary(question, abstracts, model, tokenizer): | |
"""Generate a structured summary based on the given question and abstracts.""" | |
# Preprocess and clean abstracts | |
formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts if abstract.strip()] | |
if not formatted_abstracts: | |
raise ValueError("Abstracts list is empty or improperly formatted.") | |
# Join abstracts with separator | |
abstracts_content = " [SEP] ".join(formatted_abstracts) | |
# Create the prompt | |
prompt = f""" | |
Generate a structured summary based on the given abstracts and the question. Follow these rules STRICTLY: | |
**QUESTION:** {question} | |
**SECTION FORMATTING RULES:** | |
1. Each section MUST start with the section name followed by ": " (e.g., "Background: "). | |
2. Each section MUST end with a period. | |
3. Write complete, grammatically correct sentences. | |
4. Do not use bullet points, lists, or combined section headers. | |
5. Maintain the exact order of sections: Background, Objectives, Methods, Results, Conclusions. | |
6. Avoid redundancies, incomplete thoughts, and cutting sentences mid-way. | |
7. Use transition words (e.g., "Additionally," "Furthermore," "Moreover") to connect ideas naturally. | |
**REQUIRED SECTIONS AND CONTENT:** | |
1. **Background**: | |
- Provide the context and motivation for the study. | |
- Do not mention objectives, methods, or results in this section. | |
2. **Objectives**: | |
- Clearly state the aim(s) of the study. | |
- Avoid referencing any methods or findings. | |
3. **Methods**: | |
- Describe the approach, tools, and procedures used. | |
- Do not include any findings or results in this section. | |
4. **Results**: | |
- Summarize the key findings, including relevant statistics and outcomes. | |
- Mention implications only if explicitly stated in the abstracts. | |
5. **Conclusions**: | |
- Highlight the overall interpretation of findings. | |
- Emphasize the significance and implications of the study. | |
**CRITICAL FORMAT RULES:** | |
1. Each section title must be followed by a colon and a space. | |
2. All sentences must be grammatically complete and coherent. | |
3. Avoid bullet points, lists, and repeated sections. | |
4. End each section with a period. | |
**INPUT ABSTRACTS:** {abstracts_content} | |
""" | |
# Tokenize input (use the correct variable `prompt` here) | |
inputs = tokenizer(prompt, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 280, | |
"min_length": 100, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"no_repeat_ngram_size": 2, | |
"temperature": 0.7, | |
"do_sample": False | |
} | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return post_process_summary(summary) | |
def process_papers_in_batches(df, model, tokenizer, batch_size=2): | |
"""Process papers in batches for better efficiency""" | |
abstracts = df['Abstract'].tolist() | |
summaries = [] | |
with ThreadPoolExecutor(max_workers=4) as executor: # Parallel processing | |
future_to_batch = {executor.submit(generate_focused_summary, "Focus on key findings and methods.", [abstract], model, tokenizer): abstract for abstract in abstracts} | |
for future in future_to_batch: | |
summaries.append(future.result()) | |
return summaries | |
def create_filter_controls(df, sort_column): | |
"""Create appropriate filter controls based on the selected column""" | |
filtered_df = df.copy() | |
if sort_column == 'Publication Year': | |
# Year range slider | |
year_min = int(df['Publication Year'].min()) | |
year_max = int(df['Publication Year'].max()) | |
col1, col2 = st.columns(2) | |
with col1: | |
start_year = st.number_input('From Year', | |
min_value=year_min, | |
max_value=year_max, | |
value=year_min) | |
with col2: | |
end_year = st.number_input('To Year', | |
min_value=year_min, | |
max_value=year_max, | |
value=year_max) | |
filtered_df = filtered_df[ | |
(filtered_df['Publication Year'] >= start_year) & | |
(filtered_df['Publication Year'] <= end_year) | |
] | |
elif sort_column == 'Authors': | |
# Multi-select for authors | |
unique_authors = sorted(set( | |
author.strip() | |
for authors in df['Authors'].dropna() | |
for author in authors.split(';') | |
)) | |
selected_authors = st.multiselect( | |
'Select Authors', | |
unique_authors | |
) | |
if selected_authors: | |
filtered_df = filtered_df[ | |
filtered_df['Authors'].apply( | |
lambda x: any(author in str(x) for author in selected_authors) | |
) | |
] | |
elif sort_column == 'Source Title': | |
# Multi-select for source titles | |
unique_sources = sorted(df['Source Title'].unique()) | |
selected_sources = st.multiselect( | |
'Select Sources', | |
unique_sources | |
) | |
if selected_sources: | |
filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] | |
elif sort_column == 'Article Title': | |
# Only alphabetical sorting, no filtering | |
pass | |
return filtered_df | |
def main(): | |
st.title("π¬ Biomedical Papers Analysis") | |
st.info(""" | |
**π File Upload Requirements:** | |
- Excel file (.xlsx or .xls) with **maximum 5 papers** | |
- Must contain these columns: | |
β’ Abstract | |
β’ Article Title | |
β’ Authors | |
β’ Source Title | |
β’ Publication Year | |
β’ DOI | |
β’ Times Cited, All Databases | |
""") | |
# File upload section | |
uploaded_file = st.file_uploader( | |
"Upload Excel file containing papers (max 5 papers)", | |
type=['xlsx', 'xls'], | |
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" | |
) | |
# Question input - moved up but hidden initially | |
question_container = st.empty() | |
question = "" | |
if uploaded_file is not None: | |
# Process Excel file | |
if st.session_state.processed_data is None: | |
with st.spinner("Processing file..."): | |
df = process_excel(uploaded_file) | |
if df is not None: | |
df = df.dropna(subset=["Abstract"]) | |
if len(df) > 0: | |
st.session_state.processed_data = df | |
st.success(f"β Successfully loaded {len(df)} papers with abstracts") | |
else: | |
st.error("β No valid papers found after processing. Please check your file.") | |
if st.session_state.processed_data is not None: | |
df = st.session_state.processed_data | |
st.write(f"π Loaded {len(df)} papers with abstracts") | |
# Get question before processing | |
with question_container: | |
question = st.text_input( | |
"Enter your research question (optional):", | |
help="If provided, a question-focused summary will be generated after individual summaries" | |
) | |
# Single button for both processes | |
if not st.session_state.get('processing_started', False): | |
if st.button("Start Analysis"): | |
st.session_state.processing_started = True | |
# Show processing status and results | |
if st.session_state.get('processing_started', False): | |
# Individual Summaries Section | |
st.header("π Individual Paper Summaries") | |
# Generate summaries if not already done | |
if st.session_state.summaries is None: | |
try: | |
with st.spinner("Generating individual paper summaries..."): | |
model, tokenizer = get_model("summarize") | |
if model is None or tokenizer is None: | |
reset_processing_state() | |
return | |
start_time = time.time() | |
st.session_state.summaries = process_papers_in_batches( | |
df, model, tokenizer, batch_size=2 | |
) | |
end_time = time.time() | |
st.write(f"Processing time: {end_time - start_time:.2f} seconds") | |
except Exception as e: | |
st.error(f"Error generating summaries: {str(e)}") | |
reset_processing_state() | |
# Display summaries with improved sorting and filtering | |
if st.session_state.summaries is not None: | |
col1, col2 = st.columns(2) | |
with col1: | |
sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] | |
sort_column = st.selectbox("Sort/Filter by:", sort_options) | |
with col2: | |
if sort_column == 'Article Title': | |
ascending = st.radio( | |
"Sort order", | |
["A to Z", "Z to A"], | |
horizontal=True | |
) == "A to Z" | |
elif sort_column == 'Times Cited': | |
ascending = st.radio( | |
"Sort order", | |
["Most cited first", "Least cited first"], | |
horizontal=True | |
) == "Least cited first" | |
else: | |
ascending = True # Default for other columns | |
# Create display dataframe | |
display_df = df.copy() | |
display_df['Summary'] = st.session_state.summaries | |
display_df['Publication Year'] = display_df['Publication Year'].astype(int) | |
display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) | |
display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) | |
# Apply filters | |
filtered_df = create_filter_controls(display_df, sort_column) | |
# Apply sorting | |
if sort_column == 'Times Cited': | |
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) | |
elif sort_column == 'Article Title': | |
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) | |
else: | |
sorted_df = filtered_df | |
# Show number of filtered results | |
if len(sorted_df) != len(display_df): | |
st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") | |
# Apply custom styling | |
st.markdown(""" | |
<style> | |
.paper-info { | |
border: 1px solid #ddd; | |
padding: 15px; | |
margin-bottom: 20px; | |
border-radius: 5px; | |
} | |
.paper-section { | |
margin-bottom: 10px; | |
} | |
.section-header { | |
font-weight: bold; | |
color: #555; | |
margin-bottom: 8px; | |
} | |
.paper-title { | |
margin-top: 5px; | |
margin-bottom: 10px; | |
} | |
.paper-meta { | |
font-size: 0.9em; | |
color: #666; | |
} | |
.doi-link { | |
color: #0366d6; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Display papers using the filtered and sorted dataframe | |
for _, row in sorted_df.iterrows(): | |
paper_info_cols = st.columns([1, 1]) | |
with paper_info_cols[0]: # PAPER column | |
st.markdown('<div class="paper-section"><div class="section-header">PAPER</div>', unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class="paper-info"> | |
<div class="paper-title">{row['Article Title']}</div> | |
<div class="paper-meta"> | |
<strong>Authors:</strong> {row['Authors']}<br> | |
<strong>Source:</strong> {row['Source Title']}<br> | |
<strong>Publication Year:</strong> {row['Publication Year']}<br> | |
<strong>Times Cited:</strong> {row['Times Cited']}<br> | |
<strong>DOI:</strong> {row['DOI'] if pd.notna(row['DOI']) else 'None'} | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
with paper_info_cols[1]: # SUMMARY column | |
st.markdown('<div class="paper-section"><div class="section-header">SUMMARY</div>', unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class="paper-info"> | |
{row['Summary']} | |
</div> | |
""", unsafe_allow_html=True) | |
# Add spacing between papers | |
st.markdown("<div style='margin-bottom: 20px;'></div>", unsafe_allow_html=True) | |
# Question-focused Summary Section (only if question provided) | |
if question.strip(): | |
st.header("β Question-focused Summary") | |
if not st.session_state.get('focused_summary_generated', False): | |
try: | |
with st.spinner("Analyzing relevant papers..."): | |
# Initialize text processor if needed | |
if st.session_state.text_processor is None: | |
st.session_state.text_processor = TextProcessor() | |
# Validate question | |
if not question.strip(): | |
st.warning("Please enter a question first") | |
return | |
# Find relevant abstracts | |
results = st.session_state.text_processor.find_most_relevant_abstracts( | |
question, | |
df['Abstract'].tolist(), | |
top_k=5 | |
) | |
if not results['top_indices']: | |
st.warning("No relevant papers found for your question") | |
return | |
# Load question-focused model | |
model, tokenizer = get_model("question_focused") | |
if model is None or tokenizer is None: | |
return | |
# Generate focused summary | |
try: | |
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() | |
focused_summary = generate_focused_summary( | |
question, | |
relevant_abstracts, | |
model, | |
tokenizer | |
) | |
# Store results | |
st.session_state.focused_summary = focused_summary | |
st.session_state.relevant_papers = df.iloc[results['top_indices']] | |
st.session_state.relevance_scores = results['scores'] | |
st.session_state.focused_summary_generated = True | |
finally: | |
# Cleanup second model | |
cleanup_model(model, tokenizer) | |
except Exception as e: | |
st.error(f"Error generating focused summary: {str(e)}") | |
reset_processing_state() | |
# Display focused summary results | |
if st.session_state.get('focused_summary_generated', False): | |
st.subheader("Summary") | |
st.write(st.session_state.focused_summary) | |
st.subheader("Most Relevant Papers") | |
relevant_papers = st.session_state.relevant_papers[ | |
['Article Title', 'Authors', 'Publication Year', 'DOI'] | |
].copy() | |
relevant_papers['Relevance Score'] = st.session_state.relevance_scores | |
relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) | |
st.dataframe(relevant_papers, hide_index=True) | |
if __name__ == "__main__": | |
main() |