clockclock commited on
Commit
cd63cff
·
verified ·
1 Parent(s): 1c82e5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -40
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py (Gaze Playback & Real-Time Prediction)
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
@@ -20,16 +20,11 @@ plt.style.use('default')
20
  sns.set_palette("husl")
21
 
22
  class EnhancedAIvsRealGazeAnalyzer:
23
- """
24
- A comprehensive class to load, process, and analyze eye-tracking data.
25
- It supports statistical analysis (RQ1), predictive modeling (RQ2),
26
- and an innovative "Gaze Playback" dashboard.
27
- """
28
  def __init__(self):
29
  self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
30
  self.correct_answers = {'Pair1': 'B', 'Pair2': 'B', 'Pair3': 'B', 'Pair4': 'B', 'Pair5': 'B', 'Pair6': 'B'}
31
  self.combined_data = None
32
- self.fixation_data = {} # To store raw fixation data for each trial
33
  self.participant_list = []
34
  self.model = None
35
  self.scaler = None
@@ -37,7 +32,6 @@ class EnhancedAIvsRealGazeAnalyzer:
37
  self.et_id_col = 'Participant name'
38
 
39
  def load_and_process_data(self, base_path, response_file):
40
- """Loads aggregated metrics and raw fixation data, handling NaNs."""
41
  print("Loading and processing aggregated and raw fixation data...")
42
  self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
43
  self.response_data.columns = self.response_data.columns.str.strip()
@@ -48,27 +42,34 @@ class EnhancedAIvsRealGazeAnalyzer:
48
  file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{q}.xlsx"
49
  if os.path.exists(file_path):
50
  xls = pd.ExcelFile(file_path)
51
- # Load aggregated metrics (assuming it's the first sheet)
52
  metrics_df = pd.read_excel(xls, sheet_name=0)
53
  metrics_df['Question'] = q
54
  all_metrics_dfs.append(metrics_df)
55
 
56
- # Load raw fixation data if the sheet exists
57
  if 'Fixation-based AOI' in xls.sheet_names:
58
  fix_df = pd.read_excel(xls, sheet_name='Fixation-based AOI')
59
  fix_df['Question'] = q
60
- # ENSURE NO NULLS: Drop rows with invalid data crucial for playback
61
  fix_df.dropna(subset=['Fixation point X', 'Fixation point Y', 'Gaze event duration (ms)'], inplace=True)
62
- self.et_id_col = next((c for c in fix_df.columns if 'participant' in c.lower()), 'Participant name')
63
- for participant, group in fix_df.groupby(self.et_id_col):
64
- self.fixation_data[(str(participant), q)] = group.reset_index(drop=True)
 
 
65
 
66
  if not all_metrics_dfs: raise ValueError("No aggregated metrics files were found.")
67
  self.combined_data = pd.concat(all_metrics_dfs, ignore_index=True)
68
  self.combined_data.columns = self.combined_data.columns.str.strip()
69
 
70
- # Merge response data with aggregated metrics
71
- resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), 'Participant name')
 
 
 
 
 
 
 
 
72
  for pair, ans in self.correct_answers.items():
73
  if pair in self.response_data.columns:
74
  self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() == ans)
@@ -79,12 +80,13 @@ class EnhancedAIvsRealGazeAnalyzer:
79
  response_long = response_long.merge(correctness_long[[resp_id_col, 'Pair', 'Correct']], on=[resp_id_col, 'Pair'])
80
  q_to_pair = {f'Q{i+1}': f'Pair{i+1}' for i in range(6)}
81
  self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
 
 
82
  self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
83
  self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
84
 
85
  self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
86
  self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
87
- # Convert all participant IDs to strings to prevent sorting errors
88
  self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
89
  print("Data loading complete.")
90
  return self
@@ -119,42 +121,30 @@ class EnhancedAIvsRealGazeAnalyzer:
119
  return summary_md, report_df, fig
120
 
121
  def _recalculate_features_from_fixations(self, fixations_df):
122
- """Helper to dynamically create a feature vector from a list of fixations."""
123
  feature_vector = pd.Series(0.0, index=self.feature_names)
124
- if fixations_df.empty:
125
- return feature_vector.values.reshape(1, -1)
126
-
127
- # Recalculate duration-based features
128
  if 'AOI name' in fixations_df.columns:
129
  for aoi_name, group in fixations_df.groupby('AOI name'):
130
  col_name = f'Total fixation duration on {aoi_name}'
131
  if col_name in feature_vector.index:
132
  feature_vector[col_name] = group['Gaze event duration (ms)'].sum()
