import streamlit as st
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel
from text_processing import TextProcessor
import gc
import time
from pathlib import Path

# Configure page
st.set_page_config(
    page_title="Biomedical Papers Analysis",
    page_icon="🔬",
    layout="wide"
)

# Initialize session state
if 'processed_data' not in st.session_state:
    st.session_state.processed_data = None
if 'summaries' not in st.session_state:
    st.session_state.summaries = None
if 'text_processor' not in st.session_state:
    st.session_state.text_processor = None

def load_model(model_type):
    """Load appropriate model based on type"""
    if model_type == "summarize":
        base_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
        model = PeftModel.from_pretrained(base_model, "pendar02/results")
        tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
    else:  # question_focused
        base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
        model = PeftModel.from_pretrained(base_model, "pendar02/biobart-finetune")
        tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
    
    return model, tokenizer

@st.cache_data
def process_excel(uploaded_file):
    """Process uploaded Excel file"""
    try:
        df = pd.read_excel(uploaded_file)
        required_columns = ['Abstract', 'Article Title', 'Authors', 
                          'Source Title', 'Publication Year', 'DOI']
        
        # Check required columns
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            st.error(f"Missing required columns: {', '.join(missing_columns)}")
            return None
            
        return df[required_columns]
    except Exception as e:
        st.error(f"Error processing file: {str(e)}")
        return None

def generate_summary(text, model, tokenizer):
    """Generate summary for single abstract"""
    inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
    
    with torch.no_grad():
        summary_ids = model.generate(
            **{
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"],
                "max_length": 150,
                "min_length": 50,
                "num_beams": 4,
                "length_penalty": 2.0,
                "early_stopping": True
            }
        )
    
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

def generate_focused_summary(question, abstracts, model, tokenizer):
    """Generate focused summary based on question"""
    combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(abstracts)
    
    inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
    
    with torch.no_grad():
        summary_ids = model.generate(
            **{
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"],
                "max_length": 200,
                "min_length": 50,
                "num_beams": 4,
                "length_penalty": 2.0,
                "early_stopping": True
            }
        )
    
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

def main():
    st.title("🔬 Biomedical Papers Analysis")
    
    # Sidebar
    st.sidebar.header("About")
    st.sidebar.info(
        "This app analyzes biomedical research papers. Upload an Excel file "
        "containing paper details and abstracts to:"
        "\n- Generate individual summaries"
        "\n- Get question-focused insights"
    )
    
    # Initialize text processor if not already done
    if st.session_state.text_processor is None:
        with st.spinner("Loading NLP models..."):
            st.session_state.text_processor = TextProcessor()
    
    # File upload section
    uploaded_file = st.file_uploader(
        "Upload Excel file containing papers", 
        type=['xlsx', 'xls'],
        help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
    )
    
    if uploaded_file is not None:
        # Process Excel file
        if st.session_state.processed_data is None:
            with st.spinner("Processing file..."):
                df = process_excel(uploaded_file)
                if df is not None:
                    st.session_state.processed_data = df
        
        if st.session_state.processed_data is not None:
            df = st.session_state.processed_data
            st.write(f"📊 Loaded {len(df)} papers")
            
            # Individual Summaries Section
            st.header("📝 Individual Paper Summaries")
            
            if st.session_state.summaries is None and st.button("Generate Individual Summaries"):
                try:
                    with st.spinner("Generating summaries..."):
                        # Load summarization model
                        model, tokenizer = load_model("summarize")
                        
                        # Process abstracts
                        progress_bar = st.progress(0)
                        summaries = []
                        
                        for i, abstract in enumerate(df['Abstract']):
                            summary = generate_summary(abstract, model, tokenizer)
                            summaries.append(summary)
                            progress_bar.progress((i + 1) / len(df))
                        
                        st.session_state.summaries = summaries
                        
                        # Clear GPU memory
                        del model
                        del tokenizer
                        torch.cuda.empty_cache()
                        gc.collect()
                
                except Exception as e:
                    st.error(f"Error generating summaries: {str(e)}")
            
            if st.session_state.summaries is not None:
                # Display summaries with sorting options
                col1, col2 = st.columns(2)
                with col1:
                    sort_column = st.selectbox("Sort by:", df.columns)
                with col2:
                    ascending = st.checkbox("Ascending order", True)
                
                # Create display dataframe
                display_df = df.copy()
                display_df['Summary'] = st.session_state.summaries
                sorted_df = display_df.sort_values(by=sort_column, ascending=ascending)
                
                # Show interactive table
                st.dataframe(
                    sorted_df,
                    column_config={
                        "Abstract": st.column_config.TextColumn(
                            "Abstract",
                            width="medium",
                            help="Original abstract text"
                        ),
                        "Summary": st.column_config.TextColumn(
                            "Summary",
                            width="medium",
                            help="Generated summary"
                        )
                    },
                    hide_index=True
                )
            
            # Question-focused Summary Section
            st.header("❓ Question-focused Summary")
            question = st.text_input("Enter your research question:")
            
            if question and st.button("Generate Focused Summary"):
                try:
                    with st.spinner("Analyzing relevant papers..."):
                        # Find relevant abstracts
                        results = st.session_state.text_processor.find_most_relevant_abstracts(
                            question,
                            df['Abstract'].tolist(),
                            top_k=5
                        )
                        
                        # Show spell-check suggestion if needed
                        if results['processed_question']['original'] != results['processed_question']['corrected']:
                            st.info(f"Did you mean: {results['processed_question']['corrected']}?")
                        
                        # Load question-focused model
                        model, tokenizer = load_model("question_focused")
                        
                        # Get relevant abstracts and generate summary
                        relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist()
                        focused_summary = generate_focused_summary(
                            results['processed_question']['corrected'],
                            relevant_abstracts,
                            model,
                            tokenizer
                        )
                        
                        # Display results
                        st.subheader("Summary")
                        st.write(focused_summary)
                        
                        # Show relevant papers
                        st.subheader("Most Relevant Papers")
                        relevant_papers = df.iloc[results['top_indices']][
                            ['Article Title', 'Authors', 'Publication Year', 'DOI']
                        ]
                        relevant_papers['Relevance Score'] = results['scores']
                        st.dataframe(relevant_papers, hide_index=True)
                        
                        # Show identified medical terms
                        st.subheader("Identified Medical Terms")
                        st.write(", ".join(results['processed_question']['medical_entities']))
                        
                        # Clear GPU memory
                        del model
                        del tokenizer
                        torch.cuda.empty_cache()
                        gc.collect()
                
                except Exception as e:
                    st.error(f"Error generating focused summary: {str(e)}")

if __name__ == "__main__":
    main()