clockclock commited on
Commit
13605e4
Β·
verified Β·
1 Parent(s): 25e6b09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -47
app.py CHANGED
@@ -54,7 +54,6 @@ class EnhancedAIvsRealGazeAnalyzer:
54
  self.combined_data = pd.concat(all_dfs, ignore_index=True)
55
  self.combined_data.columns = self.combined_data.columns.str.strip()
56
 
57
- # Dynamically find participant ID columns
58
  self.et_id_col = next((c for c in self.combined_data.columns if 'participant' in c.lower()), 'Participant name')
59
  resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), 'Participant name')
60
 
@@ -74,15 +73,17 @@ class EnhancedAIvsRealGazeAnalyzer:
74
 
75
  self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
76
  self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
77
- self.participant_list = sorted(self.combined_data[self.et_id_col].unique().tolist())
78
 
79
- # Pre-calculate group means for the explorer tab
 
 
 
 
80
  self.group_means = self.combined_data.groupby('Answer_Correctness')[self.numeric_cols].mean()
81
  print("Data loading complete.")
82
  return self
83
 
84
  def analyze_rq1_metric(self, metric):
85
- """Analyzes a single metric for RQ1."""
86
  if not metric: return None, "Metric not found."
87
  correct = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Correct', metric].dropna()
88
  incorrect = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Incorrect', metric].dropna()
@@ -92,7 +93,6 @@ class EnhancedAIvsRealGazeAnalyzer:
92
  return fig, summary
93
 
94
  def run_prediction_model(self, test_size, n_estimators):
95
- """Trains and evaluates the RandomForest model for RQ2."""
96
  leaky_features = ['Total_Correct', 'Overall_Accuracy', 'Correct', self.et_id_col]
97
  self.feature_names = [col for col in self.numeric_cols if col not in leaky_features and col in self.combined_data.columns]
98
  features = self.combined_data[self.feature_names].copy()
@@ -113,41 +113,29 @@ class EnhancedAIvsRealGazeAnalyzer:
113
  return summary_md, report_df, fig
114
 
115
  def analyze_individual_trial(self, participant, question):
116
- """Generates a detailed report for a single participant-question trial."""
117
  if not participant or not question:
118
  return "Please select a participant and a question.", None, None
119
 
120
- trial_data = self.combined_data[(self.combined_data[self.et_id_col] == participant) & (self.combined_data['Question'] == question)]
 
121
  if trial_data.empty:
122
  return f"No data found for {participant} on {question}.", None, None
123
 
124
  trial_data = trial_data.iloc[0]
125
  actual_answer = trial_data['Answer_Correctness']
126
-
127
- # Model Prediction for this specific trial
128
  trial_features = trial_data[self.feature_names].values.reshape(1, -1)
129
  trial_features_scaled = self.scaler.transform(trial_features)
130
  prediction_prob = self.model.predict_proba(trial_features_scaled)[0]
131
  predicted_answer = "Correct" if prediction_prob[1] > 0.5 else "Incorrect"
132
-
133
- # Summary Text
134
- summary_md = f"""
135
- ### Trial Breakdown: **{participant}** on **{question}**
136
- - **Actual Answer:** `{actual_answer}`
137
- - **Model Prediction:** `{predicted_answer}` (Confidence: {max(prediction_prob)*100:.1f}%)
138
- """
139
-
140
- # A vs B Gaze Bias Plot
141
  aoi_cols = [c for c in self.feature_names if ' A' in c or ' B' in c]
142
  a_cols = sorted([c for c in aoi_cols if ' A' in c])
143
  b_cols = sorted([c for c in aoi_cols if ' B' in c])
144
-
145
  plot_data = []
146
  for a_col, b_col in zip(a_cols, b_cols):
147
  base_name = a_col.replace(' A', '')
148
  plot_data.append({'AOI': base_name, 'Image': 'A', 'Value': trial_data[a_col]})
149
  plot_data.append({'AOI': base_name, 'Image': 'B', 'Value': trial_data[b_col]})
150
-
151
  fig, ax = plt.subplots(figsize=(10, 6))
152
  if plot_data:
153
  sns.barplot(data=pd.DataFrame(plot_data), x='Value', y='AOI', hue='Image', ax=ax, palette={'A': '#66b3ff', 'B': '#ff9999'})
@@ -155,26 +143,16 @@ class EnhancedAIvsRealGazeAnalyzer:
155
  else:
156
  ax.text(0.5, 0.5, 'No A/B Area of Interest data for this question.', ha='center')
157
  plt.tight_layout()
158
-
159
- # Feature Report Card
160
  top_features = self.model.feature_importances_.argsort()[-5:][::-1]
161
  top_feature_names = [self.feature_names[i] for i in top_features]
162
-
163
  report_card_data = []
164
  for feature in top_feature_names:
165
- report_card_data.append({
166
- 'Top Feature': feature,
167
- 'This Trial Value': f"{trial_data[feature]:.2f}",
168
- 'Avg (Correct)': f"{self.group_means.loc['Correct', feature]:.2f}",
169
- 'Avg (Incorrect)': f"{self.group_means.loc['Incorrect', feature]:.2f}"
170
- })
171
  report_card_df = pd.DataFrame(report_card_data)
172
-
173
  return summary_md, fig, report_card_df
174
 
175
- # --- DATA SETUP (RUNS ONCE AT STARTUP) ---
176
  def setup_and_load_data():
177
- """Clones the repo if not present and loads data."""
178
  repo_url = "https://github.com/RextonRZ/GenAIEyeTrackingCleanedDataset"
179
  repo_dir = "GenAIEyeTrackingCleanedDataset"
180
  if not os.path.exists(repo_dir):
@@ -209,9 +187,7 @@ def update_explorer_view(participant, question):
209
  # --- GRADIO INTERFACE DEFINITION ---
210
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
211
  gr.Markdown("# Interactive Dashboard: AI vs. Real Gaze Analysis\nExplore the eye-tracking dataset by interacting with the controls below. The data is automatically loaded from the public GitHub repository.")
212
-
213
  with gr.Tabs():
214
- # --- TAB 1: RQ1 ---
215
  with gr.TabItem("πŸ“Š RQ1: Viewing Time vs. Correctness"):
216
  with gr.Row():
217
  with gr.Column(scale=1):
@@ -219,8 +195,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
219
  rq1_summary_output = gr.Markdown(label="Statistical Summary")
220
  with gr.Column(scale=2):
221
  rq1_plot_output = gr.Plot(label="Metric Comparison")
222
-
223
- # --- TAB 2: RQ2 ---
224
  with gr.TabItem("πŸ€– RQ2: Predicting Correctness from Gaze"):
225
  with gr.Row():
226
  with gr.Column(scale=1):
@@ -231,8 +205,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
231
  rq2_summary_output = gr.Markdown(label="Model Performance Summary")
232
  rq2_table_output = gr.Dataframe(label="Classification Report", interactive=False)
233
  rq2_plot_output = gr.Plot(label="Feature Importance")
234
-
235
- # --- TAB 3: INNOVATIVE EXPLORER ---
236
  with gr.TabItem("πŸ”¬ Individual Trial Explorer"):
237
  gr.Markdown("### Deep Dive into a Single Trial\nSelect a participant and a question to see a detailed breakdown of their gaze behavior.")
238
  with gr.Row():
@@ -244,22 +216,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
244
  with gr.Column(scale=2):
245
  explorer_plot = gr.Plot(label="Gaze Bias (Image A vs. B)")
246
 
247
- # --- WIRING FOR ALL TABS ---
248
  outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
249
  outputs_explorer = [explorer_summary, explorer_plot, explorer_report_card]
250
-
251
- # Wiring for Tab 1
252
  rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
253
-
254
- # Wiring for Tab 2
255
  rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
256
  rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
257
-
258
- # Wiring for Tab 3
259
  explorer_participant.change(fn=update_explorer_view, inputs=[explorer_participant, explorer_question], outputs=outputs_explorer)
260
  explorer_question.change(fn=update_explorer_view, inputs=[explorer_participant, explorer_question], outputs=outputs_explorer)
261
-
262
- # Load initial state for all tabs when the app starts
263
  demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
264
  demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
265
 
 
54
  self.combined_data = pd.concat(all_dfs, ignore_index=True)
55
  self.combined_data.columns = self.combined_data.columns.str.strip()
56
 
 
57
  self.et_id_col = next((c for c in self.combined_data.columns if 'participant' in c.lower()), 'Participant name')
58
  resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), 'Participant name')
59
 
 
73
 
74
  self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
75
  self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
 
76
 
77
+ # --- THIS IS THE CORRECTED LINE ---
78
+ # Convert all participant IDs to strings before sorting to handle mixed types.
79
+ self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
80
+ # --- END OF CORRECTION ---
81
+
82
  self.group_means = self.combined_data.groupby('Answer_Correctness')[self.numeric_cols].mean()
83
  print("Data loading complete.")
84
  return self
85
 
86
  def analyze_rq1_metric(self, metric):
 
87
  if not metric: return None, "Metric not found."
88
  correct = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Correct', metric].dropna()
89
  incorrect = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Incorrect', metric].dropna()
 
93
  return fig, summary
94
 
95
  def run_prediction_model(self, test_size, n_estimators):
 
96
  leaky_features = ['Total_Correct', 'Overall_Accuracy', 'Correct', self.et_id_col]
97
  self.feature_names = [col for col in self.numeric_cols if col not in leaky_features and col in self.combined_data.columns]
98
  features = self.combined_data[self.feature_names].copy()
 
113
  return summary_md, report_df, fig
114
 
115
  def analyze_individual_trial(self, participant, question):
 
116
  if not participant or not question:
117
  return "Please select a participant and a question.", None, None
118
 
119
+ # Convert participant ID to string for matching, as the list is now all strings
120
+ trial_data = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)]
121
  if trial_data.empty:
122
  return f"No data found for {participant} on {question}.", None, None
123
 
124
  trial_data = trial_data.iloc[0]
125
  actual_answer = trial_data['Answer_Correctness']
 
 
126
  trial_features = trial_data[self.feature_names].values.reshape(1, -1)
127
  trial_features_scaled = self.scaler.transform(trial_features)
128
  prediction_prob = self.model.predict_proba(trial_features_scaled)[0]
129
  predicted_answer = "Correct" if prediction_prob[1] > 0.5 else "Incorrect"
130
+ summary_md = f"""### Trial Breakdown: **{participant}** on **{question}**\n- **Actual Answer:** `{actual_answer}`\n- **Model Prediction:** `{predicted_answer}` (Confidence: {max(prediction_prob)*100:.1f}%)"""
 
 
 
 
 
 
 
 
131
  aoi_cols = [c for c in self.feature_names if ' A' in c or ' B' in c]
132
  a_cols = sorted([c for c in aoi_cols if ' A' in c])
133
  b_cols = sorted([c for c in aoi_cols if ' B' in c])
 
134
  plot_data = []
135
  for a_col, b_col in zip(a_cols, b_cols):
136
  base_name = a_col.replace(' A', '')
137
  plot_data.append({'AOI': base_name, 'Image': 'A', 'Value': trial_data[a_col]})
138
  plot_data.append({'AOI': base_name, 'Image': 'B', 'Value': trial_data[b_col]})
 
139
  fig, ax = plt.subplots(figsize=(10, 6))
140
  if plot_data:
141
  sns.barplot(data=pd.DataFrame(plot_data), x='Value', y='AOI', hue='Image', ax=ax, palette={'A': '#66b3ff', 'B': '#ff9999'})
 
143
  else:
144
  ax.text(0.5, 0.5, 'No A/B Area of Interest data for this question.', ha='center')
145
  plt.tight_layout()
 
 
146
  top_features = self.model.feature_importances_.argsort()[-5:][::-1]
147
  top_feature_names = [self.feature_names[i] for i in top_features]
 
148
  report_card_data = []
149
  for feature in top_feature_names:
150
+ report_card_data.append({'Top Feature': feature, 'This Trial Value': f"{trial_data[feature]:.2f}", 'Avg (Correct)': f"{self.group_means.loc['Correct', feature]:.2f}", 'Avg (Incorrect)': f"{self.group_means.loc['Incorrect', feature]:.2f}"})
 
 
 
 
 
151
  report_card_df = pd.DataFrame(report_card_data)
 
152
  return summary_md, fig, report_card_df
153
 
154
+ # --- DATA SETUP ---
155
  def setup_and_load_data():
 
156
  repo_url = "https://github.com/RextonRZ/GenAIEyeTrackingCleanedDataset"
157
  repo_dir = "GenAIEyeTrackingCleanedDataset"
158
  if not os.path.exists(repo_dir):
 
187
  # --- GRADIO INTERFACE DEFINITION ---
188
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
189
  gr.Markdown("# Interactive Dashboard: AI vs. Real Gaze Analysis\nExplore the eye-tracking dataset by interacting with the controls below. The data is automatically loaded from the public GitHub repository.")
 
190
  with gr.Tabs():
 
191
  with gr.TabItem("πŸ“Š RQ1: Viewing Time vs. Correctness"):
192
  with gr.Row():
193
  with gr.Column(scale=1):
 
195
  rq1_summary_output = gr.Markdown(label="Statistical Summary")
196
  with gr.Column(scale=2):
197
  rq1_plot_output = gr.Plot(label="Metric Comparison")
 
 
198
  with gr.TabItem("πŸ€– RQ2: Predicting Correctness from Gaze"):
199
  with gr.Row():
200
  with gr.Column(scale=1):
 
205
  rq2_summary_output = gr.Markdown(label="Model Performance Summary")
206
  rq2_table_output = gr.Dataframe(label="Classification Report", interactive=False)
207
  rq2_plot_output = gr.Plot(label="Feature Importance")
 
 
208
  with gr.TabItem("πŸ”¬ Individual Trial Explorer"):
209
  gr.Markdown("### Deep Dive into a Single Trial\nSelect a participant and a question to see a detailed breakdown of their gaze behavior.")
210
  with gr.Row():
 
216
  with gr.Column(scale=2):
217
  explorer_plot = gr.Plot(label="Gaze Bias (Image A vs. B)")
218
 
 
219
  outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
220
  outputs_explorer = [explorer_summary, explorer_plot, explorer_report_card]
 
 
221
  rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
 
 
222
  rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
223
  rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
 
 
224
  explorer_participant.change(fn=update_explorer_view, inputs=[explorer_participant, explorer_question], outputs=outputs_explorer)
225
  explorer_question.change(fn=update_explorer_view, inputs=[explorer_participant, explorer_question], outputs=outputs_explorer)
 
 
226
  demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
227
  demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
228