import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree, export_text
import seaborn as sns
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
import plotly.express as px
import shap
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier

def load_data():
    data = pd.read_csv('exported_named_train_good.csv')
    data_test = pd.read_csv('exported_named_test_good.csv')
    X_train = data.drop("Target", axis=1)
    y_train = data['Target']
    X_test = data_test.drop('Target', axis=1)
    y_test = data_test['Target']
    return X_train, y_train, X_test, y_test, X_train.columns

def train_models(X_train, y_train, X_test, y_test):
    models = {
        "Logistic Regression": LogisticRegression(random_state=42),
        "Decision Tree": DecisionTreeClassifier(random_state=42),
        "Random Forest": RandomForestClassifier(n_estimators=100, min_samples_split=2,max_features=7, max_depth=None, random_state=42),
        "Gradient Boost": GradientBoostingClassifier(random_state=42),
        "Extreme Gradient Boosting": XGBClassifier(random_state=42, n_estimators=500, learning_rate=0.0789),
        "Light Gradient Boosting": LGBMClassifier(random_state=42, n_estimators=500, learning_rate=0.0789)
    }
    
    results = {}
    for name, model in models.items():
        model.fit(X_train, y_train)
        
        # Predictions
        y_train_pred = model.predict(X_train)
        y_test_pred = model.predict(X_test)
        
        # Metrics
        results[name] = {
            'model': model,
            'train_metrics': {
                'accuracy': accuracy_score(y_train, y_train_pred),
                'f1': f1_score(y_train, y_train_pred, average='weighted'),
                'precision': precision_score(y_train, y_train_pred),
                'recall': recall_score(y_train, y_train_pred),
                'roc_auc': roc_auc_score(y_train, y_train_pred)
            },
            'test_metrics': {
                'accuracy': accuracy_score(y_test, y_test_pred),
                'f1': f1_score(y_test, y_test_pred, average='weighted'),
                'precision': precision_score(y_test, y_test_pred),
                'recall': recall_score(y_test, y_test_pred),
                'roc_auc': roc_auc_score(y_test, y_test_pred)
            }
        }
    
    return results

def plot_model_performance(results):
    metrics = ['accuracy', 'f1', 'precision', 'recall', 'roc_auc']
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Training metrics
    train_data = {model: [results[model]['train_metrics'][metric] for metric in metrics] 
                 for model in results.keys()}
    train_df = pd.DataFrame(train_data, index=metrics)
    train_df.plot(kind='bar', ax=axes[0], title='Training Performance')
    axes[0].set_ylim(0, 1)
    
    # Test metrics
    test_data = {model: [results[model]['test_metrics'][metric] for metric in metrics] 
                for model in results.keys()}
    test_df = pd.DataFrame(test_data, index=metrics)
    test_df.plot(kind='bar', ax=axes[1], title='Test Performance')
    axes[1].set_ylim(0, 1)
    
    plt.tight_layout()
    return fig

def plot_feature_importance(model, feature_names, model_type):
    plt.figure(figsize=(10, 6))
    
    if model_type in ["Decision Tree", "Random Forest", "Gradient Boost"]:
        importance = model.feature_importances_
    elif model_type == "Logistic Regression":
        importance = np.abs(model.coef_[0])
    
    importance_df = pd.DataFrame({
        'feature': feature_names,
        'importance': importance
    }).sort_values('importance', ascending=True)
    
    plt.barh(importance_df['feature'], importance_df['importance'])
    plt.title(f"Feature Importance - {model_type}")
    return plt.gcf()

def prepare_clustering_data(data, numeric_columns):
    scaler = StandardScaler()
    scaled_features = scaler.fit_transform(data[numeric_columns])
    return scaled_features, scaler

