clockclock commited on
Commit
1b80388
·
verified ·
1 Parent(s): 1074b1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
@@ -81,40 +81,37 @@ class EnhancedAIvsRealGazeAnalyzer:
81
  valid_indices = target.notna()
82
  features, target = features[valid_indices], target[valid_indices]
83
  features = features.fillna(features.median()).fillna(0)
84
- if len(target.unique()) < 2: return "Not enough classes to train the model.", None
85
  X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=42, stratify=target)
86
  scaler = StandardScaler()
87
  X_train_scaled, X_test_scaled = scaler.fit_transform(X_train), scaler.transform(X_test)
88
  model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
89
  model.fit(X_train_scaled, y_train)
90
- y_pred_proba = model.predict_proba(X_test_scaled)[:, 1]
91
  y_pred = model.predict(X_test_scaled)
92
  report = classification_report(y_test, y_pred, target_names=['Incorrect', 'Correct'], output_dict=True)
93
- auc_score = roc_auc_score(y_test, y_pred_proba)
94
 
95
  # --- THIS IS THE KEY FIX ---
96
- # 1. Convert the report dictionary to a DataFrame
97
- report_df = pd.DataFrame(report).transpose().round(3)
98
- # 2. Use the built-in .to_markdown() method for perfect formatting
99
- report_table = report_df.to_markdown()
100
-
101
- report_md = f"""
102
  ### Model Performance
103
  - **AUC Score:** **{auc_score:.4f}**
104
  - **Overall Accuracy:** {report['accuracy']:.3f}
105
-
106
- **Classification Report:**
107
- {report_table}
108
  """
109
- # --- END OF FIX ---
110
-
 
 
111
  feature_importance = pd.DataFrame({'Feature': features.columns, 'Importance': model.feature_importances_})
112
  feature_importance = feature_importance.sort_values('Importance', ascending=False).head(15)
113
  fig, ax = plt.subplots(figsize=(10, 8))
114
  sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis')
115
  ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14)
116
  plt.tight_layout()
117
- return report_md, fig
 
 
 
118
 
119
  # --- DATA SETUP ---
120
  def setup_and_load_data():
@@ -124,7 +121,7 @@ def setup_and_load_data():
124
  print(f"Cloning data repository from {repo_url}...")
125
  git.Repo.clone_from(repo_url, repo_dir)
126
  else:
127
- print("Data repository already exists.")
128
  base_path = repo_dir
129
  response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
130
  analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
@@ -142,8 +139,9 @@ def update_rq1_visuals(metric_choice):
142
 
143
  def update_rq2_model(test_size, n_estimators):
144
  n_estimators = int(n_estimators)
145
- report, plot = analyzer.run_prediction_model(test_size, n_estimators)
146
- return report, plot
 
147
 
148
  # --- GRADIO INTERFACE DEFINITION ---
149
  description = """
@@ -161,23 +159,35 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
161
  rq1_summary_output = gr.Markdown(label="Statistical Summary")
162
  with gr.Column(scale=2):
163
  rq1_plot_output = gr.Plot(label="Metric Comparison")
 
164
  with gr.TabItem("RQ2: Predicting Correctness from Gaze"):
165
  with gr.Row():
166
  with gr.Column(scale=1):
167
  gr.Markdown("#### Tune Model Hyperparameters")
168
  rq2_test_size_slider = gr.Slider(minimum=0.1, maximum=0.5, step=0.05, value=0.3, label="Test Set Size")
169
  rq2_estimators_slider = gr.Slider(minimum=10, maximum=200, step=10, value=100, label="Number of Trees (n_estimators)")
 
 
170
  with gr.Column(scale=2):
171
- # Ensure this is gr.Markdown()
172
- rq2_report_output = gr.Markdown(label="Model Performance Report")
 
 
 
173
  rq2_plot_output = gr.Plot(label="Feature Importance")
 
174
 
 
 
 
 
175
  rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
176
- rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=[rq2_report_output, rq2_plot_output])
177
- rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=[rq2_report_output, rq2_plot_output])
178
 
179
  demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
180
- demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=[rq2_report_output, rq2_plot_output])
 
181
 
182
  if __name__ == "__main__":
183
  demo.launch()
 
1
+ # app.py
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
 
81
  valid_indices = target.notna()
82
  features, target = features[valid_indices], target[valid_indices]
83
  features = features.fillna(features.median()).fillna(0)
84
+ if len(target.unique()) < 2: return "Not enough classes to train.", None, None
85
  X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=42, stratify=target)
86
  scaler = StandardScaler()
87
  X_train_scaled, X_test_scaled = scaler.fit_transform(X_train), scaler.transform(X_test)
88
  model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
89
  model.fit(X_train_scaled, y_train)
 
90
  y_pred = model.predict(X_test_scaled)
91
  report = classification_report(y_test, y_pred, target_names=['Incorrect', 'Correct'], output_dict=True)
92
+ auc_score = roc_auc_score(y_test, model.predict_proba(X_test_scaled)[:, 1])
93
 
94
  # --- THIS IS THE KEY FIX ---
95
+ # 1. Create the summary text separately.
96
+ summary_md = f"""
 
 
 
 
97
  ### Model Performance
