import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree, export_text
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve

data = pd.read_csv('exported_named_train_good.csv')
X_test = pd.read_csv('exported_named_test.csv').values
X_train = data.drop("Target", axis=1).values
y_train = data['Target'].values

models={
    "Logisitic Regression":LogisticRegression(),
    "Decision Tree":DecisionTreeClassifier(),
    "Random Forest":RandomForestClassifier(),
    "Gradient Boost":GradientBoostingClassifier()
}

for name, model in models.items():

    model.fit(X_train, y_train)

    # Make predictions
    y_train_pred = model.predict(X_train)
    y_test_pred = model.predict(X_test)

    # Training set performance
    model_train_accuracy = accuracy_score(y_train, y_train_pred) # Calculate Accuracy
    model_train_f1 = f1_score(y_train, y_train_pred, average='weighted') # Calculate F1-score
    model_train_precision = precision_score(y_train, y_train_pred) # Calculate Precision
    model_train_recall = recall_score(y_train, y_train_pred) # Calculate Recall
    model_train_rocauc_score = roc_auc_score(y_train, y_train_pred)

    # Test set performance
    model_test_accuracy = accuracy_score(y_test, y_test_pred) # Calculate Accuracy
    model_test_f1 = f1_score(y_test, y_test_pred, average='weighted') # Calculate F1-score
    model_test_precision = precision_score(y_test, y_test_pred) # Calculate Precision
    model_test_recall = recall_score(y_test, y_test_pred) # Calculate Recall
    model_test_rocauc_score = roc_auc_score(y_test, y_test_pred) #Calculate Roc

    print(name)

    print('Model performance for Training set')
    print("- Accuracy: {:.4f}".format(model_train_accuracy))
    print('- F1 score: {:.4f}'.format(model_train_f1))
    
    print('- Precision: {:.4f}'.format(model_train_precision))
    print('- Recall: {:.4f}'.format(model_train_recall))
    print('- Roc Auc Score: {:.4f}'.format(model_train_rocauc_score))

    
    
    print('----------------------------------')
    
    print('Model performance for Test set')
    print('- Accuracy: {:.4f}'.format(model_test_accuracy))
    print('- F1 score: {:.4f}'.format(model_test_f1))
    print('- Precision: {:.4f}'.format(model_test_precision))
    print('- Recall: {:.4f}'.format(model_test_recall))
    print('- Roc Auc Score: {:.4f}'.format(model_test_rocauc_score))

    
    print('='*35)
    print('\n')

def load_model_and_data():
    # Ici vous chargeriez votre modèle et données
    # Pour l'exemple, on suppose qu'ils sont disponibles comme:
    # model = loaded_model
    # X = loaded_X
    # y = loaded_y
    # feature_names = X.columns
    model = models['Decision Tree']
    data = pd.read_csv('exported_named_train.csv')
    X = data.drop("Target", axis=1)
    y = data['Target']

    return model, X, y
    
   

def app():
    st.title("Interpréteur d'Arbre de Décision")
    
    # Sidebar pour les contrôles
    st.sidebar.header("Paramètres d'analyse")
    
    # Section 1: Vue globale du modèle
    st.header("Vue globale du modèle")
    col1, col2 = st.columns(2)
    
    with col1:
        model, X, y = load_model_and_data()
        feature_names = X.columns
        st.subheader("Importance des caractéristiques")
        importance_plot = plt.figure(figsize=(10, 6))
        # Remplacer par vos features et leurs importances
        feature_importance = pd.DataFrame({
            'feature': feature_names,
            'importance': model.feature_importances_
        }).sort_values('importance', ascending=True)
        plt.barh(feature_importance['feature'], feature_importance['importance'])
        st.pyplot(importance_plot)
    
    with col2:
        st.subheader("Statistiques du modèle")
        st.write(f"Profondeur de l'arbre: {model.get_depth()}")
        st.write(f"Nombre de feuilles: {model.get_n_leaves()}")
    
    # Section 2: Explorateur de règles
    st.header("2. Explorateur de règles de décision")
    max_depth = st.slider("Profondeur maximale à afficher", 1, model.get_depth(), 3)
    
    tree_text = export_text(model, feature_names=list(feature_names), max_depth=max_depth)
    st.text(tree_text)
    
    # Section 3: Analyse de cohortes
    st.header("3. Analyse de cohortes")
    
    # Sélection des caractéristiques pour définir les cohortes
    selected_features = st.multiselect(
        "Sélectionnez les caractéristiques pour définir les cohortes",
        feature_names,
        max_selections=2
    )
    
    if len(selected_features) > 0:
        # Création des cohortes basées sur les caractéristiques sélectionnées
        def create_cohorts(X, features):
            cohort_def = X[features].copy()
            for feat in features:
                if X[feat].dtype == 'object' or len(X[feat].unique()) < 10:
                    cohort_def[feat] = X[feat]
                else:
                    cohort_def[feat] = pd.qcut(X[feat], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
            return cohort_def.apply(lambda x: ' & '.join(x.astype(str)), axis=1)
        
        cohorts = create_cohorts(X, selected_features)
        
        # Analyse des prédictions par cohorte
        cohort_analysis = pd.DataFrame({
            'Cohorte': cohorts,
            'Prédiction': model.predict(X)
        })
        
        cohort_stats = cohort_analysis.groupby('Cohorte')['Prédiction'].agg(['count', 'mean'])
        cohort_stats.columns = ['Nombre d\'observations', 'Taux de prédiction positive']
        
        st.write("Statistiques par cohorte:")
        st.dataframe(cohort_stats)
        
        # Visualisation des cohortes
        cohort_viz = plt.figure(figsize=(10, 6))
        sns.barplot(data=cohort_analysis, x='Cohorte', y='Prédiction')
        plt.xticks(rotation=45)
        st.pyplot(cohort_viz)
    
    # Section 4: Simulateur de prédictions
    st.header("4. Simulateur de prédictions")
    
    # Interface pour entrer des valeurs
    input_values = {}
    for feature in feature_names:
        if X[feature].dtype == 'object':
            input_values[feature] = st.selectbox(
                f"Sélectionnez {feature}",
                options=X[feature].unique()
            )
        else:
            input_values[feature] = st.slider(
                f"Valeur pour {feature}",
                float(X[feature].min()),
                float(X[feature].max()),
                float(X[feature].mean())
            )
    
    if st.button("Prédire"):
        # Création du DataFrame pour la prédiction
        input_df = pd.DataFrame([input_values])
        
        # Prédiction
        prediction = model.predict_proba(input_df)
        
        # Affichage du résultat
        st.write("Probabilités prédites:")
        st.write({f"Classe {i}": f"{prob:.2%}" for i, prob in enumerate(prediction[0])})
        
        # Chemin de décision pour cette prédiction
        st.subheader("Chemin de décision")
        node_indicator = model.decision_path(input_df)
        leaf_id = model.apply(input_df)
        
        feature_names = list(feature_names)
        node_index = node_indicator.indices[node_indicator.indptr[0]:node_indicator.indptr[1]]
        
        rules = []
        for node_id in node_index:
            if node_id != leaf_id[0]:
                threshold = model.tree_.threshold[node_id]
                feature = feature_names[model.tree_.feature[node_id]]
                if input_df.iloc[0][feature] <= threshold:
                    rules.append(f"{feature} ≤ {threshold:.2f}")
                else:
                    rules.append(f"{feature} > {threshold:.2f}")
        
        for rule in rules:
            st.write(rule)

if __name__ == "__main__":
    app()