Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
# app.py
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
import matplotlib.patches as patches
|
6 |
import seaborn as sns
|
7 |
from scipy import stats
|
8 |
-
from sklearn.preprocessing import StandardScaler
|
9 |
from sklearn.ensemble import RandomForestClassifier
|
10 |
from sklearn.model_selection import train_test_split
|
11 |
from sklearn.metrics import classification_report, roc_auc_score
|
@@ -23,37 +23,56 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
23 |
"""
|
24 |
A comprehensive class to load, process, and analyze eye-tracking data.
|
25 |
It supports statistical analysis (RQ1), predictive modeling (RQ2),
|
26 |
-
and an innovative
|
27 |
"""
|
28 |
def __init__(self):
|
29 |
self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
|
30 |
self.correct_answers = {'Pair1': 'B', 'Pair2': 'B', 'Pair3': 'B', 'Pair4': 'B', 'Pair5': 'B', 'Pair6': 'B'}
|
31 |
self.combined_data = None
|
|
|
32 |
self.participant_list = []
|
33 |
self.model = None
|
34 |
self.scaler = None
|
35 |
self.feature_names = []
|
36 |
-
self.group_means = None
|
37 |
self.et_id_col = 'Participant name'
|
38 |
|
39 |
def load_and_process_data(self, base_path, response_file):
|
40 |
-
|
|
|
41 |
self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
|
42 |
self.response_data.columns = self.response_data.columns.str.strip()
|
43 |
-
|
|
|
|
|
44 |
for q in self.questions:
|
45 |
file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{q}.xlsx"
|
46 |
if os.path.exists(file_path):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
self.combined_data.columns = self.combined_data.columns.str.strip()
|
52 |
-
|
|
|
53 |
resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), 'Participant name')
|
54 |
for pair, ans in self.correct_answers.items():
|
55 |
if pair in self.response_data.columns:
|
56 |
self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() == ans)
|
|
|
57 |
response_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=self.correct_answers.keys(), var_name='Pair')
|
58 |
correctness_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=[f'{p}_Correct' for p in self.correct_answers.keys()], var_name='Pair_Correct_Col', value_name='Correct')
|
59 |
correctness_long['Pair'] = correctness_long['Pair_Correct_Col'].str.replace('_Correct', '')
|
@@ -62,15 +81,15 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
62 |
self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
|
63 |
self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
|
64 |
self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
|
|
|
65 |
self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
|
66 |
self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
|
|
|
67 |
self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
|
68 |
-
self.group_means = self.combined_data.groupby('Answer_Correctness')[self.numeric_cols].mean()
|
69 |
print("Data loading complete.")
|
70 |
return self
|
71 |
|
72 |
def analyze_rq1_metric(self, metric):
|
73 |
-
# (Unchanged)
|
74 |
if not metric: return None, "Metric not found."
|
75 |
correct = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Correct', metric].dropna()
|
76 |
incorrect = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Incorrect', metric].dropna()
|
@@ -80,7 +99,6 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
80 |
return fig, summary
|
81 |
|
82 |
def run_prediction_model(self, test_size, n_estimators):
|
83 |
-
# (Unchanged)
|
84 |
leaky_features = ['Total_Correct', 'Overall_Accuracy', 'Correct', self.et_id_col]
|
85 |
self.feature_names = [col for col in self.numeric_cols if col not in leaky_features and col in self.combined_data.columns]
|
86 |
features = self.combined_data[self.feature_names].copy()
|
@@ -99,158 +117,133 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
99 |
feature_importance = pd.DataFrame({'Feature': self.feature_names, 'Importance': self.model.feature_importances_}).sort_values('Importance', ascending=False).head(15)
|
100 |
fig, ax = plt.subplots(figsize=(10, 8)); sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis'); ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14); plt.tight_layout()
|
101 |
return summary_md, report_df, fig
|
102 |
-
|
103 |
-
def
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
trial_data_row = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)]
|
109 |
-
if trial_data_row.empty:
|
110 |
-
return f"No data found for {participant} on {question}.", None
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
# --- Plot 1: Schematic Heatmap ---
|
122 |
-
ax1.set_title('Schematic AOI Heatmap (Fixation Duration)', fontsize=14)
|
123 |
-
ax1.set_xlim(0, 10); ax1.set_ylim(0, 10); ax1.set_aspect('equal')
|
124 |
-
ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)
|
125 |
|
126 |
-
|
127 |
-
'2 Face': (2.5, 7.5), '2 Hair': (2.5, 5), '1 Eyes': (7.5, 8.5), '2 Body': (7.5, 6),
|
128 |
-
'5 Sky': (2.5, 2.5), '5 Sea': (7.5, 3.5), '6 Background': (7.5, 1.5)
|
129 |
-
}
|
130 |
-
fixation_cols = [c for c in trial_data.index if 'Total fixation duration' in c]
|
131 |
-
aoi_values = {key.split('duration on ')[1]: trial_data[key] for key in fixation_cols if key.split('duration on ')[1] in aoi_layout}
|
132 |
-
max_val = max(aoi_values.values()) if aoi_values else 1.0
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
ax1.text(pos[0]-1, pos[1], "A", ha='center', va='center', color='white', weight='bold')
|
140 |
-
ax1.text(pos[0]+1, pos[1], "B", ha='center', va='center', color='white', weight='bold')
|
141 |
-
ax1.text(pos[0], pos[1]+0.7, name, ha='center', va='center')
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
labels = [self.feature_names[i].replace('Total fixation duration on ','') for i in top_features_indices]
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
self.group_means.loc['Incorrect', [self.feature_names[i] for i in top_features_indices]],
|
151 |
-
trial_data[[self.feature_names[i] for i in top_features_indices]]
|
152 |
-
], axis=1).T
|
153 |
-
profiles_scaled = scaler_minmax.fit_transform(profiles)
|
154 |
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
def add_to_radar(profile, color, label):
|
162 |
-
values = profile.tolist()
|
163 |
-
values += values[:1]
|
164 |
-
ax2.plot(angles, values, color=color, linewidth=2, linestyle='solid', label=label)
|
165 |
-
ax2.fill(angles, values, color=color, alpha=0.25)
|
166 |
-
|
167 |
-
add_to_radar(profiles_scaled[0,:], 'green', 'Avg. Correct Profile')
|
168 |
-
add_to_radar(profiles_scaled[1,:], 'red', 'Avg. Incorrect Profile')
|
169 |
-
add_to_radar(profiles_scaled[2,:], 'blue', 'This Trial Profile')
|
170 |
|
171 |
-
|
172 |
-
ax2.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
|
173 |
-
|
174 |
-
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
175 |
-
return summary_md, fig
|
176 |
|
177 |
-
# --- DATA SETUP ---
|
178 |
def setup_and_load_data():
|
179 |
repo_url = "https://github.com/RextonRZ/GenAIEyeTrackingCleanedDataset"
|
180 |
repo_dir = "GenAIEyeTrackingCleanedDataset"
|
181 |
if not os.path.exists(repo_dir): git.Repo.clone_from(repo_url, repo_dir)
|
182 |
-
else: print("Data repository already exists.
|
183 |
base_path = repo_dir
|
184 |
response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
|
185 |
analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
|
186 |
return analyzer
|
187 |
|
188 |
-
print("Starting application setup...")
|
189 |
analyzer = setup_and_load_data()
|
190 |
-
print("Application setup complete. Ready for interaction.")
|
191 |
-
|
192 |
-
# --- GRADIO INTERACTIVE FUNCTIONS ---
|
193 |
-
def update_rq1_visuals(metric_choice):
|
194 |
-
return analyzer.analyze_rq1_metric(metric_choice)
|
195 |
|
196 |
-
def update_rq2_model(test_size, n_estimators):
|
197 |
-
return analyzer.run_prediction_model(int(test_size*100)/100, int(n_estimators))
|
198 |
-
|
199 |
-
def update_explorer_view(participant, question):
|
200 |
-
return analyzer.generate_aoi_focus_dashboard(participant, question)
|
201 |
-
|
202 |
-
# --- GRADIO INTERFACE DEFINITION ---
|
203 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
204 |
-
gr.Markdown("# Interactive Dashboard: AI vs. Real Gaze Analysis
|
205 |
with gr.Tabs():
|
206 |
with gr.TabItem("π RQ1: Viewing Time vs. Correctness"):
|
207 |
-
# (UI Unchanged)
|
208 |
with gr.Row():
|
209 |
with gr.Column(scale=1):
|
210 |
-
rq1_metric_dropdown
|
211 |
-
rq1_summary_output
|
212 |
with gr.Column(scale=2):
|
213 |
-
rq1_plot_output
|
214 |
|
215 |
with gr.TabItem("π€ RQ2: Predicting Correctness from Gaze"):
|
216 |
-
# (UI Unchanged)
|
217 |
with gr.Row():
|
218 |
with gr.Column(scale=1):
|
219 |
gr.Markdown("#### Tune Model Hyperparameters")
|
220 |
-
rq2_test_size_slider
|
221 |
-
rq2_estimators_slider
|
222 |
with gr.Column(scale=2):
|
223 |
-
rq2_summary_output
|
224 |
-
rq2_table_output
|
225 |
-
rq2_plot_output
|
226 |
-
|
227 |
-
with gr.TabItem("
|
228 |
-
|
229 |
-
gr.Markdown("### Deep Dive into a Single Trial\nSelect a participant and a question to visualize their unique gaze pattern.")
|
230 |
with gr.Row():
|
231 |
with gr.Column(scale=1):
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
235 |
with gr.Column(scale=2):
|
236 |
-
|
237 |
-
|
238 |
-
# --- WIRING FOR ALL TABS ---
|
239 |
-
outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
|
240 |
-
outputs_explorer = [explorer_summary, explorer_plot]
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
#
|
252 |
-
|
253 |
-
|
|
|
|
|
|
|
|
|
254 |
|
255 |
if __name__ == "__main__":
|
256 |
demo.launch()
|
|
|
1 |
+
# app.py (Gaze Playback & Real-Time Prediction)
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
import matplotlib.patches as patches
|
6 |
import seaborn as sns
|
7 |
from scipy import stats
|
8 |
+
from sklearn.preprocessing import StandardScaler
|
9 |
from sklearn.ensemble import RandomForestClassifier
|
10 |
from sklearn.model_selection import train_test_split
|
11 |
from sklearn.metrics import classification_report, roc_auc_score
|
|
|
23 |
"""
|
24 |
A comprehensive class to load, process, and analyze eye-tracking data.
|
25 |
It supports statistical analysis (RQ1), predictive modeling (RQ2),
|
26 |
+
and an innovative "Gaze Playback" dashboard.
|
27 |
"""
|
28 |
def __init__(self):
|
29 |
self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
|
30 |
self.correct_answers = {'Pair1': 'B', 'Pair2': 'B', 'Pair3': 'B', 'Pair4': 'B', 'Pair5': 'B', 'Pair6': 'B'}
|
31 |
self.combined_data = None
|
32 |
+
self.fixation_data = {} # To store raw fixation data for each trial
|
33 |
self.participant_list = []
|
34 |
self.model = None
|
35 |
self.scaler = None
|
36 |
self.feature_names = []
|
|
|
37 |
self.et_id_col = 'Participant name'
|
38 |
|
39 |
def load_and_process_data(self, base_path, response_file):
|
40 |
+
"""Loads aggregated metrics and raw fixation data, handling NaNs."""
|
41 |
+
print("Loading and processing aggregated and raw fixation data...")
|
42 |
self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
|
43 |
self.response_data.columns = self.response_data.columns.str.strip()
|
44 |
+
|
45 |
+
all_metrics_dfs = []
|
46 |
+
# Load both aggregated metrics and raw fixations
|
47 |
for q in self.questions:
|
48 |
file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{q}.xlsx"
|
49 |
if os.path.exists(file_path):
|
50 |
+
xls = pd.ExcelFile(file_path)
|
51 |
+
# Load aggregated metrics (assuming it's the first sheet)
|
52 |
+
metrics_df = pd.read_excel(xls, sheet_name=0)
|
53 |
+
metrics_df['Question'] = q
|
54 |
+
all_metrics_dfs.append(metrics_df)
|
55 |
+
|
56 |
+
# Load raw fixation data if the sheet exists
|
57 |
+
if 'Fixation-based AOI' in xls.sheet_names:
|
58 |
+
fix_df = pd.read_excel(xls, sheet_name='Fixation-based AOI')
|
59 |
+
fix_df['Question'] = q
|
60 |
+
# ENSURE NO NULLS: Drop rows with invalid data crucial for playback
|
61 |
+
fix_df.dropna(subset=['Fixation point X', 'Fixation point Y', 'Gaze event duration (ms)'], inplace=True)
|
62 |
+
self.et_id_col = next((c for c in fix_df.columns if 'participant' in c.lower()), 'Participant name')
|
63 |
+
for participant, group in fix_df.groupby(self.et_id_col):
|
64 |
+
self.fixation_data[(str(participant), q)] = group.reset_index(drop=True)
|
65 |
+
|
66 |
+
if not all_metrics_dfs: raise ValueError("No aggregated metrics files were found.")
|
67 |
+
self.combined_data = pd.concat(all_metrics_dfs, ignore_index=True)
|
68 |
self.combined_data.columns = self.combined_data.columns.str.strip()
|
69 |
+
|
70 |
+
# Merge response data with aggregated metrics
|
71 |
resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), 'Participant name')
|
72 |
for pair, ans in self.correct_answers.items():
|
73 |
if pair in self.response_data.columns:
|
74 |
self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() == ans)
|
75 |
+
|
76 |
response_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=self.correct_answers.keys(), var_name='Pair')
|
77 |
correctness_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=[f'{p}_Correct' for p in self.correct_answers.keys()], var_name='Pair_Correct_Col', value_name='Correct')
|
78 |
correctness_long['Pair'] = correctness_long['Pair_Correct_Col'].str.replace('_Correct', '')
|
|
|
81 |
self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
|
82 |
self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
|
83 |
self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
|
84 |
+
|
85 |
self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
|
86 |
self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
|
87 |
+
# Convert all participant IDs to strings to prevent sorting errors
|
88 |
self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
|
|
|
89 |
print("Data loading complete.")
|
90 |
return self
|
91 |
|
92 |
def analyze_rq1_metric(self, metric):
|
|
|
93 |
if not metric: return None, "Metric not found."
|
94 |
correct = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Correct', metric].dropna()
|
95 |
incorrect = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Incorrect', metric].dropna()
|
|
|
99 |
return fig, summary
|
100 |
|
101 |
def run_prediction_model(self, test_size, n_estimators):
|
|
|
102 |
leaky_features = ['Total_Correct', 'Overall_Accuracy', 'Correct', self.et_id_col]
|
103 |
self.feature_names = [col for col in self.numeric_cols if col not in leaky_features and col in self.combined_data.columns]
|
104 |
features = self.combined_data[self.feature_names].copy()
|
|
|
117 |
feature_importance = pd.DataFrame({'Feature': self.feature_names, 'Importance': self.model.feature_importances_}).sort_values('Importance', ascending=False).head(15)
|
118 |
fig, ax = plt.subplots(figsize=(10, 8)); sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis'); ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14); plt.tight_layout()
|
119 |
return summary_md, report_df, fig
|
120 |
+
|
121 |
+
def _recalculate_features_from_fixations(self, fixations_df):
|
122 |
+
"""Helper to dynamically create a feature vector from a list of fixations."""
|
123 |
+
feature_vector = pd.Series(0.0, index=self.feature_names)
|
124 |
+
if fixations_df.empty:
|
125 |
+
return feature_vector.values.reshape(1, -1)
|
|
|
|
|
|
|
126 |
|
127 |
+
# Recalculate duration-based features
|
128 |
+
if 'AOI name' in fixations_df.columns:
|
129 |
+
for aoi_name, group in fixations_df.groupby('AOI name'):
|
130 |
+
col_name = f'Total fixation duration on {aoi_name}'
|
131 |
+
if col_name in feature_vector.index:
|
132 |
+
feature_vector[col_name] = group['Gaze event duration (ms)'].sum()
|
133 |
+
|
134 |
+
feature_vector['Total Recording Duration'] = fixations_df['Gaze event duration (ms)'].sum()
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
return feature_vector.fillna(0).values.reshape(1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
def generate_gaze_playback(self, participant, question, fixation_num):
|
139 |
+
"""Generates the gaze playback and real-time prediction dashboard."""
|
140 |
+
trial_key = (str(participant), question)
|
141 |
+
if not participant or not question or trial_key not in self.fixation_data:
|
142 |
+
return "Please select a valid trial with fixation data.", None, gr.Slider(interactive=False)
|
|
|
|
|
|
|
143 |
|
144 |
+
all_fixations = self.fixation_data[trial_key]
|
145 |
+
fixation_num = int(fixation_num)
|
|
|
146 |
|
147 |
+
slider_max = len(all_fixations)
|
148 |
+
if fixation_num > slider_max: fixation_num = slider_max
|
149 |
+
current_fixations = all_fixations.iloc[:fixation_num]
|
|
|
|
|
|
|
|
|
150 |
|
151 |
+
partial_features = self._recalculate_features_from_fixations(current_fixations)
|
152 |
+
prediction_prob = self.model.predict_proba(self.scaler.transform(partial_features))[0]
|
153 |
+
prob_correct = prediction_prob[1]
|
154 |
+
|
155 |
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), gridspec_kw={'height_ratios': [4, 1]})
|
156 |
+
fig.suptitle(f"Gaze Playback for {participant} - {question}", fontsize=16, weight='bold')
|
157 |
+
|
158 |
+
ax1.set_title(f"Displaying Fixations 1 through {fixation_num}/{slider_max}")
|
159 |
+
ax1.set_xlim(0, 1920); ax1.set_ylim(1080, 0)
|
160 |
+
ax1.set_aspect('equal'); ax1.tick_params(left=False, right=False, bottom=False, top=False, labelleft=False, labelbottom=False)
|
161 |
+
ax1.add_patch(patches.Rectangle((0, 0), 1920/2, 1080, facecolor='black', alpha=0.05))
|
162 |
+
ax1.add_patch(patches.Rectangle((1920/2, 0), 1920/2, 1080, facecolor='blue', alpha=0.05))
|
163 |
+
ax1.text(1920*0.25, 50, "Image A", ha='center', fontsize=14, alpha=0.5)
|
164 |
+
ax1.text(1920*0.75, 50, "Image B", ha='center', fontsize=14, alpha=0.5)
|
165 |
+
|
166 |
+
if not current_fixations.empty:
|
167 |
+
points = current_fixations[['Fixation point X', 'Fixation point Y']]
|
168 |
+
ax1.plot(points['Fixation point X'], points['Fixation point Y'], marker='o', color='grey', alpha=0.5, linestyle='-')
|
169 |
+
ax1.scatter(points.iloc[-1]['Fixation point X'], points.iloc[-1]['Fixation point Y'], s=150, c='red', zorder=10, edgecolors='black')
|
170 |
+
|
171 |
+
ax2.set_xlim(0, 1); ax2.set_yticks([])
|
172 |
+
ax2.set_title("Live Prediction Confidence (Answer is 'Correct')")
|
173 |
+
bar_color = 'green' if prob_correct > 0.5 else 'red'
|
174 |
+
ax2.barh([0], [prob_correct], color=bar_color, height=0.5)
|
175 |
+
ax2.axvline(0.5, color='black', linestyle='--', linewidth=1)
|
176 |
+
ax2.text(prob_correct, 0, f" {prob_correct:.1%} ", va='center', ha='left' if prob_correct < 0.9 else 'right', color='white', weight='bold')
|
177 |
+
|
178 |
+
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
179 |
|
180 |
+
trial_info = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)].iloc[0]
|
181 |
+
summary_text = f"**Actual Answer:** `{trial_info['Answer_Correctness']}`"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
+
return summary_text, fig, gr.Slider(maximum=slider_max, value=fixation_num, interactive=True)
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
# --- DATA SETUP & GRADIO APP ---
|
186 |
def setup_and_load_data():
|
187 |
repo_url = "https://github.com/RextonRZ/GenAIEyeTrackingCleanedDataset"
|
188 |
repo_dir = "GenAIEyeTrackingCleanedDataset"
|
189 |
if not os.path.exists(repo_dir): git.Repo.clone_from(repo_url, repo_dir)
|
190 |
+
else: print("Data repository already exists.")
|
191 |
base_path = repo_dir
|
192 |
response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
|
193 |
analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
|
194 |
return analyzer
|
195 |
|
|
|
196 |
analyzer = setup_and_load_data()
|
|
|
|
|
|
|
|
|
|
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
199 |
+
gr.Markdown("# Interactive Dashboard: AI vs. Real Gaze Analysis")
|
200 |
with gr.Tabs():
|
201 |
with gr.TabItem("π RQ1: Viewing Time vs. Correctness"):
|
|
|
202 |
with gr.Row():
|
203 |
with gr.Column(scale=1):
|
204 |
+
rq1_metric_dropdown=gr.Dropdown(choices=analyzer.time_metrics, label="Select a Time-Based Metric", value=analyzer.time_metrics[0] if analyzer.time_metrics else None)
|
205 |
+
rq1_summary_output=gr.Markdown(label="Statistical Summary")
|
206 |
with gr.Column(scale=2):
|
207 |
+
rq1_plot_output=gr.Plot(label="Metric Comparison")
|
208 |
|
209 |
with gr.TabItem("π€ RQ2: Predicting Correctness from Gaze"):
|
|
|
210 |
with gr.Row():
|
211 |
with gr.Column(scale=1):
|
212 |
gr.Markdown("#### Tune Model Hyperparameters")
|
213 |
+
rq2_test_size_slider=gr.Slider(minimum=0.1, maximum=0.5, step=0.05, value=0.3, label="Test Set Size")
|
214 |
+
rq2_estimators_slider=gr.Slider(minimum=10, maximum=200, step=10, value=100, label="Number of Trees")
|
215 |
with gr.Column(scale=2):
|
216 |
+
rq2_summary_output=gr.Markdown(label="Model Performance Summary")
|
217 |
+
rq2_table_output=gr.Dataframe(label="Classification Report", interactive=False)
|
218 |
+
rq2_plot_output=gr.Plot(label="Feature Importance")
|
219 |
+
|
220 |
+
with gr.TabItem("ποΈ Gaze Playback & Real-Time Prediction"):
|
221 |
+
gr.Markdown("### See the Prediction Evolve with Every Glance!")
|
|
|
222 |
with gr.Row():
|
223 |
with gr.Column(scale=1):
|
224 |
+
playback_participant=gr.Dropdown(choices=analyzer.participant_list, label="Select Participant")
|
225 |
+
playback_question=gr.Dropdown(choices=analyzer.questions, label="Select Question")
|
226 |
+
gr.Markdown("Use the slider to play back fixations one by one.")
|
227 |
+
playback_slider=gr.Slider(minimum=0, maximum=1, step=1, value=0, label="Fixation Number", interactive=False)
|
228 |
+
playback_summary=gr.Markdown(label="Trial Info")
|
229 |
with gr.Column(scale=2):
|
230 |
+
playback_plot=gr.Plot(label="Gaze Playback & Live Prediction")
|
|
|
|
|
|
|
|
|
231 |
|
232 |
+
outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
|
233 |
+
outputs_playback = [playback_summary, playback_plot, playback_slider]
|
|
|
234 |
|
235 |
+
rq1_metric_dropdown.change(fn=analyzer.analyze_rq1_metric, inputs=rq1_metric_dropdown, outputs=[rq1_plot_output, rq1_summary_output])
|
236 |
+
rq2_test_size_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
237 |
+
rq2_estimators_slider.release(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
238 |
+
|
239 |
+
playback_inputs = [playback_participant, playback_question, playback_slider]
|
240 |
+
# Reset slider to 0 when a new trial is selected
|
241 |
+
playback_participant.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
242 |
+
playback_question.change(lambda: 0, None, playback_slider).then(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
243 |
+
playback_slider.release(fn=analyzer.generate_gaze_playback, inputs=playback_inputs, outputs=outputs_playback)
|
244 |
+
|
245 |
+
demo.load(fn=analyzer.analyze_rq1_metric, inputs=rq1_metric_dropdown, outputs=[rq1_plot_output, rq1_summary_output])
|
246 |
+
demo.load(fn=analyzer.run_prediction_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
247 |
|
248 |
if __name__ == "__main__":
|
249 |
demo.launch()
|