analist commited on
Commit
2644c4d
·
verified ·
1 Parent(s): 1ad2cd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -107
app.py CHANGED
@@ -89,132 +89,169 @@ def load_model_and_data():
89
 
90
 
91
 
 
 
 
 
 
 
 
 
 
 
92
  def app():
93
  st.title("Interpréteur d'Arbre de Décision")
94
 
95
- # Sidebar pour les contrôles
96
- st.sidebar.header("Paramètres d'analyse")
97
-
98
- # Section 1: Vue globale du modèle
99
- st.header("Vue globale du modèle")
100
- col1, col2 = st.columns(2)
101
-
102
- with col1:
103
- model, X, y = load_model_and_data()
104
- feature_names = X.columns
105
- st.subheader("Importance des caractéristiques")
106
- importance_plot = plt.figure(figsize=(10, 6))
107
- # Remplacer par vos features et leurs importances
108
- feature_importance = pd.DataFrame({
109
- 'feature': feature_names,
110
- 'importance': model.feature_importances_
111
- }).sort_values('importance', ascending=True)
112
- plt.barh(feature_importance['feature'], feature_importance['importance'])
113
- st.pyplot(importance_plot)
114
 
115
- with col2:
116
- st.subheader("Statistiques du modèle")
117
- st.write(f"Profondeur de l'arbre: {model.get_depth()}")
118
- st.write(f"Nombre de feuilles: {model.get_n_leaves()}")
119
 
120
- # Section 2: Explorateur de règles
121
- st.header("2. Explorateur de règles de décision")
122
- max_depth = st.slider("Profondeur maximale à afficher", 1, model.get_depth(), 3)
123
-
124
- tree_text = export_text(model, feature_names=list(feature_names), max_depth=max_depth)
125
- st.text(tree_text)
126
-
127
- # Section 3: Analyse de cohortes
128
- st.header("3. Analyse de cohortes")
129
-
130
- # Sélection des caractéristiques pour définir les cohortes
131
- selected_features = st.multiselect(
132
- "Sélectionnez les caractéristiques pour définir les cohortes",
133
- feature_names,
134
- max_selections=2
135
  )
136
 
