import streamlit as st import pandas as pd import torch import re from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel from text_processing import TextProcessor import gc from pathlib import Path # Configure page st.set_page_config( page_title="Biomedical Papers Analysis", page_icon="🔬", layout="wide" ) # Initialize session state if 'processed_data' not in st.session_state: st.session_state.processed_data = None if 'summaries' not in st.session_state: st.session_state.summaries = None if 'text_processor' not in st.session_state: st.session_state.text_processor = None if 'processing_started' not in st.session_state: st.session_state.processing_started = False if 'focused_summary_generated' not in st.session_state: st.session_state.focused_summary_generated = False def preprocess_text(text): """Preprocess text for summarization""" if not isinstance(text, str) or not text.strip(): return text # Clean up whitespace text = re.sub(r'\s+', ' ', text) text = text.strip() # Fix common formatting issues text = re.sub(r'(\d+)\s*%', r'\1%', text) # Fix percentage format text = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', text) # Fix sample size format text = re.sub(r'([Pp])\s*([<>])\s*(\d)', r'\1\2\3', text) # Fix p-value format return text def verify_facts(summary, original_text): """Verify key facts between summary and original text""" # Extract numbers and percentages def extract_numbers(text): return set(re.findall(r'(\d+\.?\d*)%?', text)) # Extract relationships def extract_relationships(text): patterns = [ r'associated with', r'predicted', r'correlated', r'increased', r'decreased', r'significant' ] found = [] for pattern in patterns: if re.search(pattern, text.lower()): found.append(pattern) return set(found) # Get facts from both texts original_numbers = extract_numbers(original_text) summary_numbers = extract_numbers(summary) original_relations = extract_relationships(original_text) summary_relations = extract_relationships(summary) return { 'is_valid': summary_numbers.issubset(original_numbers) and summary_relations.issubset(original_relations), 'missing_numbers': original_numbers - summary_numbers, 'missing_relations': original_relations - summary_relations } def load_model(model_type): """Load appropriate model based on type with proper memory management""" try: gc.collect() torch.cuda.empty_cache() device = "cpu" if model_type == "summarize": model = AutoModelForSeq2SeqLM.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models", torch_dtype=torch.float32 ).to(device) tokenizer = AutoTokenizer.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models" ) else: base_model = AutoModelForSeq2SeqLM.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models", torch_dtype=torch.float32 ).to(device) model = PeftModel.from_pretrained( base_model, "pendar02/biobart-finetune", is_trainable=False ).to(device) tokenizer = AutoTokenizer.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models" ) model.eval() return model, tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") raise def cleanup_model(model, tokenizer): try: del model del tokenizer torch.cuda.empty_cache() gc.collect() except Exception: pass def process_excel(uploaded_file): try: df = pd.read_excel(uploaded_file) required_columns = ['Abstract', 'Article Title', 'Authors', 'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: st.error(f"Missing required columns: {', '.join(missing_columns)}") return None return df[required_columns] except Exception as e: st.error(f"Error processing file: {str(e)}") return None def improve_summary_generation(text, model, tokenizer): """Generate improved summary with better prompt and validation""" if not isinstance(text, str) or not text.strip(): return "No abstract available to summarize." try: # Simplified prompt formatted_text = ( "Summarize this biomedical abstract into four sections:\n" "1. Background/Objectives: State the main purpose and population\n" "2. Methods: Describe what was done\n" "3. Key findings: Include ALL numerical results and statistical relationships\n" "4. Conclusions: State main implications\n\n" "Important: Preserve all numbers, measurements, and statistical findings.\n\n" "Text: " + preprocess_text(text) ) inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} # Single generation attempt with optimized parameters with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 300, "min_length": 100, "num_beams": 5, "length_penalty": 2.0, "no_repeat_ngram_size": 3, "temperature": 0.3, "repetition_penalty": 2.5 } ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) if not summary: return "Error: Could not generate summary." return post_process_summary(summary) except Exception as e: print(f"Error in summary generation: {str(e)}") return "Error generating summary." def post_process_summary(summary): """Enhanced post-processing focused on maintaining structure and removing artifacts""" if not summary: return summary # Clean up section headers header_mappings = { r'(?i)background.*objectives?:?': 'Background and objectives:', r'(?i)(materials?\s*and\s*)?methods?:?': 'Methods:', r'(?i)(key\s*)?findings?:?|results?:?': 'Key findings:', r'(?i)conclusions?:?': 'Conclusions:', r'(?i)(study\s*)?aims?:?|goals?:?|purpose:?': '', r'(?i)objectives?:?': '', r'(?i)outcomes?:?': '', r'(?i)discussion:?': '' } for pattern, replacement in header_mappings.items(): summary = re.sub(pattern, replacement, summary) # Split into sections and clean sections = re.split(r'(?i)(Background and objectives:|Methods:|Key findings:|Conclusions:)', summary) sections = [s.strip() for s in sections if s.strip()] # Reorganize sections organized_sections = { 'Background and objectives': '', 'Methods': '', 'Key findings': '', 'Conclusions': '' } current_section = None for item in sections: if item in organized_sections: current_section = item elif current_section: # Clean up content content = re.sub(r'\s+', ' ', item) # Fix spacing content = re.sub(r'\.+', '.', content) # Fix multiple periods content = content.strip('.: ') # Remove trailing periods and spaces organized_sections[current_section] = content # Build final summary final_sections = [] for section, content in organized_sections.items(): if content: final_sections.append(f"{section} {content}.") return '\n\n'.join(final_sections) def validate_summary(summary, original_text): """Validate summary content against original text""" # Perform fact verification verification = verify_facts(summary, original_text) if not verification.get('is_valid', False): return False # Check for age inconsistencies age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower()) if len(age_mentions) > 1: # Multiple age mentions return False # Check for repetitive sentences sentences = summary.split('.') unique_sentences = set(s.strip().lower() for s in sentences if s.strip()) if len(sentences) - len(unique_sentences) > 1: # More than one duplicate return False # Check summary isn't too long or too short compared to original summary_words = len(summary.split()) original_words = len(original_text.split()) if summary_words < 20 or summary_words > original_words * 0.8: return False return True def generate_focused_summary(question, abstracts, model, tokenizer): """Generate focused summary based on question""" try: # Preprocess each abstract formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts] combined_input = f"Question: {question}\nSummarize these abstracts to answer the question:\n" + \ "\n---\n".join(formatted_abstracts) inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 300, "min_length": 100, "num_beams": 5, "length_penalty": 2.0, "temperature": 0.3, "repetition_penalty": 2.5 } ) return tokenizer.decode(summary_ids[0], skip_special_tokens=True) except Exception as e: print(f"Error in focused summary generation: {str(e)}") return "Error generating focused summary." def create_filter_controls(df, sort_column): """Create appropriate filter controls based on the selected column""" filtered_df = df.copy() if sort_column == 'Publication Year': year_min = int(df['Publication Year'].min()) year_max = int(df['Publication Year'].max()) col1, col2 = st.columns(2) with col1: start_year = st.number_input('From Year', min_value=year_min, max_value=year_max, value=year_min) with col2: end_year = st.number_input('To Year', min_value=year_min, max_value=year_max, value=year_max) filtered_df = filtered_df[ (filtered_df['Publication Year'] >= start_year) & (filtered_df['Publication Year'] <= end_year) ] elif sort_column == 'Authors': unique_authors = sorted(set( author.strip() for authors in df['Authors'].dropna() for author in authors.split(';') )) selected_authors = st.multiselect( 'Select Authors', unique_authors ) if selected_authors: filtered_df = filtered_df[ filtered_df['Authors'].apply( lambda x: any(author in str(x) for author in selected_authors) ) ] elif sort_column == 'Source Title': unique_sources = sorted(df['Source Title'].unique()) selected_sources = st.multiselect( 'Select Sources', unique_sources ) if selected_sources: filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] elif sort_column == 'Times Cited': cited_min = int(df['Times Cited'].min()) cited_max = int(df['Times Cited'].max()) col1, col2 = st.columns(2) with col1: start_cited = st.number_input('From Cited Count', min_value=cited_min, max_value=cited_max, value=cited_min) with col2: end_cited = st.number_input('To Cited Count', min_value=cited_min, max_value=cited_max, value=cited_max) filtered_df = filtered_df[ (filtered_df['Times Cited'] >= start_cited) & (filtered_df['Times Cited'] <= end_cited) ] return filtered_df def main(): st.title("🔬 Biomedical Papers Analysis") uploaded_file = st.file_uploader( "Upload Excel file containing papers", type=['xlsx', 'xls'], help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" ) question_container = st.empty() question = "" if uploaded_file is not None: if st.session_state.processed_data is None: with st.spinner("Processing file..."): df = process_excel(uploaded_file) if df is not None: st.session_state.processed_data = df.dropna(subset=["Abstract"]) if st.session_state.processed_data is not None: df = st.session_state.processed_data st.write(f"📊 Loaded {len(df)} papers with abstracts") with question_container: question = st.text_input( "Enter your research question (optional):", help="If provided, a focused summary will be generated after individual summaries" ) # Single button for both processes if not st.session_state.get('processing_started', False): if st.button("Start Analysis"): st.session_state.processing_started = True # Show processing status and results if st.session_state.get('processing_started', False): # Individual Summaries Section st.header("📝 Individual Paper Summaries") # Generate summaries if not already done if st.session_state.summaries is None: try: with st.spinner("Generating individual paper summaries..."): model, tokenizer = load_model("summarize") summaries = [] progress_bar = st.progress(0) for idx, abstract in enumerate(df['Abstract']): summary = improve_summary_generation(abstract, model, tokenizer) summaries.append(summary) progress_bar.progress((idx + 1) / len(df)) st.session_state.summaries = summaries cleanup_model(model, tokenizer) progress_bar.empty() except Exception as e: st.error(f"Error generating summaries: {str(e)}") st.session_state.processing_started = False # Display summaries with improved sorting and filtering if st.session_state.summaries is not None: col1, col2 = st.columns(2) with col1: sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] sort_column = st.selectbox("Sort/Filter by:", sort_options) with col2: # Only show A-Z/Z-A option for Article Title if sort_column == 'Article Title': ascending = st.radio( "Sort order", ["A to Z", "Z to A"], horizontal=True ) == "A to Z" elif sort_column == 'Times Cited': ascending = st.radio( "Sort order", ["Most cited", "Least cited"], horizontal=True ) == "Least cited" else: ascending = True # Default for other columns # Create display dataframe display_df = df.copy() display_df['Summary'] = st.session_state.summaries display_df['Publication Year'] = display_df['Publication Year'].astype(int) display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) # Apply filters filtered_df = create_filter_controls(display_df, sort_column) if sort_column == 'Article Title': # Sort alphabetically sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) else: # Keep original order for other columns after filtering # Keep original order for other columns after filtering sorted_df = filtered_df # Show number of filtered results if len(sorted_df) != len(display_df): st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") # Apply custom styling st.markdown(""" """, unsafe_allow_html=True) # Display papers using the filtered and sorted dataframe for _, row in sorted_df.iterrows(): paper_info_cols = st.columns([1, 1]) with paper_info_cols[0]: # PAPER column st.markdown('
PAPER
', unsafe_allow_html=True) st.markdown(f"""
{row['Article Title']}
Authors: {row['Authors']}
Source: {row['Source Title']}
Publication Year: {row['Publication Year']}
Times Cited: {row['Times Cited']}
DOI: {row['DOI'] if pd.notna(row['DOI']) else 'None'}
""", unsafe_allow_html=True) with paper_info_cols[1]: # SUMMARY column st.markdown('
SUMMARY
', unsafe_allow_html=True) st.markdown(f"""
{row['Summary']}
""", unsafe_allow_html=True) # Add spacing between papers st.markdown("
", unsafe_allow_html=True) # Question-focused Summary Section (only if question provided) if question.strip(): st.header("❓ Question-focused Summary") if not st.session_state.get('focused_summary_generated', False): try: with st.spinner("Analyzing relevant papers..."): # Initialize text processor if needed if st.session_state.text_processor is None: st.session_state.text_processor = TextProcessor() # Find relevant abstracts results = st.session_state.text_processor.find_most_relevant_abstracts( question, df['Abstract'].tolist(), top_k=5 ) # Load question-focused model model, tokenizer = load_model("question_focused") # Generate focused summary relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() focused_summary = generate_focused_summary( question, relevant_abstracts, model, tokenizer ) # Store results st.session_state.focused_summary = focused_summary st.session_state.relevant_papers = df.iloc[results['top_indices']] st.session_state.relevance_scores = results['scores'] st.session_state.focused_summary_generated = True # Cleanup second model cleanup_model(model, tokenizer) except Exception as e: st.error(f"Error generating focused summary: {str(e)}") # Display focused summary results if st.session_state.get('focused_summary_generated', False): st.subheader("Summary") st.write(st.session_state.focused_summary) st.subheader("Most Relevant Papers") relevant_papers = st.session_state.relevant_papers[ ['Article Title', 'Authors', 'Publication Year', 'DOI'] ].copy() relevant_papers['Relevance Score'] = st.session_state.relevance_scores relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) st.dataframe(relevant_papers, hide_index=True) if __name__ == "__main__": main()