Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
import re | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from peft import PeftModel | |
from text_processing import TextProcessor | |
import gc | |
from pathlib import Path | |
import concurrent.futures | |
import time | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
from concurrent.futures import ThreadPoolExecutor # Add this import | |
nltk.download('punkt') | |
# Configure page | |
st.set_page_config( | |
page_title="Biomedical Papers Analysis", | |
page_icon="π¬", | |
layout="wide" | |
) | |
# Initialize session state | |
if 'relevant_papers' not in st.session_state: | |
st.session_state.relevant_papers = None | |
if 'relevance_scores' not in st.session_state: | |
st.session_state.relevance_scores = None | |
if 'processed_data' not in st.session_state: | |
st.session_state.processed_data = None | |
if 'summaries' not in st.session_state: | |
st.session_state.summaries = None | |
if 'text_processor' not in st.session_state: | |
st.session_state.text_processor = None | |
if 'processing_started' not in st.session_state: | |
st.session_state.processing_started = False | |
if 'focused_summary_generated' not in st.session_state: | |
st.session_state.focused_summary_generated = False | |
if 'current_model' not in st.session_state: | |
st.session_state.current_model = None | |
if 'current_tokenizer' not in st.session_state: | |
st.session_state.current_tokenizer = None | |
if 'model_type' not in st.session_state: | |
st.session_state.model_type = None | |
if 'focused_summary' not in st.session_state: | |
st.session_state.focused_summary = None | |
# TextProcessor class definition | |
try: | |
from text_processing import TextProcessor | |
except ImportError: | |
class TextProcessor: | |
def find_most_relevant_abstracts(self, question, abstracts, top_k=5): | |
return { | |
'top_indices': list(range(min(top_k, len(abstracts)))), | |
'scores': [1.0] * min(top_k, len(abstracts)) | |
} | |
def load_model(model_type): | |
"""Load appropriate model based on type with proper memory management""" | |
try: | |
# Clear any existing cached data | |
gc.collect() | |
torch.cuda.empty_cache() | |
device = "cpu" # Force CPU usage | |
if model_type == "summarize": | |
# Load the new fine-tuned model directly | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"pendar02/bart-large-pubmedd", | |
cache_dir="./models", | |
torch_dtype=torch.float32 | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"pendar02/bart-large-pubmedd", | |
cache_dir="./models" | |
) | |
else: # question_focused | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models", | |
torch_dtype=torch.float32 | |
).to(device) | |
model = PeftModel.from_pretrained( | |
base_model, | |
"pendar02/biobart-finetune", | |
is_trainable=False | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models" | |
) | |
model.eval() | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
raise | |
def get_model(model_type): | |
"""Get model from session state or load if needed""" | |
try: | |
if (st.session_state.current_model is None or | |
st.session_state.model_type != model_type): | |
# Clean up existing model | |
if st.session_state.current_model is not None: | |
cleanup_model(st.session_state.current_model, | |
st.session_state.current_tokenizer) | |
# Load new model | |
model, tokenizer = load_model(model_type) | |
st.session_state.current_model = model | |
st.session_state.current_tokenizer = tokenizer | |
st.session_state.model_type = model_type | |
return st.session_state.current_model, st.session_state.current_tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.session_state.processing_started = False | |
return None, None | |
def cleanup_model(model, tokenizer): | |
"""Properly cleanup model resources""" | |
try: | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception: | |
pass | |
def process_excel(uploaded_file): | |
"""Process uploaded Excel file""" | |
try: | |
df = pd.read_excel(uploaded_file) | |
required_columns = ['Abstract', 'Article Title', 'Authors', | |
'Source Title', 'Publication Year', 'DOI', | |
'Times Cited, All Databases'] | |
# Check required columns first | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
st.error("β Missing required columns: " + ", ".join(missing_columns)) | |
st.error("Please ensure your Excel file contains all required columns.") | |
return None | |
# Only proceed with validation if all required columns exist | |
if len(df) > 5: | |
st.error("β Your file contains more than 5 papers. Please upload a file with maximum 5 papers.") | |
return None | |
# Now safe to validate structure as we know columns exist | |
is_valid, messages = validate_excel_structure(df) | |
if not is_valid: | |
for msg in messages: | |
st.error(f"β {msg}") | |
return None | |
return df[required_columns] | |
except Exception as e: | |
st.error(f"β Error reading file: {str(e)}") | |
st.error("Please check if your file is in the correct Excel format (.xlsx or .xls)") | |
return None | |
def validate_excel_structure(df): | |
"""Validate the structure and content of the Excel file""" | |
validation_messages = [] | |
# Check for minimum content | |
if len(df) == 0: | |
validation_messages.append("File contains no data") | |
return False, validation_messages | |
try: | |
# Check publication year format - this is useful for sorting/filtering | |
df['Publication Year'] = pd.to_numeric(df['Publication Year'], errors='coerce') | |
if df['Publication Year'].isna().any(): | |
validation_messages.append("Some publication years are invalid. Please ensure all years are in numeric format (e.g., 2024)") | |
else: | |
years = df['Publication Year'].dropna() | |
if len(years) > 0: | |
if years.min() < 1900 or years.max() > 2025: | |
validation_messages.append("Publication years must be between 1900 and 2025") | |
# For short abstracts - just show a warning | |
short_abstracts = df['Abstract'].fillna('').astype(str).str.len() < 50 | |
if short_abstracts.any(): | |
st.warning("βΉοΈ Some abstracts are quite short, but will still be processed") | |
except Exception as e: | |
validation_messages.append(f"Error checking data format: {str(e)}") | |
return len(validation_messages) == 0, validation_messages | |
def preprocess_text(text): | |
"""Clean biomedical text by handling common formatting issues and standardizing structure.""" | |
if not isinstance(text, str) or not text.strip(): | |
return text | |
# Remove extra whitespace | |
text = ' '.join(text.split()) | |
# Roman numeral conversion | |
roman_map = {'i': '1', 'ii': '2', 'iii': '3', 'iv': '4', 'v': '5', | |
'vi': '6', 'vii': '7', 'viii': '8', 'ix': '9', 'x': '10'} | |
def replace_roman(match): | |
roman = match.group(1).lower() | |
return f"({roman_map.get(roman, roman)})" | |
text = re.sub(r'\(([ivx]+)\)', replace_roman, text) | |
# Clean enumerated lists | |
for roman in roman_map: | |
text = re.sub(f"\\b{roman}\\)", f"{roman_map[roman]})", text, flags=re.IGNORECASE) | |
# Standardize section headers | |
section_patterns = { | |
r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ', | |
r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ', | |
r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ', | |
r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: ', | |
r'\b(?:results?\s+and\s+conclusions?)\s*:?\s*(?=.*?:)': '', # Remove if followed by another section | |
r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)': 'Results: ' # Fix malformed combination | |
} | |
for pattern, replacement in section_patterns.items(): | |
text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
# Ensure complete sentences in sections | |
text = re.sub(r'(?<=:)\s*([^.!?\n]*?)(?=\s*(?:[A-Z][^:]*:|$))', | |
lambda m: f" {m.group(1)}." if m.group(1) and not m.group(1).strip().endswith('.') else m.group(0), | |
text) | |
# Fix truncated sentences | |
text = re.sub(r'(?<=:)\s*([^.!?\n]*?)\s*(?=[A-Z][^:]*:)', | |
lambda m: f" {m.group(1)}." if m.group(1) else "", | |
text) | |
# Clean formatting | |
text = re.sub(r'[\r\n]+', ' ', text) | |
text = re.sub(r'\s*:\s*', ': ', text) | |
text = re.sub(r'\s+', ' ', text) | |
text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text) | |
text = re.sub(r'β’|\*|β |β‘|β|β', '', text) | |
text = re.sub(r'\\n|\\r', ' ', text) | |
text = re.sub(r'\s*\(\s*', ' (', text) | |
text = re.sub(r'\s*\)\s*', ') ', text) | |
# Fix statistical notations | |
text = re.sub(r'p\s*[<=>]\s*0\.\d+', lambda m: m.group().replace(' ', ''), text) | |
text = re.sub(r'(?<=\d)\s*%', '%', text) | |
# Fix abbreviations spacing | |
text = re.sub(r'(?<=\w)vs\.(?=\w)', 'vs. ', text) | |
text = re.sub(r'(?<=\w)et\s+al\.(?=\w)', 'et al. ', text) | |
# Remove repeated punctuation | |
text = re.sub(r'([.!?])\1+', r'\1', text) | |
# Final cleanup | |
text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', ' ', text) | |
text = text.strip() | |
if not text.endswith('.'): | |
text += '.' | |
return text | |
# """Enhanced text preprocessing with better section handling and prompt removal.""" | |
# if not isinstance(text, str) or not text.strip(): | |
# return text | |
# # Remove prompt leakage | |
# prompt_patterns = [ | |
# r'Generate a structured summary addressing this question:.*?(?=\w+:)', | |
# r'Focus on key findings and methods\.', | |
# r'is a structured summary addressing this question:' | |
# ] | |
# for pattern in prompt_patterns: | |
# text = re.sub(pattern, '', text, flags=re.IGNORECASE) | |
# # Clean section headers more aggressively | |
# section_patterns = { | |
# r'\b(?:introduction|purpose|background|objectives?|context)\s*:?\s*': 'Background: ', | |
# r'\b(?:materials?\s+and\s+methods?|methods?|approach|study\s+design)\s*:?\s*': 'Methods: ', | |
# r'\b(?:results?|findings?|observations?)\s*:?\s*': 'Results: ', | |
# r'\b(?:conclusions?|summary|final\s+remarks?)\s*:?\s*': 'Conclusions: ' | |
# } | |
# # Apply section normalization | |
# for pattern, replacement in section_patterns.items(): | |
# text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
# # Remove combined section headers | |
# combined_headers = [ | |
# r'\bmethods?\s+and\s+conclusions?\b', | |
# r'\bresults?\s+and\s+conclusions?\b', | |
# r'\bmaterials?\s+and\s+methods?\b' | |
# ] | |
# for pattern in combined_headers: | |
# text = re.sub(pattern, 'Methods:', text, flags=re.IGNORECASE) | |
# # Clean up sentences | |
# sentences = text.split('.') | |
# cleaned_sentences = [] | |
# for sentence in sentences: | |
# # Remove redundant section references | |
# sentence = re.sub(r'\b(?:first|second|third|fourth|fifth)\s+sections?\b', '', sentence, flags=re.IGNORECASE) | |
# # Remove comparative phrases about section details | |
# sentence = re.sub(r'\b(?:more|less)\s+detailed\s+than.*', '', sentence, flags=re.IGNORECASE) | |
# if sentence.strip(): | |
# cleaned_sentences.append(sentence.strip()) | |
# # Rejoin and format | |
# text = '. '.join(cleaned_sentences) | |
# text = re.sub(r'\s+', ' ', text) # Remove extra spaces | |
# text = re.sub(r'\s*:\s*', ': ', text) # Fix spacing around colons | |
# return text.strip() | |
def generate_focused_summary(question, abstracts, model, tokenizer): | |
formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts if abstract.strip()] | |
abstracts_content = " [SEP] ".join(formatted_abstracts) | |
prompt = f""" | |
Provide a factual summary structured as: | |
- Background: Context and origin only if present | |
- Methods: Key procedures and approaches | |
- Results: Specific findings with numbers | |
- Conclusions: Main implications | |
Requirements: | |
- Present sections sequentially | |
- Merge related points within sections | |
- Complete all sentences | |
- Avoid repeating section headers | |
- Use original terminology | |
Content: {abstracts_content} | |
""" | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 512, | |
"min_length": 200, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"no_repeat_ngram_size": 3, | |
"temperature": 0.7, | |
"do_sample": False | |
} | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return post_process_summary(summary) | |
def post_process_summary(summary): | |
"""Post-process summary with improved section handling and formatting.""" | |
if not summary: | |
return summary | |
valid_sections = ['Background', 'Methods', 'Results', 'Conclusions'] | |
sections = {} | |
current_section = None | |
current_content = [] | |
# Pre-clean section headers | |
summary = re.sub(r'\b(?:results?\s*:\s*and\s*conclusions?\s*:)', 'Results:', summary, flags=re.IGNORECASE) | |
summary = re.sub(r'\bresults?\s*and\s*conclusions?\s*:', 'Results:', summary, flags=re.IGNORECASE) | |
# Process line by line | |
lines = [line.strip() for line in summary.split('.') if line.strip()] | |
for i, line in enumerate(lines): | |
section_match = None | |
for section in valid_sections: | |
if re.match(fr'\b{section}:', line, re.IGNORECASE): | |
section_match = section | |
break | |
if section_match: | |
if current_section: | |
content = ' '.join(current_content) | |
if content: | |
sections[current_section] = content | |
current_section = section_match | |
content = re.sub(fr'\b{section_match}:\s*', '', line, flags=re.IGNORECASE) | |
current_content = [content] if content else [] | |
elif current_section: | |
# Prevent section header splitting | |
if not any(sect.lower() in line.lower() for sect in valid_sections): | |
current_content.append(line) | |
if current_section and current_content: | |
sections[current_section] = ' '.join(current_content) | |
# Format sections | |
formatted_sections = [] | |
for section in valid_sections: | |
if section in sections: | |
content = sections[section].strip() | |
if content: | |
# Complete truncated sentences | |
if not re.search(r'[.!?]$', content): | |
if len(content.split()) >= 3: # Only complete if substantial | |
content += '.' | |
# Ensure capitalization | |
content = content[0].upper() + content[1:] | |
# Fix double periods | |
content = re.sub(r'\.+', '.', content) | |
formatted_sections.append(f"{section}: {content}") | |
return ' '.join(formatted_sections) | |
def process_papers_in_batches(df, model, tokenizer, batch_size=2): | |
"""Process papers in batches for better efficiency""" | |
abstracts = df['Abstract'].tolist() | |
summaries = [] | |
with ThreadPoolExecutor(max_workers=4) as executor: # Parallel processing | |
future_to_batch = {executor.submit(generate_focused_summary, "Focus on key findings and methods.", [abstract], model, tokenizer): abstract for abstract in abstracts} | |
for future in future_to_batch: | |
summaries.append(future.result()) | |
return summaries | |
def create_filter_controls(df, sort_column): | |
"""Create appropriate filter controls based on the selected column""" | |
filtered_df = df.copy() | |
if sort_column == 'Publication Year': | |
# Year range slider | |
year_min = int(df['Publication Year'].min()) | |
year_max = int(df['Publication Year'].max()) | |
col1, col2 = st.columns(2) | |
with col1: | |
start_year = st.number_input('From Year', | |
min_value=year_min, | |
max_value=year_max, | |
value=year_min) | |
with col2: | |
end_year = st.number_input('To Year', | |
min_value=year_min, | |
max_value=year_max, | |
value=year_max) | |
filtered_df = filtered_df[ | |
(filtered_df['Publication Year'] >= start_year) & | |
(filtered_df['Publication Year'] <= end_year) | |
] | |
elif sort_column == 'Authors': | |
# Multi-select for authors | |
unique_authors = sorted(set( | |
author.strip() | |
for authors in df['Authors'].dropna() | |
for author in authors.split(';') | |
)) | |
selected_authors = st.multiselect( | |
'Select Authors', | |
unique_authors | |
) | |
if selected_authors: | |
filtered_df = filtered_df[ | |
filtered_df['Authors'].apply( | |
lambda x: any(author in str(x) for author in selected_authors) | |
) | |
] | |
elif sort_column == 'Source Title': | |
# Multi-select for source titles | |
unique_sources = sorted(df['Source Title'].unique()) | |
selected_sources = st.multiselect( | |
'Select Sources', | |
unique_sources | |
) | |
if selected_sources: | |
filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] | |
elif sort_column == 'Article Title': | |
# Only alphabetical sorting, no filtering | |
pass | |
return filtered_df | |
def main(): | |
st.title("π¬ Biomedical Papers Analysis") | |
st.info(""" | |
**π File Upload Requirements:** | |
- Excel file (.xlsx or .xls) with **maximum 5 papers** | |
- Must contain these columns: | |
β’ Abstract | |
β’ Article Title | |
β’ Authors | |
β’ Source Title | |
β’ Publication Year | |
β’ DOI | |
β’ Times Cited, All Databases | |
""") | |
# File upload section | |
uploaded_file = st.file_uploader( | |
"Upload Excel file containing papers (max 5 papers)", | |
type=['xlsx', 'xls'], | |
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" | |
) | |
# Question input - moved up but hidden initially | |
question_container = st.empty() | |
question = "" | |
if uploaded_file is not None: | |
# Process Excel file | |
if st.session_state.processed_data is None: | |
with st.spinner("Processing file..."): | |
df = process_excel(uploaded_file) | |
if df is not None: | |
df = df.dropna(subset=["Abstract"]) | |
if len(df) > 0: | |
st.session_state.processed_data = df | |
st.success(f"β Successfully loaded {len(df)} papers with abstracts") | |
else: | |
st.error("β No valid papers found after processing. Please check your file.") | |
if st.session_state.processed_data is not None: | |
df = st.session_state.processed_data | |
st.write(f"π Loaded {len(df)} papers with abstracts") | |
# Get question before processing | |
with question_container: | |
question = st.text_input( | |
"Enter your research question (optional):", | |
help="If provided, a question-focused summary will be generated after individual summaries" | |
) | |
# Single button for both processes | |
if not st.session_state.get('processing_started', False): | |
if st.button("Start Analysis"): | |
st.session_state.processing_started = True | |
# Show processing status and results | |
if st.session_state.get('processing_started', False): | |
# Individual Summaries Section | |
st.header("π Individual Paper Summaries") | |
# Generate summaries if not already done | |
if st.session_state.summaries is None: | |
try: | |
with st.spinner("Generating individual paper summaries..."): | |
model, tokenizer = get_model("summarize") | |
if model is None or tokenizer is None: | |
reset_processing_state() | |
return | |
start_time = time.time() | |
st.session_state.summaries = process_papers_in_batches( | |
df, model, tokenizer, batch_size=2 | |
) | |
end_time = time.time() | |
st.write(f"Processing time: {end_time - start_time:.2f} seconds") | |
except Exception as e: | |
st.error(f"Error generating summaries: {str(e)}") | |
reset_processing_state() | |
# Display summaries with improved sorting and filtering | |
if st.session_state.summaries is not None: | |
col1, col2 = st.columns(2) | |
with col1: | |
sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] | |
sort_column = st.selectbox("Sort/Filter by:", sort_options) | |
with col2: | |
if sort_column == 'Article Title': | |
ascending = st.radio( | |
"Sort order", | |
["A to Z", "Z to A"], | |
horizontal=True | |
) == "A to Z" | |
elif sort_column == 'Times Cited': | |
ascending = st.radio( | |
"Sort order", | |
["Most cited first", "Least cited first"], | |
horizontal=True | |
) == "Least cited first" | |
else: | |
ascending = True # Default for other columns | |
# Create display dataframe | |
display_df = df.copy() | |
display_df['Summary'] = st.session_state.summaries | |
display_df['Publication Year'] = display_df['Publication Year'].astype(int) | |
display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) | |
display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) | |
# Apply filters | |
filtered_df = create_filter_controls(display_df, sort_column) | |
# Apply sorting | |
if sort_column == 'Times Cited': | |
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) | |
elif sort_column == 'Article Title': | |
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) | |
else: | |
sorted_df = filtered_df | |
# Show number of filtered results | |
if len(sorted_df) != len(display_df): | |
st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") | |
# Apply custom styling | |
st.markdown(""" | |
<style> | |
.paper-info { | |
border: 1px solid #ddd; | |
padding: 15px; | |
margin-bottom: 20px; | |
border-radius: 5px; | |
} | |
.paper-section { | |
margin-bottom: 10px; | |
} | |
.section-header { | |
font-weight: bold; | |
color: #555; | |
margin-bottom: 8px; | |
} | |
.paper-title { | |
margin-top: 5px; | |
margin-bottom: 10px; | |
} | |
.paper-meta { | |
font-size: 0.9em; | |
color: #666; | |
} | |
.doi-link { | |
color: #0366d6; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Display papers using the filtered and sorted dataframe | |
for _, row in sorted_df.iterrows(): | |
paper_info_cols = st.columns([1, 1]) | |
with paper_info_cols[0]: # PAPER column | |
st.markdown('<div class="paper-section"><div class="section-header">PAPER</div>', unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class="paper-info"> | |
<div class="paper-title">{row['Article Title']}</div> | |
<div class="paper-meta"> | |
<strong>Authors:</strong> {row['Authors']}<br> | |
<strong>Source:</strong> {row['Source Title']}<br> | |
<strong>Publication Year:</strong> {row['Publication Year']}<br> | |
<strong>Times Cited:</strong> {row['Times Cited']}<br> | |
<strong>DOI:</strong> {row['DOI'] if pd.notna(row['DOI']) else 'None'} | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
with paper_info_cols[1]: # SUMMARY column | |
st.markdown('<div class="paper-section"><div class="section-header">SUMMARY</div>', unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class="paper-info"> | |
{row['Summary']} | |
</div> | |
""", unsafe_allow_html=True) | |
# Add spacing between papers | |
st.markdown("<div style='margin-bottom: 20px;'></div>", unsafe_allow_html=True) | |
# Question-focused Summary Section (only if question provided) | |
if question.strip(): | |
st.header("β Question-focused Summary") | |
if not st.session_state.get('focused_summary_generated', False): | |
try: | |
with st.spinner("Analyzing relevant papers..."): | |
if st.session_state.text_processor is None: | |
st.session_state.text_processor = TextProcessor() | |
model, tokenizer = get_model("question_focused") | |
if model is None or tokenizer is None: | |
raise Exception("Failed to load question-focused model") | |
results = st.session_state.text_processor.find_most_relevant_abstracts( | |
question, | |
df['Abstract'].tolist(), | |
top_k=5 | |
) | |
if not results['top_indices']: | |
st.warning("No papers found relevant to your question") | |
return | |
# Store relevant papers and scores | |
st.session_state.relevant_papers = df.iloc[results['top_indices']] | |
st.session_state.relevance_scores = results['scores'] | |
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() | |
st.session_state.focused_summary = generate_focused_summary( | |
question, | |
relevant_abstracts, | |
model, | |
tokenizer | |
) | |
st.session_state.focused_summary_generated = True | |
except Exception as e: | |
st.error(f"Error generating focused summary: {str(e)}") | |
reset_processing_state() | |
finally: | |
cleanup_model(model, tokenizer) | |
# Display focused summary results | |
if st.session_state.get('focused_summary_generated', False): | |
st.subheader("Summary") | |
st.write(st.session_state.focused_summary) | |
st.subheader("Most Relevant Papers") | |
relevant_papers = st.session_state.relevant_papers[ | |
['Article Title', 'Authors', 'Publication Year', 'DOI'] | |
].copy() | |
relevant_papers['Relevance Score'] = st.session_state.relevance_scores | |
relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) | |
st.dataframe(relevant_papers, hide_index=True) | |
if __name__ == "__main__": | |
main() |