analist commited on
Commit
3a61454
·
verified ·
1 Parent(s): 473af1a

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +144 -0
main.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from sklearn.tree import plot_tree, export_text
6
+ import seaborn as sns
7
+ from sklearn.preprocessing import LabelEncoder
8
+
9
+ def load_model_and_data():
10
+ # Ici vous chargeriez votre modèle et données
11
+ # Pour l'exemple, on suppose qu'ils sont disponibles comme:
12
+ # model = loaded_model
13
+ # X = loaded_X
14
+ # y = loaded_y
15
+ # feature_names = X.columns
16
+ pass
17
+
18
+ def app():
19
+ st.title("Interpréteur d'Arbre de Décision")
20
+
21
+ # Sidebar pour les contrôles
22
+ st.sidebar.header("Paramètres d'analyse")
23
+
24
+ # Section 1: Vue globale du modèle
25
+ st.header("1. Vue globale du modèle")
26
+ col1, col2 = st.columns(2)
27
+
28
+ with col1:
29
+ st.subheader("Importance des caractéristiques")
30
+ importance_plot = plt.figure(figsize=(10, 6))
31
+ # Remplacer par vos features et leurs importances
32
+ feature_importance = pd.DataFrame({
33
+ 'feature': feature_names,
34
+ 'importance': model.feature_importances_
35
+ }).sort_values('importance', ascending=True)
36
+ plt.barh(feature_importance['feature'], feature_importance['importance'])
37
+ st.pyplot(importance_plot)
38
+
39
+ with col2:
40
+ st.subheader("Statistiques du modèle")
41
+ st.write(f"Profondeur de l'arbre: {model.get_depth()}")
42
+ st.write(f"Nombre de feuilles: {model.get_n_leaves()}")
43
+
44
+ # Section 2: Explorateur de règles
45
+ st.header("2. Explorateur de règles de décision")
46
+ max_depth = st.slider("Profondeur maximale à afficher", 1, model.get_depth(), 3)
47
+
48
+ tree_text = export_text(model, feature_names=list(feature_names), max_depth=max_depth)
49
+ st.text(tree_text)
50
+
51
+ # Section 3: Analyse de cohortes
52
+ st.header("3. Analyse de cohortes")
53
+
54
+ # Sélection des caractéristiques pour définir les cohortes
55
+ selected_features = st.multiselect(
56
+ "Sélectionnez les caractéristiques pour définir les cohortes",
57
+ feature_names,
58
+ max_selections=2
59
+ )
60
+
61
+ if len(selected_features) > 0:
62
+ # Création des cohortes basées sur les caractéristiques sélectionnées
63
+ def create_cohorts(X, features):
64
+ cohort_def = X[features].copy()
65
+ for feat in features:
66
+ if X[feat].dtype == 'object' or len(X[feat].unique()) < 10:
67
+ cohort_def[feat] = X[feat]
68
+ else:
69
+ cohort_def[feat] = pd.qcut(X[feat], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
70
+ return cohort_def.apply(lambda x: ' & '.join(x.astype(str)), axis=1)
71
+
72
+ cohorts = create_cohorts(X, selected_features)
73
+
74
+ # Analyse des prédictions par cohorte
75
+ cohort_analysis = pd.DataFrame({
76
+ 'Cohorte': cohorts,
77
+ 'Prédiction': model.predict(X)
78
+ })
79
+
80
+ cohort_stats = cohort_analysis.groupby('Cohorte')['Prédiction'].agg(['count', 'mean'])
81
+ cohort_stats.columns = ['Nombre d\'observations', 'Taux de prédiction positive']
82
+
83
+ st.write("Statistiques par cohorte:")
84
+ st.dataframe(cohort_stats)
85
+
86
+ # Visualisation des cohortes
87
+ cohort_viz = plt.figure(figsize=(10, 6))
88
+ sns.barplot(data=cohort_analysis, x='Cohorte', y='Prédiction')
89
+ plt.xticks(rotation=45)
90
+ st.pyplot(cohort_viz)
91
+
92
+ # Section 4: Simulateur de prédictions
93
+ st.header("4. Simulateur de prédictions")
94
+
95
+ # Interface pour entrer des valeurs
96
+ input_values = {}
97
+ for feature in feature_names:
98
+ if X[feature].dtype == 'object':
99
+ input_values[feature] = st.selectbox(
100
+ f"Sélectionnez {feature}",
101
+ options=X[feature].unique()
102
+ )
103
+ else:
104
+ input_values[feature] = st.slider(
105
+ f"Valeur pour {feature}",
106
+ float(X[feature].min()),
107
+ float(X[feature].max()),
108
+ float(X[feature].mean())
109
+ )
110
+
111
+ if st.button("Prédire"):
112
+ # Création du DataFrame pour la prédiction
113
+ input_df = pd.DataFrame([input_values])
114
+
115
+ # Prédiction
116
+ prediction = model.predict_proba(input_df)
117
+
118
+ # Affichage du résultat
119
+ st.write("Probabilités prédites:")
120
+ st.write({f"Classe {i}": f"{prob:.2%}" for i, prob in enumerate(prediction[0])})
121
+
122
+ # Chemin de décision pour cette prédiction
123
+ st.subheader("Chemin de décision")
124
+ node_indicator = model.decision_path(input_df)
125
+ leaf_id = model.apply(input_df)
126
+
127
+ feature_names = list(feature_names)
128
+ node_index = node_indicator.indices[node_indicator.indptr[0]:node_indicator.indptr[1]]
129
+
130
+ rules = []
131
+ for node_id in node_index:
132
+ if node_id != leaf_id[0]:
133
+ threshold = model.tree_.threshold[node_id]
134
+ feature = feature_names[model.tree_.feature[node_id]]
135
+ if input_df.iloc[0][feature] <= threshold:
136
+ rules.append(f"{feature} ≤ {threshold:.2f}")
137
+ else:
138
+ rules.append(f"{feature} > {threshold:.2f}")
139
+
140
+ for rule in rules:
141
+ st.write(rule)
142
+
143
+ if __name__ == "__main__":
144
+ app()