#file_path = "cleaned_bmd_medication_data.xlsx"

import streamlit as st
import pandas as pd
import plotly.graph_objs as go

# Constants from linear regression
REGRESSION_CONSTANTS = {
    'Femoral Neck': {
        'Female': {'mu': 0.916852, 'sigma': 0.120754},
        'Male': {'mu': 0.9687385325352573, 'sigma': 0.121870698023835}
    },
    'Total Hip': {
        'Female': {'mu': 0.955439, 'sigma': 0.125406},
        'Male': {'mu': 0.967924895046735, 'sigma': 0.13081439619361657}
    },
    'Lumbar spine (L1-L4)': {
        'Female': {'mu': 1.131649, 'sigma': 0.139618},
        'Male': {'mu': 1.1309707991669353, 'sigma': 0.1201836924980611}
    }
}

# Load medication data
@st.cache_data
def load_medication_data():
    file_path = "cleaned_bmd_medication_data.xlsx"
    return pd.read_excel(file_path)

# Calculate predicted BMD after medication
def calculate_bmd(bmd, percentage_increase):
    return bmd * (1 + percentage_increase)

# Convert BMD to T-score
def calculate_tscore(bmd, mu, sigma):
    return (bmd - mu) / sigma

# Generate prediction table for all drugs
def generate_predictions(medication_data, site, bmd, mu, sigma):
    site_data = medication_data[medication_data['Site'] == site]
    all_results = []
    
    for _, row in site_data.iterrows():
        drug = row['Medication']
        predictions = {
            'Year': ['0'], 
            'Year Index': [0],  # Numeric x-axis for plotting
            'Predicted BMD': [round(bmd, 3)], 
            'Predicted T-score': [round(calculate_tscore(bmd, mu, sigma), 1)]
        }
        
        year_index = 1
        for year in row.index[1:-1]:  # Skip 'Medication' and 'Site' columns
            if not pd.isna(row[year]):
                percentage_increase = row[year]
                predicted_bmd = bmd * (1 + percentage_increase)
                predicted_tscore = calculate_tscore(predicted_bmd, mu, sigma)
                
                predictions['Year'].append(year.replace(" Year", ""))  # Simplify year label
                predictions['Year Index'].append(year_index)  # Numeric x-axis
                predictions['Predicted BMD'].append(round(predicted_bmd, 3))
                predictions['Predicted T-score'].append(round(predicted_tscore, 1))
                year_index += 1
        
        all_results.append({'Drug': drug, 'Predictions': predictions})
    return all_results


# Display results as table and plots
def display_results(predictions, site):
    st.subheader(f"Predictions for {site}")
    
    for result in predictions:
        drug = result['Drug']
        predictions = result['Predictions']
        
        # Display table
        st.write(f"### {drug}")
        st.dataframe(pd.DataFrame(predictions))
        
        # Plot BMD and T-score using Year Index
        bmd_plot = go.Scatter(
            x=predictions['Year Index'], y=predictions['Predicted BMD'], mode='lines+markers',
            name='Predicted BMD', line=dict(color='blue')
        )
        tscore_plot = go.Scatter(
            x=predictions['Year Index'], y=predictions['Predicted T-score'], mode='lines+markers',
            name='Predicted T-score', line=dict(color='green')
        )
        
        # Combine plots in a single row
        col1, col2 = st.columns(2)
        with col1:
            st.plotly_chart(go.Figure(data=[bmd_plot], layout=go.Layout(
                title=f"{drug} - Predicted BMD", xaxis_title="Year", yaxis_title="BMD (g/cm²)",
                xaxis=dict(tickmode='array', tickvals=predictions['Year Index'], ticktext=predictions['Year'])
            )))
        with col2:
            st.plotly_chart(go.Figure(data=[tscore_plot], layout=go.Layout(
                title=f"{drug} - Predicted T-score", xaxis_title="Year", yaxis_title="T-score",
                xaxis=dict(tickmode='array', tickvals=predictions['Year Index'], ticktext=predictions['Year'])
            )))