def perform_clustering(scaled_features, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(scaled_features)
    return kmeans, cluster_labels

def plot_clusters_3d(data, labels, features, product_category):
    pca = PCA(n_components=3)
    components = pca.fit_transform(data)
    
    df_plot = pd.DataFrame({
        'PC1': components[:, 0],
        'PC2': components[:, 1],
        'PC3': components[:, 2],
        'Cluster': [f"Groupe {i}" for i in labels]
    })
    
    fig = px.scatter_3d(
        df_plot,
        x='PC1',
        y='PC2',
        z='PC3',
        color='Cluster',
        title=f'Analyse des sous-groupes pour {product_category}',
        labels={
            'PC1': 'Composante 1',
            'PC2': 'Composante 2',
            'PC3': 'Composante 3'
        }
    )
    
    fig.update_layout(
        scene=dict(
            xaxis_title='Composante 1',
            yaxis_title='Composante 2',
            zaxis_title='Composante 3'
        ),
        legend_title_text='Sous-groupes'
    )
    
    return fig

def analyze_clusters(data, cluster_labels, numeric_columns, product_category):
    data_with_clusters = data.copy()
    data_with_clusters['Cluster'] = cluster_labels
    
    cluster_stats = []
    for cluster in range(len(np.unique(cluster_labels))):
        cluster_data = data_with_clusters[data_with_clusters['Cluster'] == cluster]
        stats = {
            'Cluster': cluster,
            'Taille': len(cluster_data),
            'Product': product_category,
            'Caractéristiques principales': {}
        }
        
        for col in numeric_columns:
            stats['Caractéristiques principales'][col] = cluster_data[col].mean()
        
        cluster_stats.append(stats)
    
    return cluster_stats

def add_clustering_analysis(data):
    st.header("Analyse par Clustering des Produits Acceptés")
    
    if data is None:
        st.error("Veuillez charger des données pour l'analyse")
        return
        
    # Filtrer uniquement les clients ayant accepté un produit
    accepted_data = data[data['ProdTaken'] == 1]
    
    if len(accepted_data) == 0:
        st.error("Aucune donnée trouvée pour les produits acceptés")
        return
        
    st.write(f"Nombre total de produits acceptés: {len(accepted_data)}")
    
    # Obtenir les différents types de produits proposés
    product_types = accepted_data['ProductPitched'].unique()
    st.write(f"Types de produits disponibles: {', '.join(product_types)}")
    
    # Sélection des caractéristiques pour le clustering
    numeric_columns = st.multiselect(
        "Sélectionner les caractéristiques pour l'analyse",
        data.select_dtypes(include=['float64', 'int64']).columns,
        help="Choisissez les caractéristiques numériques pertinentes pour l'analyse"
    )
    
    if numeric_columns:
        for product in product_types:
            st.subheader(f"\nAnalyse du produit: {product}")
            
            product_data = accepted_data[accepted_data['ProductPitched'] == product]
            st.write(f"Nombre de clients ayant accepté ce produit: {len(product_data)}")
            
            max_clusters = min(len(product_data) - 1, 10)
            if max_clusters < 2:
                st.warning(f"Pas assez de données pour le clustering du produit {product}")
                continue
            
            n_clusters = st.slider(
                f"Nombre de sous-groupes pour {product}", 
                2, max_clusters, 
                min(3, max_clusters),
                key=f"slider_{product}"
            )
            
            scaled_features, _ = prepare_clustering_data(product_data, numeric_columns)
            kmeans, cluster_labels = perform_clustering(scaled_features, n_clusters)
            
            silhouette_avg = silhouette_score(scaled_features, cluster_labels)
            st.write(f"Score de silhouette: {silhouette_avg:.3f}")
            
            fig = plot_clusters_3d(scaled_features, cluster_labels, numeric_columns, product)
            st.plotly_chart(fig)
            
            st.write("### Caractéristiques des sous-groupes")
            cluster_stats = analyze_clusters(product_data, cluster_labels, numeric_columns, product)
            
            global_means = product_data[numeric_columns].mean()
            
            for stats in cluster_stats:
                st.write(f"\n**Sous-groupe {stats['Cluster']} ({stats['Taille']} clients)**")
                
                comparison_data = []
                for feat, value in stats['Caractéristiques principales'].items():
                    global_mean = global_means[feat]
                    diff_percent = ((value - global_mean) / global_mean * 100)
                    comparison_data.append({
                        'Caractéristique': feat,
                        'Valeur moyenne du groupe': f"{value:.2f}",
                        'Moyenne globale': f"{global_mean:.2f}",
                        'Différence (%)': f"{diff_percent:+.1f}%"
                    })
                
                comparison_df = pd.DataFrame(comparison_data)
                st.table(comparison_df)
                
                st.write("### Recommandations marketing")
                distinctive_features = []
                for col in numeric_columns:
                    cluster_mean = product_data[cluster_labels == stats['Cluster']][col].mean()
                    global_mean = product_data[col].mean()
                    diff_percent = ((cluster_mean - global_mean) / global_mean * 100)
                    
                    if abs(diff_percent) > 10:
                        distinctive_features.append({
                            'feature': col,
                            'diff': diff_percent,
                            'value': cluster_mean
                        })
                
                if distinctive_features:
                    recommendations = [
                        f"- Groupe avec {feat['feature']} {'supérieur' if feat['diff'] > 0 else 'inférieur'} " \
                        f"à la moyenne ({feat['diff']:+.1f}%)"
                        for feat in distinctive_features
                    ]
                    st.write("\n".join(recommendations))
                else:
                    st.write("- Pas de caractéristiques fortement distinctives identifiées")


def app():
    st.title("Interpréteur de Modèles ML")
    
    # Load data
    X_train, y_train, X_test, y_test, feature_names = load_data()
    
    # Train models if not in session state
    if 'model_results' not in st.session_state:
        with st.spinner("Entraînement des modèles en cours..."):
            st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
    
    # Sidebar
    st.sidebar.title("Navigation")
    selected_model = st.sidebar.selectbox(
        "Sélectionnez un modèle",
        list(st.session_state.model_results.keys())
    )
    
    page = st.sidebar.radio(
        "Sélectionnez une section",
        ["Performance des modèles", 
         "Interprétation du modèle", 
         "Analyse des caractéristiques",
         "Simulateur de prédictions",
        "Analyse par Clustering"]
    )
    
    current_model = st.session_state.model_results[selected_model]['model']
    
    # Performance des modèles
    if page == "Performance des modèles":
        st.header("Performance des modèles")
        
        # Plot global performance comparison
        st.subheader("Comparaison des performances")
        performance_fig = plot_model_performance(st.session_state.model_results)
        st.pyplot(performance_fig)
        
        # Detailed metrics for selected model
        st.subheader(f"Métriques détaillées - {selected_model}")
        col1, col2 = st.columns(2)
        
        with col1:
            st.write("Métriques d'entraînement:")
            for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
                st.write(f"{metric}: {value:.4f}")
        
        with col2:
            st.write("Métriques de test:")
            for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
                st.write(f"{metric}: {value:.4f}")
    
    # Interprétation du modèle
    elif page == "Interprétation du modèle":
        st.header(f"Interprétation du modèle - {selected_model}")
        
        if selected_model in ["Decision Tree", "Random Forest"]:
            if selected_model == "Decision Tree":
                st.subheader("Visualisation de l'arbre")
                max_depth = st.slider("Profondeur maximale à afficher", 1, 5, 3)
                fig, ax = plt.subplots(figsize=(20, 10))
                plot_tree(current_model, feature_names=list(feature_names), 
                         max_depth=max_depth, filled=True, rounded=True)
                st.pyplot(fig)
            
            st.subheader("Règles de décision importantes")
            if selected_model == "Decision Tree":
                st.text(export_text(current_model, feature_names=list(feature_names)))
        
        # SHAP values for all models
        st.subheader("SHAP Values")
        with st.spinner("Calcul des valeurs SHAP en cours..."):
            explainer = shap.TreeExplainer(current_model) if selected_model != "Logistic Regression" \
                       else shap.LinearExplainer(current_model, X_train)
            shap_values = explainer.shap_values(X_train[:100])  # Using first 100 samples for speed
            
            fig, ax = plt.subplots(figsize=(10, 6))
            shap.summary_plot(shap_values, X_train[:100], feature_names=list(feature_names),
                            show=False)
            st.pyplot(fig)
    
    # Analyse des caractéristiques
    elif page == "Analyse des caractéristiques":
        st.header("Analyse des caractéristiques")
        
        # Feature importance
        st.subheader("Importance des caractéristiques")
        importance_fig = plot_feature_importance(current_model, feature_names, selected_model)
        st.pyplot(importance_fig)
        
        # Feature correlation
        st.subheader("Matrice de corrélation")
        corr_matrix = X_train.corr()
        fig, ax = plt.subplots(figsize=(10, 8))
        sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0)
        st.pyplot(fig)
        
    elif page == "Analyse par Clustering":
        # Charger les données pour le clustering
        uploaded_file = st.file_uploader("Charger les données pour le clustering (CSV)", type="csv")
        if uploaded_file is not None:
            data = pd.read_csv(uploaded_file)
            add_clustering_analysis(data)
        else:
            st.warning("Veuillez charger un fichier CSV pour l'analyse par clustering")

    
    # Simulateur de prédictions
    else:
        st.header("Simulateur de prédictions")
        
        input_values = {}
        for feature in feature_names:
            if X_train[feature].dtype == 'object':
                input_values[feature] = st.selectbox(
                    f"Sélectionnez {feature}",
                    options=X_train[feature].unique()
                )
            else:
                input_values[feature] = st.slider(
                    f"Valeur pour {feature}",
                    float(X_train[feature].min()),
                    float(X_train[feature].max()),
                    float(X_train[feature].mean())
                )
        
        if st.button("Prédire"):
            input_df = pd.DataFrame([input_values])
            
            prediction = current_model.predict_proba(input_df)
            
            st.write("Probabilités prédites:")
            st.write({f"Classe {i}": f"{prob:.2%}" for i, prob in enumerate(prediction[0])})
            
            if selected_model == "Decision Tree":
                st.subheader("Chemin de décision")
                node_indicator = current_model.decision_path(input_df)
                leaf_id = current_model.apply(input_df)
                
                node_index = node_indicator.indices[node_indicator.indptr[0]:node_indicator.indptr[1]]
                
                rules = []
                for node_id in node_index:
                    if node_id != leaf_id[0]:
                        threshold = current_model.tree_.threshold[node_id]
                        feature = feature_names[current_model.tree_.feature[node_id]]
                        if input_df.iloc[0][feature] <= threshold:
                            rules.append(f"{feature} ≤ {threshold:.2f}")
                        else:
                            rules.append(f"{feature} > {threshold:.2f}")
                
                for rule in rules:
                    st.write(rule)

if __name__ == "__main__":
    app()