Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
@@ -20,16 +20,11 @@ plt.style.use('default')
|
|
20 |
sns.set_palette("husl")
|
21 |
|
22 |
class EnhancedAIvsRealGazeAnalyzer:
|
23 |
-
"""
|
24 |
-
A comprehensive class to load, process, and analyze eye-tracking data.
|
25 |
-
It supports statistical analysis (RQ1), predictive modeling (RQ2),
|
26 |
-
and an innovative "Gaze Playback" dashboard.
|
27 |
-
"""
|
28 |
def __init__(self):
|
29 |
self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
|
30 |
self.correct_answers = {'Pair1': 'B', 'Pair2': 'B', 'Pair3': 'B', 'Pair4': 'B', 'Pair5': 'B', 'Pair6': 'B'}
|
31 |
self.combined_data = None
|
32 |
-
self.fixation_data = {}
|
33 |
self.participant_list = []
|
34 |
self.model = None
|
35 |
self.scaler = None
|
@@ -37,7 +32,6 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
37 |
self.et_id_col = 'Participant name'
|
38 |
|
39 |
def load_and_process_data(self, base_path, response_file):
|
40 |
-
"""Loads aggregated metrics and raw fixation data, handling NaNs."""
|
41 |
print("Loading and processing aggregated and raw fixation data...")
|
42 |
self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
|
43 |
self.response_data.columns = self.response_data.columns.str.strip()
|
@@ -48,27 +42,34 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
48 |
file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{q}.xlsx"
|
49 |
if os.path.exists(file_path):
|
50 |
xls = pd.ExcelFile(file_path)
|
51 |
-
# Load aggregated metrics (assuming it's the first sheet)
|
52 |
metrics_df = pd.read_excel(xls, sheet_name=0)
|
53 |
metrics_df['Question'] = q
|
54 |
all_metrics_dfs.append(metrics_df)
|
55 |
|
56 |
-
# Load raw fixation data if the sheet exists
|
57 |
if 'Fixation-based AOI' in xls.sheet_names:
|
58 |
fix_df = pd.read_excel(xls, sheet_name='Fixation-based AOI')
|
59 |
fix_df['Question'] = q
|
60 |
-
# ENSURE NO NULLS: Drop rows with invalid data crucial for playback
|
61 |
fix_df.dropna(subset=['Fixation point X', 'Fixation point Y', 'Gaze event duration (ms)'], inplace=True)
|
62 |
-
|
63 |
-
for
|
64 |
-
|
|
|
|
|
65 |
|
66 |
if not all_metrics_dfs: raise ValueError("No aggregated metrics files were found.")
|
67 |
self.combined_data = pd.concat(all_metrics_dfs, ignore_index=True)
|
68 |
self.combined_data.columns = self.combined_data.columns.str.strip()
|
69 |
|
70 |
-
#
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
for pair, ans in self.correct_answers.items():
|
73 |
if pair in self.response_data.columns:
|
74 |
self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() == ans)
|
@@ -79,12 +80,13 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
79 |
response_long = response_long.merge(correctness_long[[resp_id_col, 'Pair', 'Correct']], on=[resp_id_col, 'Pair'])
|
80 |
q_to_pair = {f'Q{i+1}': f'Pair{i+1}' for i in range(6)}
|
81 |
self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
|
|
|
|
|
82 |
self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
|
83 |
self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
|
84 |
|
85 |
self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
|
86 |
self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
|
87 |
-
# Convert all participant IDs to strings to prevent sorting errors
|
88 |
self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
|
89 |
print("Data loading complete.")
|
90 |
return self
|
@@ -119,42 +121,30 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
119 |
return summary_md, report_df, fig
|
120 |
|
121 |
def _recalculate_features_from_fixations(self, fixations_df):
|
122 |
-
"""Helper to dynamically create a feature vector from a list of fixations."""
|
123 |
feature_vector = pd.Series(0.0, index=self.feature_names)
|
124 |
-
if fixations_df.empty:
|
125 |
-
return feature_vector.values.reshape(1, -1)
|
126 |
-
|
127 |
-
# Recalculate duration-based features
|
128 |
if 'AOI name' in fixations_df.columns:
|
129 |
for aoi_name, group in fixations_df.groupby('AOI name'):
|
130 |
col_name = f'Total fixation duration on {aoi_name}'
|
131 |
if col_name in feature_vector.index:
|
132 |
feature_vector[col_name] = group['Gaze event duration (ms)'].sum()
|
133 |
-
|
134 |
feature_vector['Total Recording Duration'] = fixations_df['Gaze event duration (ms)'].sum()
|
135 |
-
|
136 |
return feature_vector.fillna(0).values.reshape(1, -1)
|
137 |
|
138 |
def generate_gaze_playback(self, participant, question, fixation_num):
|
139 |
-
"""Generates the gaze playback and real-time prediction dashboard."""
|
140 |
trial_key = (str(participant), question)
|
141 |
if not participant or not question or trial_key not in self.fixation_data:
|
142 |
return "Please select a valid trial with fixation data.", None, gr.Slider(interactive=False)
|
143 |
-
|
144 |
all_fixations = self.fixation_data[trial_key]
|
145 |
fixation_num = int(fixation_num)
|
146 |
-
|
147 |
slider_max = len(all_fixations)
|
148 |
if fixation_num > slider_max: fixation_num = slider_max
|
149 |
current_fixations = all_fixations.iloc[:fixation_num]
|
150 |
-
|
151 |
partial_features = self._recalculate_features_from_fixations(current_fixations)
|
152 |
prediction_prob = self.model.predict_proba(self.scaler.transform(partial_features))[0]
|
153 |
prob_correct = prediction_prob[1]
|
154 |
-
|
155 |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), gridspec_kw={'height_ratios': [4, 1]})
|
156 |
fig.suptitle(f"Gaze Playback for {participant} - {question}", fontsize=16, weight='bold')
|
157 |
-
|
158 |
ax1.set_title(f"Displaying Fixations 1 through {fixation_num}/{slider_max}")
|
159 |
ax1.set_xlim(0, 1920); ax1.set_ylim(1080, 0)
|
160 |
ax1.set_aspect('equal'); ax1.tick_params(left=False, right=False, bottom=False, top=False, labelleft=False, labelbottom=False)
|
@@ -162,24 +152,19 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
162 |
ax1.add_patch(patches.Rectangle((1920/2, 0), 1920/2, 1080, facecolor='blue', alpha=0.05))
|
163 |
ax1.text(1920*0.25, 50, "Image A", ha='center', fontsize=14, alpha=0.5)
|
164 |
ax1.text(1920*0.75, 50, "Image B", ha='center', fontsize=14, alpha=0.5)
|
165 |
-
|
166 |
if not current_fixations.empty:
|
167 |
points = current_fixations[['Fixation point X', 'Fixation point Y']]
|
168 |
ax1.plot(points['Fixation point X'], points['Fixation point Y'], marker='o', color='grey', alpha=0.5, linestyle='-')
|
169 |
ax1.scatter(points.iloc[-1]['Fixation point X'], points.iloc[-1]['Fixation point Y'], s=150, c='red', zorder=10, edgecolors='black')
|
170 |
-
|
171 |
ax2.set_xlim(0, 1); ax2.set_yticks([])
|
172 |
ax2.set_title("Live Prediction Confidence (Answer is 'Correct')")
|
173 |
bar_color = 'green' if prob_correct > 0.5 else 'red'
|
174 |
ax2.barh([0], [prob_correct], color=bar_color, height=0.5)
|
175 |
ax2.axvline(0.5, color='black', linestyle='--', linewidth=1)
|
176 |
ax2.text(prob_correct, 0, f" {prob_correct:.1%} ", va='center', ha='left' if prob_correct < 0.9 else 'right', color='white', weight='bold')
|
177 |
-
|
178 |
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
179 |
-
|
180 |
trial_info = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)].iloc[0]
|
181 |
summary_text = f"**Actual Answer:** `{trial_info['Answer_Correctness']}`"
|
182 |
-
|
183 |
return summary_text, fig, gr.Slider(maximum=slider_max, value=fixation_num, interactive=True)
|
184 |
|
185 |
# --- DATA SETUP & GRADIO APP ---
|
@@ -205,7 +190,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
205 |
rq1_summary_output=gr.Markdown(label="Statistical Summary")
|
206 |
with gr.Column(scale=2):
|
207 |
rq1_plot_output=gr.Plot(label="Metric Comparison")
|
208 |
-
|
209 |
with gr.TabItem("🤖 RQ2: Predicting Correctness from Gaze"):
|
210 |
with gr.Row():
|
211 |
with gr.Column(scale=1):
|
@@ -216,7 +200,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
216 |
rq2_summary_output=gr.Markdown(label="Model Performance Summary")
|
217 |
rq2_table_output=gr.Dataframe(label="Classification Report", interactive=False)
|
218 |
rq2_plot_output=gr.Plot(label="Feature Importance")
|
219 |
-
|
220 |
with gr.TabItem("👁️ Gaze Playback & Real-Time Prediction"):
|
221 |
gr.Markdown("### See the Prediction Evolve with Every Glance!")
|
222 |
with gr.Row():
|
@@ -231,13 +214,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
231 |
|
232 |
outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
|
233 |
outputs_playback = [playback_summary, playback_plot, playback_slider]
|
234 |
-
|
235 |
rq1_metric_dropdown.change(fn=analyzer.analyze_rq1_metric, inputs=rq1_metric_dropdown, outputs=[rq1_plot_output, rq1_summary_output])
|
236 |
rq2_test_size_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
237 |
rq2_estimators_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
238 |
-
|
239 |
playback_inputs = [playback_participant, playback_question, playback_slider]
|
240 |
-
# Reset slider to 0 when a new trial is selected
|
241 |
playback_participant.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
242 |
playback_question.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
243 |
playback_slider.release(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
|
|
1 |
+
# app.py
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
|
|
20 |
sns.set_palette("husl")
|
21 |
|
22 |
class EnhancedAIvsRealGazeAnalyzer:
|
|
|
|
|
|
|
|
|
|
|
23 |
def __init__(self):
|
24 |
self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
|
25 |
self.correct_answers = {'Pair1': 'B', 'Pair2': 'B', 'Pair3': 'B', 'Pair4': 'B', 'Pair5': 'B', 'Pair6': 'B'}
|
26 |
self.combined_data = None
|
27 |
+
self.fixation_data = {}
|
28 |
self.participant_list = []
|
29 |
self.model = None
|
30 |
self.scaler = None
|
|
|
32 |
self.et_id_col = 'Participant name'
|
33 |
|
34 |
def load_and_process_data(self, base_path, response_file):
|
|
|
35 |
print("Loading and processing aggregated and raw fixation data...")
|
36 |
self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
|
37 |
self.response_data.columns = self.response_data.columns.str.strip()
|
|
|
42 |
file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{q}.xlsx"
|
43 |
if os.path.exists(file_path):
|
44 |
xls = pd.ExcelFile(file_path)
|
|
|
45 |
metrics_df = pd.read_excel(xls, sheet_name=0)
|
46 |
metrics_df['Question'] = q
|
47 |
all_metrics_dfs.append(metrics_df)
|
48 |
|
|
|
49 |
if 'Fixation-based AOI' in xls.sheet_names:
|
50 |
fix_df = pd.read_excel(xls, sheet_name='Fixation-based AOI')
|
51 |
fix_df['Question'] = q
|
|
|
52 |
fix_df.dropna(subset=['Fixation point X', 'Fixation point Y', 'Gaze event duration (ms)'], inplace=True)
|
53 |
+
# Use a local variable here to avoid confusion
|
54 |
+
fix_et_id_col = next((c for c in fix_df.columns if 'participant' in c.lower()), None)
|
55 |
+
if fix_et_id_col:
|
56 |
+
for participant, group in fix_df.groupby(fix_et_id_col):
|
57 |
+
self.fixation_data[(str(participant), q)] = group.reset_index(drop=True)
|
58 |
|
59 |
if not all_metrics_dfs: raise ValueError("No aggregated metrics files were found.")
|
60 |
self.combined_data = pd.concat(all_metrics_dfs, ignore_index=True)
|
61 |
self.combined_data.columns = self.combined_data.columns.str.strip()
|
62 |
|
63 |
+
# --- THIS IS THE KEY FIX ---
|
64 |
+
# 1. Dynamically find the participant ID column in the COMBINED metrics data.
|
65 |
+
self.et_id_col = next((c for c in self.combined_data.columns if 'participant' in c.lower()), None)
|
66 |
+
if not self.et_id_col: raise KeyError("Could not find a 'participant' column in the aggregated metrics data.")
|
67 |
+
|
68 |
+
# 2. Dynamically find the participant ID column in the RESPONSE data.
|
69 |
+
resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), None)
|
70 |
+
if not resp_id_col: raise KeyError("Could not find a 'participant' column in the response sheet.")
|
71 |
+
# --- END OF FIX ---
|
72 |
+
|
73 |
for pair, ans in self.correct_answers.items():
|
74 |
if pair in self.response_data.columns:
|
75 |
self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() == ans)
|
|
|
80 |
response_long = response_long.merge(correctness_long[[resp_id_col, 'Pair', 'Correct']], on=[resp_id_col, 'Pair'])
|
81 |
q_to_pair = {f'Q{i+1}': f'Pair{i+1}' for i in range(6)}
|
82 |
self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
|
83 |
+
|
84 |
+
# 3. Perform the merge using the correctly identified column names.
|
85 |
self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
|
86 |
self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
|
87 |
|
88 |
self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
|
89 |
self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
|
|
|
90 |
self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
|
91 |
print("Data loading complete.")
|
92 |
return self
|
|
|
121 |
return summary_md, report_df, fig
|
122 |
|
123 |
def _recalculate_features_from_fixations(self, fixations_df):
|
|
|
124 |
feature_vector = pd.Series(0.0, index=self.feature_names)
|
125 |
+
if fixations_df.empty: return feature_vector.fillna(0).values.reshape(1, -1)
|
|
|
|
|
|
|
126 |
if 'AOI name' in fixations_df.columns:
|
127 |
for aoi_name, group in fixations_df.groupby('AOI name'):
|
128 |
col_name = f'Total fixation duration on {aoi_name}'
|
129 |
if col_name in feature_vector.index:
|
130 |
feature_vector[col_name] = group['Gaze event duration (ms)'].sum()
|
|
|
131 |
feature_vector['Total Recording Duration'] = fixations_df['Gaze event duration (ms)'].sum()
|
|
|
132 |
return feature_vector.fillna(0).values.reshape(1, -1)
|
133 |
|
134 |
def generate_gaze_playback(self, participant, question, fixation_num):
|
|
|
135 |
trial_key = (str(participant), question)
|
136 |
if not participant or not question or trial_key not in self.fixation_data:
|
137 |
return "Please select a valid trial with fixation data.", None, gr.Slider(interactive=False)
|
|
|
138 |
all_fixations = self.fixation_data[trial_key]
|
139 |
fixation_num = int(fixation_num)
|
|
|
140 |
slider_max = len(all_fixations)
|
141 |
if fixation_num > slider_max: fixation_num = slider_max
|
142 |
current_fixations = all_fixations.iloc[:fixation_num]
|
|
|
143 |
partial_features = self._recalculate_features_from_fixations(current_fixations)
|
144 |
prediction_prob = self.model.predict_proba(self.scaler.transform(partial_features))[0]
|
145 |
prob_correct = prediction_prob[1]
|
|
|
146 |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), gridspec_kw={'height_ratios': [4, 1]})
|
147 |
fig.suptitle(f"Gaze Playback for {participant} - {question}", fontsize=16, weight='bold')
|
|
|
148 |
ax1.set_title(f"Displaying Fixations 1 through {fixation_num}/{slider_max}")
|
149 |
ax1.set_xlim(0, 1920); ax1.set_ylim(1080, 0)
|
150 |
ax1.set_aspect('equal'); ax1.tick_params(left=False, right=False, bottom=False, top=False, labelleft=False, labelbottom=False)
|
|
|
152 |
ax1.add_patch(patches.Rectangle((1920/2, 0), 1920/2, 1080, facecolor='blue', alpha=0.05))
|
153 |
ax1.text(1920*0.25, 50, "Image A", ha='center', fontsize=14, alpha=0.5)
|
154 |
ax1.text(1920*0.75, 50, "Image B", ha='center', fontsize=14, alpha=0.5)
|
|
|
155 |
if not current_fixations.empty:
|
156 |
points = current_fixations[['Fixation point X', 'Fixation point Y']]
|
157 |
ax1.plot(points['Fixation point X'], points['Fixation point Y'], marker='o', color='grey', alpha=0.5, linestyle='-')
|
158 |
ax1.scatter(points.iloc[-1]['Fixation point X'], points.iloc[-1]['Fixation point Y'], s=150, c='red', zorder=10, edgecolors='black')
|
|
|
159 |
ax2.set_xlim(0, 1); ax2.set_yticks([])
|
160 |
ax2.set_title("Live Prediction Confidence (Answer is 'Correct')")
|
161 |
bar_color = 'green' if prob_correct > 0.5 else 'red'
|
162 |
ax2.barh([0], [prob_correct], color=bar_color, height=0.5)
|
163 |
ax2.axvline(0.5, color='black', linestyle='--', linewidth=1)
|
164 |
ax2.text(prob_correct, 0, f" {prob_correct:.1%} ", va='center', ha='left' if prob_correct < 0.9 else 'right', color='white', weight='bold')
|
|
|
165 |
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
|
|
166 |
trial_info = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)].iloc[0]
|
167 |
summary_text = f"**Actual Answer:** `{trial_info['Answer_Correctness']}`"
|
|
|
168 |
return summary_text, fig, gr.Slider(maximum=slider_max, value=fixation_num, interactive=True)
|
169 |
|
170 |
# --- DATA SETUP & GRADIO APP ---
|
|
|
190 |
rq1_summary_output=gr.Markdown(label="Statistical Summary")
|
191 |
with gr.Column(scale=2):
|
192 |
rq1_plot_output=gr.Plot(label="Metric Comparison")
|
|
|
193 |
with gr.TabItem("🤖 RQ2: Predicting Correctness from Gaze"):
|
194 |
with gr.Row():
|
195 |
with gr.Column(scale=1):
|
|
|
200 |
rq2_summary_output=gr.Markdown(label="Model Performance Summary")
|
201 |
rq2_table_output=gr.Dataframe(label="Classification Report", interactive=False)
|
202 |
rq2_plot_output=gr.Plot(label="Feature Importance")
|
|
|
203 |
with gr.TabItem("👁️ Gaze Playback & Real-Time Prediction"):
|
204 |
gr.Markdown("### See the Prediction Evolve with Every Glance!")
|
205 |
with gr.Row():
|
|
|
214 |
|
215 |
outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
|
216 |
outputs_playback = [playback_summary, playback_plot, playback_slider]
|
|
|
217 |
rq1_metric_dropdown.change(fn=analyzer.analyze_rq1_metric, inputs=rq1_metric_dropdown, outputs=[rq1_plot_output, rq1_summary_output])
|
218 |
rq2_test_size_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
219 |
rq2_estimators_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
|
|
220 |
playback_inputs = [playback_participant, playback_question, playback_slider]
|
|
|
221 |
playback_participant.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
222 |
playback_question.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
223 |
playback_slider.release(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|