analist commited on
Commit
8384234
·
verified ·
1 Parent(s): 42fa5c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -162
app.py CHANGED
@@ -100,212 +100,167 @@ import streamlit as st
100
  import pandas as pd
101
  import numpy as np
102
  import matplotlib.pyplot as plt
103
- from sklearn.tree import plot_tree, export_text
104
- import seaborn as sns
105
- from sklearn.preprocessing import LabelEncoder
106
- from sklearn.ensemble import RandomForestClassifier
107
  from sklearn.tree import DecisionTreeClassifier
108
- from sklearn.ensemble import GradientBoostingClassifier
109
  from sklearn.linear_model import LogisticRegression
110
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
111
- import shap
112
 
113
  # Configuration de la page
114
- st.set_page_config(
115
- page_title="ML Model Interpreter",
116
- layout="wide",
117
- initial_sidebar_state="expanded"
118
- )
119
 
120
- # CSS personnalisé
121
  st.markdown("""
122
  <style>
123
- .main-header {
124
- color: #0D47A1;
125
- text-align: center;
126
- padding: 1rem;
127
- background: linear-gradient(90deg, #FFFFFF 0%, #90CAF9 50%, #FFFFFF 100%);
128
- border-radius: 10px;
129
- margin-bottom: 2rem;
130
- }
131
-
132
- .metric-card {
133
- background-color: white;
134
  padding: 1.5rem;
135
  border-radius: 10px;
136
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
137
- margin-bottom: 1rem;
138
  }
139
-
140
- .sub-header {
 
 
 
141
  color: #1E88E5;
142
- border-bottom: 2px solid #90CAF9;
143
- padding-bottom: 0.5rem;
144
- margin-bottom: 1rem;
145
  }
146
 
147
- .metric-value {
148
- font-size: 1.5rem;
149
- font-weight: bold;
150
- color: #1E88E5;
 
 
151
  }
152
 
153
- div[data-testid="stMetricValue"] {
 
 
 
154
  color: #1E88E5;
155
  }
156
  </style>
157
  """, unsafe_allow_html=True)
158
 
159
- def custom_metric_card(title, value, prefix=""):
160
- return f"""
161
- <div class="metric-card">
162
- <h3 style="color: #1E88E5; margin-bottom: 0.5rem;">{title}</h3>
163
- <p class="metric-value">{prefix}{value:.4f}</p>
164
- </div>
165
- """
166
-
167
- def set_plot_style(fig):
168
- """Configure le style des graphiques"""
169
- colors = ['#1E88E5', '#90CAF9', '#0D47A1', '#42A5F5']
170
- for ax in fig.axes:
171
- ax.set_facecolor('#F8F9FA')
172
- ax.grid(True, linestyle='--', alpha=0.3, color='#666666')
173
- ax.spines['top'].set_visible(False)
174
- ax.spines['right'].set_visible(False)
175
- ax.tick_params(axis='both', colors='#666666')
176
- ax.set_axisbelow(True)
177
- return fig, colors
178
-
179
- def plot_model_performance(results):
180
- metrics = ['accuracy', 'f1', 'precision', 'recall', 'roc_auc']
181
- fig, axes = plt.subplots(1, 2, figsize=(15, 6))
182
- fig, colors = set_plot_style(fig)
183
-
184
- # Training metrics
185
- train_data = {model: [results[model]['train_metrics'][metric] for metric in metrics]
186
- for model in results.keys()}
187
- train_df = pd.DataFrame(train_data, index=metrics)
188
- train_df.plot(kind='bar', ax=axes[0], color=colors)
189
- axes[0].set_title('Performance d\'Entraînement', color='#0D47A1', pad=20)
190
- axes[0].set_ylim(0, 1)
191
 