137
- if len(selected_features) > 0:
138
- # Création des cohortes basées sur les caractéristiques sélectionnées
139
- def create_cohorts(X, features):
140
- cohort_def = X[features].copy()
141
- for feat in features:
142
- if X[feat].dtype == 'object' or len(X[feat].unique()) < 10:
143
- cohort_def[feat] = X[feat]
144
- else:
145
- cohort_def[feat] = pd.qcut(X[feat], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
146
- return cohort_def.apply(lambda x: ' & '.join(x.astype(str)), axis=1)
147
 
148
- cohorts = create_cohorts(X, selected_features)
 
 
 
 
 
 
 
 
149
 
150
- # Analyse des prédictions par cohorte
151
- cohort_analysis = pd.DataFrame({
152
- 'Cohorte': cohorts,
153
- 'Prédiction': model.predict(X)
154
- })
 
 
 
155
 
156
- cohort_stats = cohort_analysis.groupby('Cohorte')['Prédiction'].agg(['count', 'mean'])
157
- cohort_stats.columns = ['Nombre d\'observations', 'Taux de prédiction positive']
 
 
158
 
159
- st.write("Statistiques par cohorte:")
160
- st.dataframe(cohort_stats)
161
 
162
- # Visualisation des cohortes
163
- cohort_viz = plt.figure(figsize=(10, 6))
164
- sns.barplot(data=cohort_analysis, x='Cohorte', y='Prédiction')
165
- plt.xticks(rotation=45)
166
- st.pyplot(cohort_viz)
167
-
168
- # Section 4: Simulateur de prédictions
169
- st.header("4. Simulateur de prédictions")
170
-
171
- # Interface pour entrer des valeurs
172
- input_values = {}
173
- for feature in feature_names:
174
- if X[feature].dtype == 'object':
175
- input_values[feature] = st.selectbox(
176
- f"Sélectionnez {feature}",
177
- options=X[feature].unique()
178
- )
179
  else:
180
- input_values[feature] = st.slider(
181
- f"Valeur pour {feature}",
182
- float(X[feature].min()),
183
- float(X[feature].max()),
184
- float(X[feature].mean())
 
 
 
 
185
  )
 
 
 
 
186
 
187
- if st.button("Prédire"):
188
- # Création du DataFrame pour la prédiction
189
- input_df = pd.DataFrame([input_values])
190
-
191
- # Prédiction
192
- prediction = model.predict_proba(input_df)
193
 
194
- # Affichage du résultat
195
- st.write("Probabilités prédites:")
196
- st.write({f"Classe {i}": f"{prob:.2%}" for i, prob in enumerate(prediction[0])})
 
 
197
 
198
- # Chemin de décision pour cette prédiction
199
- st.subheader("Chemin de décision")
200
- node_indicator = model.decision_path(input_df)
201
- leaf_id = model.apply(input_df)
202
-
203
- feature_names = list(feature_names)
204
- node_index = node_indicator.indices[node_indicator.indptr[0]:node_indicator.indptr[1]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- rules = []
207
- for node_id in node_index:
208
- if node_id != leaf_id[0]:
209
- threshold = model.tree_.threshold[node_id]
210
- feature = feature_names[model.tree_.feature[node_id]]
211
- if input_df.iloc[0][feature] <= threshold:
212
- rules.append(f"{feature} ≤ {threshold:.2f}")
213
- else:
214
- rules.append(f"{feature} > {threshold:.2f}")
 
 
 
 
 
215
 
216
- for rule in rules:
217
- st.write(rule)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
  app()
 
89
 
90
 
91
 
92
+ import streamlit as st
93
+ import pandas as pd
94
+ import numpy as np
95
+ import matplotlib.pyplot as plt
96
+ from sklearn.tree import plot_tree, export_text
97
+ import seaborn as sns
98
+ from sklearn.preprocessing import LabelEncoder
99
+ from dtreeviz.trees import dtreeviz
100
+
101
+
102
  def app():
103
  st.title("Interpréteur d'Arbre de Décision")
104
 
105
+ # Chargement du modèle et des données
106
+ model, X, y, feature_names = load_model_and_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ if model is None:
109
+ st.warning("Veuillez charger un modèle pour commencer.")
110
+ return
 
111
 
112
+ # Sidebar avec les sections
113
+ st.sidebar.title("Navigation")
114
+ page = st.sidebar.radio(
115
+ "Sélectionnez une section",
116
+ ["Vue globale du modèle",
117
+ "Explorateur de règles",
118
+ "Analyse de cohortes",
119
+ "Simulateur de prédictions"]
 
 
 
 
 
 
 
120
  )
121
 
122
+ # Vue globale du modèle
123
+ if page == "Vue globale du modèle":
124
+ st.header("Vue globale du modèle")
125
+ col1, col2 = st.columns(2)
 
 
 
 
 
 
126
 
127
+ with col1:
128
+ st.subheader("Importance des caractéristiques")
129
+ importance_plot = plt.figure(figsize=(10, 6))
130
+ feature_importance = pd.DataFrame({
131
+ 'feature': feature_names,
132
+ 'importance': model.feature_importances_
133
+ }).sort_values('importance', ascending=True)
134
+ plt.barh(feature_importance['feature'], feature_importance['importance'])
135
+ st.pyplot(importance_plot)
136
 
137
+ with col2:
138
+ st.subheader("Statistiques du modèle")
139
+ st.write(f"Profondeur de l'arbre: {model.get_depth()}")
140
+ st.write(f"Nombre de feuilles: {model.get_n_leaves()}")
141
+
142
+ # Explorateur de règles
143
+ elif page == "Explorateur de règles":
144
+ st.header("Explorateur de règles de décision")
145
 
146
+ viz_type = st.radio(
147
+ "Type de visualisation",
148
+ ["Texte", "Graphique interactif"]
149
+ )
150
 
151
+ max_depth = st.slider("Profondeur maximale à afficher", 1, model.get_depth(), 3)
 
152
 
153
+ if viz_type == "Texte":
154
+ tree_text = export_text(model, feature_names=list(feature_names), max_depth=max_depth)
155
+ st.text(tree_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  else:
157
+ # Création de la visualisation dtreeviz
158
+ viz = dtreeviz(
159
+ model,
160
+ X,
161
+ y,
162
+ target_name="target",
163
+ feature_names=list(feature_names),
164
+ class_names=list(map(str, model.classes_)),
165
+ max_depth=max_depth
166
  )
167
+
168
+ # Conversion en PNG pour Streamlit
169
+ png_data = viz.save_png()
170
+ st.image(png_data, use_column_width=True)
171
 
172
+ # Analyse de cohortes
173
+ elif page == "Analyse de cohortes":
174
+ st.header("Analyse de cohortes")
 
 
 
175
 
176
+ selected_features = st.multiselect(
177
+ "Sélectionnez les caractéristiques pour définir les cohortes",
178
+ feature_names,
179
+ max_selections=2
180
+ )
181
 
182
+ if len(selected_features) > 0:
183
+ def create_cohorts(X, features):
184
+ cohort_def = X[features].copy()
185
+ for feat in features:
186
+ if X[feat].dtype == 'object' or len(X[feat].unique()) < 10:
187
+ cohort_def[feat] = X[feat]
188
+ else:
189
+ cohort_def[feat] = pd.qcut(X[feat], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
190
+ return cohort_def.apply(lambda x: ' & '.join(x.astype(str)), axis=1)
191
+
192
+ cohorts = create_cohorts(X, selected_features)
193
+
194
+ cohort_analysis = pd.DataFrame({
195
+ 'Cohorte': cohorts,
196
+ 'Prédiction': model.predict(X)
197
+ })
198
+
199
+ cohort_stats = cohort_analysis.groupby('Cohorte')['Prédiction'].agg(['count', 'mean'])
200
+ cohort_stats.columns = ['Nombre d\'observations', 'Taux de prédiction positive']
201
+
202
+ st.write("Statistiques par cohorte:")
203
+ st.dataframe(cohort_stats)
204
+
205
+ cohort_viz = plt.figure(figsize=(10, 6))
206
+ sns.barplot(data=cohort_analysis, x='Cohorte', y='Prédiction')
207
+ plt.xticks(rotation=45)
208
+ st.pyplot(cohort_viz)
209
+
210
+ # Simulateur de prédictions
211
+ else:
212
+ st.header("Simulateur de prédictions")
213
 
214
+ input_values = {}
215
+ for feature in feature_names:
216
+ if X[feature].dtype == 'object':
217
+ input_values[feature] = st.selectbox(
218
+ f"Sélectionnez {feature}",
219
+ options=X[feature].unique()
220
+ )
221
+ else:
222
+ input_values[feature] = st.slider(
223
+ f"Valeur pour {feature}",
224
+ float(X[feature].min()),
225
+ float(X[feature].max()),
226
+ float(X[feature].mean())
227
+ )
228
 
229
+ if st.button("Prédire"):
230
+ input_df = pd.DataFrame([input_values])
231
+
232
+ prediction = model.predict_proba(input_df)
233
+
234
+ st.write("Probabilités prédites:")
235
+ st.write({f"Classe {i}": f"{prob:.2%}" for i, prob in enumerate(prediction[0])})
236
+
237
+ st.subheader("Chemin de décision")
238
+ node_indicator = model.decision_path(input_df)
239
+ leaf_id = model.apply(input_df)
240
+
241
+ node_index = node_indicator.indices[node_indicator.indptr[0]:node_indicator.indptr[1]]
242
+
243
+ rules = []
244
+ for node_id in node_index:
245
+ if node_id != leaf_id[0]:
246
+ threshold = model.tree_.threshold[node_id]
247
+ feature = feature_names[model.tree_.feature[node_id]]
248
+ if input_df.iloc[0][feature] <= threshold:
249
+ rules.append(f"{feature} ≤ {threshold:.2f}")
250
+ else:
251
+ rules.append(f"{feature} > {threshold:.2f}")
252
+
253
+ for rule in rules:
254
+ st.write(rule)
255
 
256
  if __name__ == "__main__":
257
  app()