98
  - **AUC Score:** **{auc_score:.4f}**
99
  - **Overall Accuracy:** {report['accuracy']:.3f}
 
 
 
100
  """
101
+ # 2. Create the report DataFrame.
102
+ report_df = pd.DataFrame(report).transpose().round(3)
103
+
104
+ # 3. Create the feature importance plot.
105
  feature_importance = pd.DataFrame({'Feature': features.columns, 'Importance': model.feature_importances_})
106
  feature_importance = feature_importance.sort_values('Importance', ascending=False).head(15)
107
  fig, ax = plt.subplots(figsize=(10, 8))
108
  sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis')
109
  ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14)
110
  plt.tight_layout()
111
+
112
+ # 4. Return the three items separately.
113
+ return summary_md, report_df, fig
114
+ # --- END OF FIX ---
115
 
116
  # --- DATA SETUP ---
117
  def setup_and_load_data():
 
121
  print(f"Cloning data repository from {repo_url}...")
122
  git.Repo.clone_from(repo_url, repo_dir)
123
  else:
124
+ print("Data repository already. Skipping clone.")
125
  base_path = repo_dir
126
  response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
127
  analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
 
139
 
140
  def update_rq2_model(test_size, n_estimators):
141
  n_estimators = int(n_estimators)
142
+ # The function now returns three items
143
+ summary, report_df, plot = analyzer.run_prediction_model(test_size, n_estimators)
144
+ return summary, report_df, plot
145
 
146
  # --- GRADIO INTERFACE DEFINITION ---
147
  description = """
 
159
  rq1_summary_output = gr.Markdown(label="Statistical Summary")
160
  with gr.Column(scale=2):
161
  rq1_plot_output = gr.Plot(label="Metric Comparison")
162
+
163
  with gr.TabItem("RQ2: Predicting Correctness from Gaze"):
164
  with gr.Row():
165
  with gr.Column(scale=1):
166
  gr.Markdown("#### Tune Model Hyperparameters")
167
  rq2_test_size_slider = gr.Slider(minimum=0.1, maximum=0.5, step=0.05, value=0.3, label="Test Set Size")
168
  rq2_estimators_slider = gr.Slider(minimum=10, maximum=200, step=10, value=100, label="Number of Trees (n_estimators)")
169
+
170
+ # --- THIS IS THE KEY UI FIX ---
171
  with gr.Column(scale=2):
172
+ # 1. A Markdown component for the summary text.
173
+ rq2_summary_output = gr.Markdown(label="Model Performance Summary")
174
+ # 2. A Dataframe component for the table.
175
+ rq2_table_output = gr.Dataframe(label="Classification Report", interactive=False)
176
+ # 3. A Plot component for the chart.
177
  rq2_plot_output = gr.Plot(label="Feature Importance")
178
+ # --- END OF UI FIX ---
179
 
180
+ # --- THIS IS THE KEY WIRING FIX ---
181
+ # The outputs list now has 3 items to match the 3 components
182
+ outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
183
+
184
  rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
185
+ rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
186
+ rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
187
 
188
  demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
189
+ demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
190
+ # --- END OF WIRING FIX ---
191
 
192
  if __name__ == "__main__":
193
  demo.launch()