192
- # Test metrics
193
- test_data = {model: [results[model]['test_metrics'][metric] for metric in metrics]
194
- for model in results.keys()}
195
- test_df = pd.DataFrame(test_data, index=metrics)
196
- test_df.plot(kind='bar', ax=axes[1], color=colors)
197
- axes[1].set_title('Performance de Test', color='#0D47A1', pad=20)
198
- axes[1].set_ylim(0, 1)
199
 
200
- # Style des graphiques
201
- for ax in axes:
202
- plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
203
- ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
204
 
205
- plt.tight_layout()
206
- return fig
207
-
208
- def plot_feature_importance(model, feature_names, model_type):
209
  fig, ax = plt.subplots(figsize=(10, 6))
210
- fig, colors = set_plot_style(fig)
211
-
212
- if model_type in ["Decision Tree", "Random Forest", "Gradient Boost"]:
213
- importance = model.feature_importances_
214
- elif model_type == "Logistic Regression":
215
- importance = np.abs(model.coef_[0])
216
 
217
- importance_df = pd.DataFrame({
218
- 'feature': feature_names,
219
- 'importance': importance
220
- }).sort_values('importance', ascending=True)
221
 
222
- ax.barh(importance_df['feature'], importance_df['importance'],
223
- color='#1E88E5', alpha=0.8)
224
- ax.set_title("Importance des Caractéristiques", color='#0D47A1', pad=20)
 
 
 
 
225
 
226
  return fig
227
 
228
- def plot_correlation_matrix(data):
229
- fig, ax = plt.subplots(figsize=(10, 8))
230
- fig, _ = set_plot_style(fig)
231
-
232
- sns.heatmap(data.corr(), annot=True, cmap='coolwarm', center=0,
233
- ax=ax, fmt='.2f', square=True)
234
- ax.set_title("Matrice de Corrélation", color='#0D47A1', pad=20)
235
-
236
- return fig
 
 
 
 
 
 
237
 
238
  def app():
239
- st.markdown('<h1 class="main-header">Interpréteur de Modèles ML</h1>',
240
- unsafe_allow_html=True)
241
 
242
- # Load data
243
  X_train, y_train, X_test, y_test, feature_names = load_data()
244
 
245
- # Train models if not in session state
246
- if 'model_results' not in st.session_state:
247
- with st.spinner("🔄 Entraînement des modèles en cours..."):
248
- st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
249
-
250
- # Sidebar
251
  with st.sidebar:
252
- st.markdown('<h2 style="color: #1E88E5;">Navigation</h2>',
253
- unsafe_allow_html=True)
254
-
255
  selected_model = st.selectbox(
256
- "📊 Sélectionnez un modèle",
257
- list(st.session_state.model_results.keys())
258
- )
259
-
260
- st.markdown('<hr style="margin: 1rem 0;">', unsafe_allow_html=True)
261
-
262
- page = st.radio(
263
- "📑 Sélectionnez une section",
264
- ["Performance des modèles",
265
- "Interprétation du modèle",
266
- "Analyse des caractéristiques",
267
- "Simulateur de prédictions"]
268
  )
269
 
270
- current_model = st.session_state.model_results[selected_model]['model']
 
 
 
271
 
272
- # Main content
273
- if page == "Performance des modèles":
274
- st.markdown('<h2 class="sub-header">Performance des modèles</h2>',
275
- unsafe_allow_html=True)
276
-
277
- performance_fig = plot_model_performance(st.session_state.model_results)
278
- st.pyplot(performance_fig)
279
 
280
- st.markdown('<h3 class="sub-header">Métriques détaillées</h3>',
281
- unsafe_allow_html=True)
282
 
283
- col1, col2 = st.columns(2)
284
- with col1:
285
- st.markdown('<h4 style="color: #1E88E5;">Entraînement</h4>',
286
- unsafe_allow_html=True)
287
- for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
288
- st.markdown(custom_metric_card(metric.capitalize(), value),
289
- unsafe_allow_html=True)
290
 