133
-
134
  feature_vector['Total Recording Duration'] = fixations_df['Gaze event duration (ms)'].sum()
135
-
136
  return feature_vector.fillna(0).values.reshape(1, -1)
137
 
138
  def generate_gaze_playback(self, participant, question, fixation_num):
139
- """Generates the gaze playback and real-time prediction dashboard."""
140
  trial_key = (str(participant), question)
141
  if not participant or not question or trial_key not in self.fixation_data:
142
  return "Please select a valid trial with fixation data.", None, gr.Slider(interactive=False)
143
-
144
  all_fixations = self.fixation_data[trial_key]
145
  fixation_num = int(fixation_num)
146
-
147
  slider_max = len(all_fixations)
148
  if fixation_num > slider_max: fixation_num = slider_max
149
  current_fixations = all_fixations.iloc[:fixation_num]
150
-
151
  partial_features = self._recalculate_features_from_fixations(current_fixations)
152
  prediction_prob = self.model.predict_proba(self.scaler.transform(partial_features))[0]
153
  prob_correct = prediction_prob[1]
154
-
155
  fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), gridspec_kw={'height_ratios': [4, 1]})
156
  fig.suptitle(f"Gaze Playback for {participant} - {question}", fontsize=16, weight='bold')
157
-
158
  ax1.set_title(f"Displaying Fixations 1 through {fixation_num}/{slider_max}")
159
  ax1.set_xlim(0, 1920); ax1.set_ylim(1080, 0)
160
  ax1.set_aspect('equal'); ax1.tick_params(left=False, right=False, bottom=False, top=False, labelleft=False, labelbottom=False)
@@ -162,24 +152,19 @@ class EnhancedAIvsRealGazeAnalyzer:
162
  ax1.add_patch(patches.Rectangle((1920/2, 0), 1920/2, 1080, facecolor='blue', alpha=0.05))
163
  ax1.text(1920*0.25, 50, "Image A", ha='center', fontsize=14, alpha=0.5)
164
  ax1.text(1920*0.75, 50, "Image B", ha='center', fontsize=14, alpha=0.5)
165
-
166
  if not current_fixations.empty:
167
  points = current_fixations[['Fixation point X', 'Fixation point Y']]
168
  ax1.plot(points['Fixation point X'], points['Fixation point Y'], marker='o', color='grey', alpha=0.5, linestyle='-')
169
  ax1.scatter(points.iloc[-1]['Fixation point X'], points.iloc[-1]['Fixation point Y'], s=150, c='red', zorder=10, edgecolors='black')
170
-
171
  ax2.set_xlim(0, 1); ax2.set_yticks([])
172
  ax2.set_title("Live Prediction Confidence (Answer is 'Correct')")
173
  bar_color = 'green' if prob_correct > 0.5 else 'red'
174
  ax2.barh([0], [prob_correct], color=bar_color, height=0.5)
175
  ax2.axvline(0.5, color='black', linestyle='--', linewidth=1)
176
  ax2.text(prob_correct, 0, f" {prob_correct:.1%} ", va='center', ha='left' if prob_correct < 0.9 else 'right', color='white', weight='bold')
177
-
178
  plt.tight_layout(rect=[0, 0, 1, 0.95])
179
-
180
  trial_info = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)].iloc[0]
181
  summary_text = f"**Actual Answer:** `{trial_info['Answer_Correctness']}`"
182
-
183
  return summary_text, fig, gr.Slider(maximum=slider_max, value=fixation_num, interactive=True)
184
 
185
  # --- DATA SETUP & GRADIO APP ---
@@ -205,7 +190,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
205
  rq1_summary_output=gr.Markdown(label="Statistical Summary")
206
  with gr.Column(scale=2):
207
  rq1_plot_output=gr.Plot(label="Metric Comparison")
208
-
209
  with gr.TabItem("🤖 RQ2: Predicting Correctness from Gaze"):
210
  with gr.Row():
211
  with gr.Column(scale=1):
@@ -216,7 +200,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
216
  rq2_summary_output=gr.Markdown(label="Model Performance Summary")
217
  rq2_table_output=gr.Dataframe(label="Classification Report", interactive=False)
218
  rq2_plot_output=gr.Plot(label="Feature Importance")
219
-
220
  with gr.TabItem("👁️ Gaze Playback & Real-Time Prediction"):
221
  gr.Markdown("### See the Prediction Evolve with Every Glance!")
222
  with gr.Row():
@@ -231,13 +214,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
231
 
232
  outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
233
  outputs_playback = [playback_summary, playback_plot, playback_slider]