# Generate summary of medications reaching the target T-score
def generate_goal_summary(predictions, target_tscore=-2.4):
    def year_to_int(year):
        # Convert "1st", "2nd", "3rd", etc., to numeric values
        try:
            return int(year.rstrip("stndrdth"))  # Remove suffixes like "st", "nd", "rd", "th"
        except ValueError:
            return 0  # Default to 0 if year cannot be converted

    goal_reached = []
    
    for result in predictions:
        drug = result['Drug']
        predictions_data = result['Predictions']
        
        for year, tscore in zip(predictions_data['Year'], predictions_data['Predicted T-score']):
            if tscore >= target_tscore:
                # Convert year to an integer using helper function
                numeric_year = year_to_int(year)
                goal_reached.append({'Medication': drug, 'Year': numeric_year})
                break  # Stop checking further years for this drug

    # Sort by year to prioritize earlier achievement
    goal_reached_sorted = sorted(goal_reached, key=lambda x: x['Year'])
    return goal_reached_sorted

# Display summary of goal-reaching medications
def display_goal_summary(goal_summary):
    st.subheader("Goal Treatment Summary (T-score ≥ -2.4)")
    
    if not goal_summary:
        st.info("No medications reach the target T-score.")
    else:
        summary_table = pd.DataFrame(goal_summary)
        st.table(summary_table)

# Medication Selection with Collapsible Categories
def select_medications():
    st.subheader("Select Medications to Display")
    show_all = st.checkbox("Show All Medications", key="show_all")
    
    selected_medications = []
    if not show_all:
        # Define categories and medications
        categories = {
            "Bisphosphonates": [
                "Alendronate", "Risedronate", "Ibandronate oral", 
                "Zoledronate", "Ibandronate IV (3mg)"
            ],
            "RANK Ligand Inhibitors": [
                "Denosumab", "Denosumab + Teriparatide"
            ],
            "Anabolic Agents": [
                "Teriparatide", "Teriparatide + Denosumab"
            ],
            "Sclerostin Inhibitors": [
                "Romosozumab", "Romosozumab + Denosumab", 
                "Romosozumab + Alendronate", "Romosozumab + Ibandronate", 
                "Romosozumab + Zoledronate"
            ]
        }
        
        # Create collapsible sections
        for category, medications in categories.items():
            with st.expander(category):
                for med in medications:
                    # Use a unique key for each checkbox
                    if st.checkbox(med, key=f"{category}_{med}"):
                        selected_medications.append(med)
    else:
        # Include all medications if "Show All" is selected
        selected_medications = [
            "Alendronate", "Risedronate", "Ibandronate oral", 
            "Zoledronate", "Ibandronate IV (3mg)", "Denosumab", 
            "Denosumab + Teriparatide", "Teriparatide", 
            "Teriparatide + Denosumab", "Romosozumab", 
            "Romosozumab + Denosumab", "Romosozumab + Alendronate", 
            "Romosozumab + Ibandronate", "Romosozumab + Zoledronate"
        ]
    
    return selected_medications

# Streamlit UI
# Main function
def main():
    st.title("BMD and T-score Prediction Tool")
    
    # DEXA Machine Selection
    dexa_machine = st.selectbox("DEXA Machine", ["LUNAR"])
    
    # Gender Selection
    gender = st.selectbox("Gender", ["Female", "Male"])
    
    # Location (Site) Selection with Mapping
    site_mapping = {
        'Lumbar spine (L1-L4)': 'LS',
        'Femoral Neck': 'FN',
        'Total Hip': 'TH'
    }
    site_options = list(site_mapping.keys())
    selected_site = st.selectbox("Select Region (Site)", site_options)
    site = site_mapping[selected_site]  # Map to the actual value in the dataset
    
    # Input patient data
    bmd_patient = st.number_input(
        "Initial BMD",
        min_value=0.000, max_value=2.000,
        value=0.800, step=0.001,
        format="%.3f"
    )
    
    # Medication Selection
    selected_medications = select_medications()  # Ensure this is only called once
    
    # Load constants and medication data
    medication_data = load_medication_data()
    constants = REGRESSION_CONSTANTS[selected_site][gender]

    # Generate and display predictions for selected medications
    if st.button("Predict"):
        all_predictions = generate_predictions(medication_data, site, bmd_patient, constants['mu'], constants['sigma'])
        filtered_predictions = [pred for pred in all_predictions if pred['Drug'] in selected_medications]
        
        if not filtered_predictions:
            st.warning("No medications selected. Please select at least one medication or use the 'Show All' option.")
        else:
            # Generate and display goal treatment summary
            goal_summary = generate_goal_summary(filtered_predictions, target_tscore=-2.4)
            display_goal_summary(goal_summary)
            
            # Display individual medication results
            display_results(filtered_predictions, selected_site)


if __name__ == "__main__":
    main()