291
- with col2:
292
- st.markdown('<h4 style="color: #1E88E5;">Test</h4>',
293
- unsafe_allow_html=True)
294
- for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
295
- st.markdown(custom_metric_card(metric.capitalize(), value),
296
- unsafe_allow_html=True)
297
 
298
- elif page == "Analyse des caractéristiques":
299
- st.markdown('<h2 class="sub-header">Analyse des caractéristiques</h2>',
300
- unsafe_allow_html=True)
301
 
302
- importance_fig = plot_feature_importance(current_model, feature_names, selected_model)
303
- st.pyplot(importance_fig)
304
-
305
- st.markdown('<h3 class="sub-header">Corrélations</h3>',
306
- unsafe_allow_html=True)
307
- corr_fig = plot_correlation_matrix(X_train)
308
- st.pyplot(corr_fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  if __name__ == "__main__":
311
  app()
 
100
  import pandas as pd
101
  import numpy as np
102
  import matplotlib.pyplot as plt
 
 
 
 
103
  from sklearn.tree import DecisionTreeClassifier
104
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
105
  from sklearn.linear_model import LogisticRegression
106
+ from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score
107
+ import seaborn as sns
108
 
109
  # Configuration de la page
110
+ st.set_page_config(layout="wide", page_title="ML Dashboard")
 
 
 
 
111
 
112
+ # Style personnalisé
113
  st.markdown("""
114
  <style>
115
+ /* Cartes stylisées */
116
+ div.css-1r6slb0.e1tzin5v2 {
117
+ background-color: #FFFFFF;
118
+ border: 1px solid #EEEEEE;
 
 
 
 
 
 
 
119
  padding: 1.5rem;
120
  border-radius: 10px;
121
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
 
122
  }
123
+
124
+ /* Headers */
125
+ .main-header {
126
+ font-size: 2rem;
127
+ font-weight: 700;
128
  color: #1E88E5;
129
+ text-align: center;
130
+ margin-bottom: 2rem;
 
131
  }
132
 
133
+ /* Metric containers */
134
+ div.css-12w0qpk.e1tzin5v2 {
135
+ background-color: #F8F9FA;
136
+ padding: 1rem;
137
+ border-radius: 8px;
138
+ text-align: center;
139
  }
140
 
141
+ /* Metric values */
142
+ div.css-1xarl3l.e16fv1kl1 {
143
+ font-size: 1.8rem;
144
+ font-weight: 700;
145
  color: #1E88E5;
146
  }
147
  </style>
148
  """, unsafe_allow_html=True)
149
 
150
+ def plot_performance_comparison(results, metric='test_metrics'):
151
+ """Crée un graphique de comparaison des performances avec des couleurs distinctes"""
152
+ metrics = ['accuracy', 'f1', 'recall', 'roc_auc']
153
+ model_names = list(results.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Définir des couleurs distinctes pour chaque modèle
156
+ colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4']
 
 
 
 
 
157
 
158
+ data = {model: [results[model][metric][m] for m in metrics]
159
+ for model in model_names}
 
 
160
 
 
 
 
 
161
  fig, ax = plt.subplots(figsize=(10, 6))
162
+ x = np.arange(len(metrics))
163
+ width = 0.2
 
 
 
 
164
 
165
+ for i, (model, values) in enumerate(data.items()):
166
+ ax.bar(x + i*width, values, width, label=model, color=colors[i])
 
 
167
 
168
+ ax.set_ylabel('Score')
169
+ ax.set_title(f'Comparaison des performances ({metric.split("_")[0].title()})')
170
+ ax.set_xticks(x + width * (len(model_names)-1)/2)
171
+ ax.set_xticklabels(metrics)
172
+ ax.legend()
173
+ ax.grid(True, alpha=0.3)
174
+ plt.ylim(0, 1)
175
 
176
  return fig
177
 
178
+ def create_metric_card(title, value):
179
+ """Crée une carte de métrique stylisée"""
180
+ st.markdown(f"""
181
+ <div style="
182
+ background-color: white;
183
+ padding: 1rem;
184
+ border-radius: 8px;
185
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
186
+ text-align: center;
187
+ margin-bottom: 1rem;
188
+ ">
189
+ <h3 style="color: #666; font-size: 1rem; margin-bottom: 0.5rem;">{title}</h3>
190
+ <p style="color: #1E88E5; font-size: 1.8rem; font-weight: bold; margin: 0;">{value:.3f}</p>
191
+ </div>
192
+ """, unsafe_allow_html=True)
193
 
194
  def app():
195
+ # Header
196
+ st.markdown('<h1 class="main-header">Tableau de Bord ML</h1>', unsafe_allow_html=True)
197
 
198
+ # Charger et préparer les données
199
  X_train, y_train, X_test, y_test, feature_names = load_data()
200
 
201
+ # Sidebar pour la sélection du modèle
 
 
 
 
 
202
  with st.sidebar:
203
+ st.markdown('<h2 style="color: #1E88E5;">Configuration</h2>', unsafe_allow_html=True)
 
 
204
  selected_model = st.selectbox(
205
+ "Sélectionner un modèle",
206
+ ["Logistic Regression", "Decision Tree", "Random Forest", "Gradient Boost"]
 
 
 
 
 
 
 
 
 
 
207
  )
208
 
209
+ # Entraînement des modèles si pas déjà fait
210
+ if 'model_results' not in st.session_state:
211
+ with st.spinner("⏳ Entraînement des modèles..."):
212
+ st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
213
 
214
+ # Layout principal
215
+ col1, col2 = st.columns([2, 1])
216
+
217
+ with col1:
218
+ # Graphiques de performance
219
+ st.markdown("### 📊 Comparaison des Performances")
 
220
 
221
+ tab1, tab2 = st.tabs(["🎯 Test", "📈 Entraînement"])
 
222
 
223
+ with tab1:
224
+ fig_test = plot_performance_comparison(st.session_state.model_results, 'test_metrics')
225
+ st.pyplot(fig_test)
 
 
 
 
226
 
227
+ with tab2:
228
+ fig_train = plot_performance_comparison(st.session_state.model_results, 'train_metrics')
229
+ st.pyplot(fig_train)
 
 
 
230
 
231
+ with col2:
232
+ # Métriques détaillées du modèle sélectionné
233
+ st.markdown(f"### 📌 Métriques - {selected_model}")
234
 
235
+ metrics = st.session_state.model_results[selected_model]['test_metrics']
236
+ for metric, value in metrics.items():
237
+ if metric != 'precision': # On exclut la précision
238
+ create_metric_card(metric.upper(), value)
239
+
240
+ # Section inférieure
241
+ st.markdown("### 🔍 Analyse Détaillée")
242
+ col3, col4 = st.columns(2)
243
+
244
+ with col3:
245
+ # Feature Importance
246
+ current_model = st.session_state.model_results[selected_model]['model']
247
+ if hasattr(current_model, 'feature_importances_') or hasattr(current_model, 'coef_'):
248
+ fig_importance = plt.figure(figsize=(10, 6))
249
+ if hasattr(current_model, 'feature_importances_'):
250
+ importances = current_model.feature_importances_
251
+ else:
252
+ importances = np.abs(current_model.coef_[0])
253
+
254
+ plt.barh(feature_names, importances)
255
+ plt.title("Importance des Caractéristiques")
256
+ st.pyplot(fig_importance)
257
+
258
+ with col4:
259
+ # Matrice de corrélation
260
+ fig_corr = plt.figure(figsize=(10, 8))
261
+ sns.heatmap(X_train.corr(), annot=True, cmap='coolwarm', center=0)
262
+ plt.title("Matrice de Corrélation")
263
+ st.pyplot(fig_corr)
264
 
265
  if __name__ == "__main__":
266
  app()