234
-
235
  rq1_metric_dropdown.change(fn=analyzer.analyze_rq1_metric, inputs=rq1_metric_dropdown, outputs=[rq1_plot_output, rq1_summary_output])
236
  rq2_test_size_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
237
  rq2_estimators_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
238
-
239
  playback_inputs = [playback_participant, playback_question, playback_slider]
240
- # Reset slider to 0 when a new trial is selected
241
  playback_participant.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
242
  playback_question.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
243
  playback_slider.release(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
 
1
+ # app.py
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
 
20
  sns.set_palette("husl")
21
 
22
  class EnhancedAIvsRealGazeAnalyzer:
 
 
 
 
 
23
  def __init__(self):
24
  self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
25
  self.correct_answers = {'Pair1': 'B', 'Pair2': 'B', 'Pair3': 'B', 'Pair4': 'B', 'Pair5': 'B', 'Pair6': 'B'}
26
  self.combined_data = None
27
+ self.fixation_data = {}
28
  self.participant_list = []
29
  self.model = None
30
  self.scaler = None
 
32
  self.et_id_col = 'Participant name'
33
 
34
  def load_and_process_data(self, base_path, response_file):
 
35
  print("Loading and processing aggregated and raw fixation data...")
36
  self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
37
  self.response_data.columns = self.response_data.columns.str.strip()
 
42
  file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{q}.xlsx"
43
  if os.path.exists(file_path):
44
  xls = pd.ExcelFile(file_path)
 
45
  metrics_df = pd.read_excel(xls, sheet_name=0)
46
  metrics_df['Question'] = q
47
  all_metrics_dfs.append(metrics_df)
48
 
 
49
  if 'Fixation-based AOI' in xls.sheet_names:
50
  fix_df = pd.read_excel(xls, sheet_name='Fixation-based AOI')
51
  fix_df['Question'] = q
 
52
  fix_df.dropna(subset=['Fixation point X', 'Fixation point Y', 'Gaze event duration (ms)'], inplace=True)
53
+ # Use a local variable here to avoid confusion
54
+ fix_et_id_col = next((c for c in fix_df.columns if 'participant' in c.lower()), None)
55
+ if fix_et_id_col:
56
+ for participant, group in fix_df.groupby(fix_et_id_col):
57
+ self.fixation_data[(str(participant), q)] = group.reset_index(drop=True)
58
 
59
  if not all_metrics_dfs: raise ValueError("No aggregated metrics files were found.")
60
  self.combined_data = pd.concat(all_metrics_dfs, ignore_index=True)
61
  self.combined_data.columns = self.combined_data.columns.str.strip()
62
 
63
+ # --- THIS IS THE KEY FIX ---
64
+ # 1. Dynamically find the participant ID column in the COMBINED metrics data.
65
+ self.et_id_col = next((c for c in self.combined_data.columns if 'participant' in c.lower()), None)
66
+ if not self.et_id_col: raise KeyError("Could not find a 'participant' column in the aggregated metrics data.")
67
+
68
+ # 2. Dynamically find the participant ID column in the RESPONSE data.
69
+ resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), None)
70
+ if not resp_id_col: raise KeyError("Could not find a 'participant' column in the response sheet.")
71
+ # --- END OF FIX ---
72
+
73
  for pair, ans in self.correct_answers.items():
74
  if pair in self.response_data.columns:
75
  self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() == ans)
 
80
  response_long = response_long.merge(correctness_long[[resp_id_col, 'Pair', 'Correct']], on=[resp_id_col, 'Pair'])
81
  q_to_pair = {f'Q{i+1}': f'Pair{i+1}' for i in range(6)}
82
  self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
83
+
84
+ # 3. Perform the merge using the correctly identified column names.
85
  self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
86
  self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
87
 
88
  self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
89
  self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
 
90
  self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
91
  print("Data loading complete.")
92
  return self
 
121
  return summary_md, report_df, fig
122
 
123
  def _recalculate_features_from_fixations(self, fixations_df):
 
124
  feature_vector = pd.Series(0.0, index=self.feature_names)
125
+ if fixations_df.empty: return feature_vector.fillna(0).values.reshape(1, -1)
 
 
 
126
  if 'AOI name' in fixations_df.columns:
127
  for aoi_name, group in fixations_df.groupby('AOI name'):
128
  col_name = f'Total fixation duration on {aoi_name}'
129
  if col_name in feature_vector.index:
130
  feature_vector[col_name] = group['Gaze event duration (ms)'].sum()
 
131
  feature_vector['Total Recording Duration'] = fixations_df['Gaze event duration (ms)'].sum()
 
132
  return feature_vector.fillna(0).values.reshape(1, -1)
133
 
134
  def generate_gaze_playback(self, participant, question, fixation_num):
 
135
  trial_key = (str(participant), question)
136
  if not participant or not question or trial_key not in self.fixation_data:
137
  return "Please select a valid trial with fixation data.", None, gr.Slider(interactive=False)
 
138
  all_fixations = self.fixation_data[trial_key]
139
  fixation_num = int(fixation_num)
 
140
  slider_max = len(all_fixations)
141
  if fixation_num > slider_max: fixation_num = slider_max
142
  current_fixations = all_fixations.iloc[:fixation_num]
 
143
  partial_features = self._recalculate_features_from_fixations(current_fixations)
144
  prediction_prob = self.model.predict_proba(self.scaler.transform(partial_features))[0]
145
  prob_correct = prediction_prob[1]
 
146
  fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), gridspec_kw={'height_ratios': [4, 1]})
147
  fig.suptitle(f"Gaze Playback for {participant} - {question}", fontsize=16, weight='bold')
 
148
  ax1.set_title(f"Displaying Fixations 1 through {fixation_num}/{slider_max}")
149
  ax1.set_xlim(0, 1920); ax1.set_ylim(1080, 0)
150
  ax1.set_aspect('equal'); ax1.tick_params(left=False, right=False, bottom=False, top=False, labelleft=False, labelbottom=False)
 
152
  ax1.add_patch(patches.Rectangle((1920/2, 0), 1920/2, 1080, facecolor='blue', alpha=0.05))
153
  ax1.text(1920*0.25, 50, "Image A", ha='center', fontsize=14, alpha=0.5)
154
  ax1.text(1920*0.75, 50, "Image B", ha='center', fontsize=14, alpha=0.5)
 
155
  if not current_fixations.empty:
156
  points = current_fixations[['Fixation point X', 'Fixation point Y']]
157
  ax1.plot(points['Fixation point X'], points['Fixation point Y'], marker='o', color='grey', alpha=0.5, linestyle='-')
158
  ax1.scatter(points.iloc[-1]['Fixation point X'], points.iloc[-1]['Fixation point Y'], s=150, c='red', zorder=10, edgecolors='black')
 
159
  ax2.set_xlim(0, 1); ax2.set_yticks([])
160
  ax2.set_title("Live Prediction Confidence (Answer is 'Correct')")
161
  bar_color = 'green' if prob_correct > 0.5 else 'red'
162
  ax2.barh([0], [prob_correct], color=bar_color, height=0.5)
163
  ax2.axvline(0.5, color='black', linestyle='--', linewidth=1)
164
  ax2.text(prob_correct, 0, f" {prob_correct:.1%} ", va='center', ha='left' if prob_correct < 0.9 else 'right', color='white', weight='bold')
 
165
  plt.tight_layout(rect=[0, 0, 1, 0.95])
 
166
  trial_info = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)].iloc[0]
167
  summary_text = f"**Actual Answer:** `{trial_info['Answer_Correctness']}`"
 
168
  return summary_text, fig, gr.Slider(maximum=slider_max, value=fixation_num, interactive=True)
169
 
170
  # --- DATA SETUP & GRADIO APP ---
 
190
  rq1_summary_output=gr.Markdown(label="Statistical Summary")
191
  with gr.Column(scale=2):
192
  rq1_plot_output=gr.Plot(label="Metric Comparison")
 
193
  with gr.TabItem("🤖 RQ2: Predicting Correctness from Gaze"):
194
  with gr.Row():
195
  with gr.Column(scale=1):
 
200
  rq2_summary_output=gr.Markdown(label="Model Performance Summary")
201
  rq2_table_output=gr.Dataframe(label="Classification Report", interactive=False)
202
  rq2_plot_output=gr.Plot(label="Feature Importance")
 
203
  with gr.TabItem("👁️ Gaze Playback & Real-Time Prediction"):
204
  gr.Markdown("### See the Prediction Evolve with Every Glance!")
205
  with gr.Row():
 
214
 
215
  outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
216
  outputs_playback = [playback_summary, playback_plot, playback_slider]
 
217
  rq1_metric_dropdown.change(fn=analyzer.analyze_rq1_metric, inputs=rq1_metric_dropdown, outputs=[rq1_plot_output, rq1_summary_output])
218
  rq2_test_size_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
219
  rq2_estimators_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
 
220
  playback_inputs = [playback_participant, playback_question, playback_slider]
 
221
  playback_participant.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
222
  playback_question.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
223
  playback_slider.release(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)