Update app.py
Browse files
app.py
CHANGED
@@ -2,9 +2,10 @@
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
|
|
5 |
import seaborn as sns
|
6 |
from scipy import stats
|
7 |
-
from sklearn.preprocessing import StandardScaler
|
8 |
from sklearn.ensemble import RandomForestClassifier
|
9 |
from sklearn.model_selection import train_test_split
|
10 |
from sklearn.metrics import classification_report, roc_auc_score
|
@@ -22,7 +23,7 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
22 |
"""
|
23 |
A comprehensive class to load, process, and analyze eye-tracking data.
|
24 |
It supports statistical analysis (RQ1), predictive modeling (RQ2),
|
25 |
-
and deep-dive
|
26 |
"""
|
27 |
def __init__(self):
|
28 |
self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
|
@@ -33,57 +34,43 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
33 |
self.scaler = None
|
34 |
self.feature_names = []
|
35 |
self.group_means = None
|
36 |
-
self.et_id_col = 'Participant name'
|
37 |
|
38 |
def load_and_process_data(self, base_path, response_file):
|
39 |
-
"""Loads all data from files and preprocesses it for analysis."""
|
40 |
print("Loading and processing data...")
|
41 |
self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
|
42 |
self.response_data.columns = self.response_data.columns.str.strip()
|
43 |
-
|
44 |
all_data = {}
|
45 |
-
for
|
46 |
-
file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{
|
47 |
if os.path.exists(file_path):
|
48 |
-
|
49 |
-
all_data[question] = {sheet_name: pd.read_excel(xls, sheet_name) for sheet_name in xls.sheet_names}
|
50 |
-
|
51 |
all_dfs = [df.copy().assign(Question=q, Metric_Type=m) for q, qd in all_data.items() for m, df in qd.items()]
|
52 |
if not all_dfs: raise ValueError("No eye-tracking data files were found.")
|
53 |
-
|
54 |
self.combined_data = pd.concat(all_dfs, ignore_index=True)
|
55 |
self.combined_data.columns = self.combined_data.columns.str.strip()
|
56 |
-
|
57 |
self.et_id_col = next((c for c in self.combined_data.columns if 'participant' in c.lower()), 'Participant name')
|
58 |
resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), 'Participant name')
|
59 |
-
|
60 |
-
for pair, correct_answer in self.correct_answers.items():
|
61 |
if pair in self.response_data.columns:
|
62 |
-
self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() ==
|
63 |
-
|
64 |
-
response_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=self.correct_answers.keys(), var_name='Pair', value_name='Response')
|
65 |
correctness_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=[f'{p}_Correct' for p in self.correct_answers.keys()], var_name='Pair_Correct_Col', value_name='Correct')
|
66 |
correctness_long['Pair'] = correctness_long['Pair_Correct_Col'].str.replace('_Correct', '')
|
67 |
response_long = response_long.merge(correctness_long[[resp_id_col, 'Pair', 'Correct']], on=[resp_id_col, 'Pair'])
|
68 |
-
|
69 |
q_to_pair = {f'Q{i+1}': f'Pair{i+1}' for i in range(6)}
|
70 |
self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
|
71 |
self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
|
72 |
self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
|
73 |
-
|
74 |
self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
|
75 |
self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
|
76 |
-
|
77 |
-
# --- THIS IS THE CORRECTED LINE ---
|
78 |
-
# Convert all participant IDs to strings before sorting to handle mixed types.
|
79 |
self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
|
80 |
-
# --- END OF CORRECTION ---
|
81 |
-
|
82 |
self.group_means = self.combined_data.groupby('Answer_Correctness')[self.numeric_cols].mean()
|
83 |
print("Data loading complete.")
|
84 |
return self
|
85 |
|
86 |
def analyze_rq1_metric(self, metric):
|
|
|
87 |
if not metric: return None, "Metric not found."
|
88 |
correct = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Correct', metric].dropna()
|
89 |
incorrect = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Incorrect', metric].dropna()
|
@@ -93,6 +80,7 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
93 |
return fig, summary
|
94 |
|
95 |
def run_prediction_model(self, test_size, n_estimators):
|
|
|
96 |
leaky_features = ['Total_Correct', 'Overall_Accuracy', 'Correct', self.et_id_col]
|
97 |
self.feature_names = [col for col in self.numeric_cols if col not in leaky_features and col in self.combined_data.columns]
|
98 |
features = self.combined_data[self.feature_names].copy()
|
@@ -102,64 +90,96 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
102 |
features = features.fillna(features.median()).fillna(0)
|
103 |
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=42, stratify=target)
|
104 |
self.scaler = StandardScaler().fit(X_train)
|
105 |
-
X_train_scaled
|
106 |
self.model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced').fit(X_train_scaled, y_train)
|
107 |
-
report = classification_report(y_test, self.model.predict(
|
108 |
-
auc_score = roc_auc_score(y_test, self.model.predict_proba(
|
109 |
summary_md = f"### Model Performance\n- **AUC Score:** **{auc_score:.4f}**\n- **Overall Accuracy:** {report['accuracy']:.3f}"
|
110 |
report_df = pd.DataFrame(report).transpose().round(3)
|
111 |
feature_importance = pd.DataFrame({'Feature': self.feature_names, 'Importance': self.model.feature_importances_}).sort_values('Importance', ascending=False).head(15)
|
112 |
fig, ax = plt.subplots(figsize=(10, 8)); sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis'); ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14); plt.tight_layout()
|
113 |
return summary_md, report_df, fig
|
114 |
|
115 |
-
def
|
|
|
116 |
if not participant or not question:
|
117 |
-
return "Please select a participant and a question.", None
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
return f"No data found for {participant} on {question}.", None, None
|
123 |
|
124 |
-
trial_data =
|
125 |
actual_answer = trial_data['Answer_Correctness']
|
126 |
-
|
127 |
-
trial_features_scaled = self.scaler.transform(trial_features)
|
128 |
-
prediction_prob = self.model.predict_proba(trial_features_scaled)[0]
|
129 |
predicted_answer = "Correct" if prediction_prob[1] > 0.5 else "Incorrect"
|
130 |
-
summary_md = f"
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
# --- DATA SETUP ---
|
155 |
def setup_and_load_data():
|
156 |
repo_url = "https://github.com/RextonRZ/GenAIEyeTrackingCleanedDataset"
|
157 |
repo_dir = "GenAIEyeTrackingCleanedDataset"
|
158 |
-
if not os.path.exists(repo_dir):
|
159 |
-
|
160 |
-
git.Repo.clone_from(repo_url, repo_dir)
|
161 |
-
else:
|
162 |
-
print("Data repository already exists. Skipping clone.")
|
163 |
base_path = repo_dir
|
164 |
response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
|
165 |
analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
|
@@ -171,31 +191,29 @@ print("Application setup complete. Ready for interaction.")
|
|
171 |
|
172 |
# --- GRADIO INTERACTIVE FUNCTIONS ---
|
173 |
def update_rq1_visuals(metric_choice):
|
174 |
-
|
175 |
-
plot, summary = analyzer.analyze_rq1_metric(metric_choice)
|
176 |
-
return plot, summary
|
177 |
|
178 |
def update_rq2_model(test_size, n_estimators):
|
179 |
-
|
180 |
-
summary, report_df, plot = analyzer.run_prediction_model(test_size, n_estimators)
|
181 |
-
return summary, report_df, plot
|
182 |
|
183 |
def update_explorer_view(participant, question):
|
184 |
-
|
185 |
-
return summary, plot, report_card
|
186 |
|
187 |
# --- GRADIO INTERFACE DEFINITION ---
|
188 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
189 |
-
gr.Markdown("# Interactive Dashboard: AI vs. Real Gaze Analysis\nExplore the eye-tracking dataset by interacting with the controls below.
|
190 |
with gr.Tabs():
|
191 |
with gr.TabItem("π RQ1: Viewing Time vs. Correctness"):
|
|
|
192 |
with gr.Row():
|
193 |
with gr.Column(scale=1):
|
194 |
rq1_metric_dropdown = gr.Dropdown(choices=analyzer.time_metrics, label="Select a Time-Based Metric", value=analyzer.time_metrics[0] if analyzer.time_metrics else None)
|
195 |
rq1_summary_output = gr.Markdown(label="Statistical Summary")
|
196 |
with gr.Column(scale=2):
|
197 |
rq1_plot_output = gr.Plot(label="Metric Comparison")
|
|
|
198 |
with gr.TabItem("π€ RQ2: Predicting Correctness from Gaze"):
|
|
|
199 |
with gr.Row():
|
200 |
with gr.Column(scale=1):
|
201 |
gr.Markdown("#### Tune Model Hyperparameters")
|
@@ -205,24 +223,32 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
205 |
rq2_summary_output = gr.Markdown(label="Model Performance Summary")
|
206 |
rq2_table_output = gr.Dataframe(label="Classification Report", interactive=False)
|
207 |
rq2_plot_output = gr.Plot(label="Feature Importance")
|
208 |
-
|
209 |
-
|
|
|
|
|
210 |
with gr.Row():
|
211 |
with gr.Column(scale=1):
|
212 |
explorer_participant = gr.Dropdown(choices=analyzer.participant_list, label="Select Participant")
|
213 |
explorer_question = gr.Dropdown(choices=analyzer.questions, label="Select Question")
|
214 |
explorer_summary = gr.Markdown(label="Trial Summary")
|
215 |
-
explorer_report_card = gr.Dataframe(label="Feature Report Card", interactive=False)
|
216 |
with gr.Column(scale=2):
|
217 |
-
explorer_plot = gr.Plot(label="Gaze
|
218 |
|
|
|
219 |
outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
|
220 |
-
outputs_explorer = [explorer_summary, explorer_plot
|
|
|
221 |
rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
222 |
rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
223 |
rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
226 |
demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
227 |
demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
228 |
|
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
import matplotlib.patches as patches
|
6 |
import seaborn as sns
|
7 |
from scipy import stats
|
8 |
+
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
9 |
from sklearn.ensemble import RandomForestClassifier
|
10 |
from sklearn.model_selection import train_test_split
|
11 |
from sklearn.metrics import classification_report, roc_auc_score
|
|
|
23 |
"""
|
24 |
A comprehensive class to load, process, and analyze eye-tracking data.
|
25 |
It supports statistical analysis (RQ1), predictive modeling (RQ2),
|
26 |
+
and an innovative deep-dive AOI Focus Dashboard.
|
27 |
"""
|
28 |
def __init__(self):
|
29 |
self.questions = ['Q1', 'Q2', 'Q3', 'Q4', 'Q5', 'Q6']
|
|
|
34 |
self.scaler = None
|
35 |
self.feature_names = []
|
36 |
self.group_means = None
|
37 |
+
self.et_id_col = 'Participant name'
|
38 |
|
39 |
def load_and_process_data(self, base_path, response_file):
|
|
|
40 |
print("Loading and processing data...")
|
41 |
self.response_data = pd.read_excel(response_file) if response_file.endswith('.xlsx') else pd.read_csv(response_file)
|
42 |
self.response_data.columns = self.response_data.columns.str.strip()
|
|
|
43 |
all_data = {}
|
44 |
+
for q in self.questions:
|
45 |
+
file_path = f"{base_path}/Filtered_GenAI_Metrics_cleaned_{q}.xlsx"
|
46 |
if os.path.exists(file_path):
|
47 |
+
all_data[q] = {s: pd.read_excel(file_path, s) for s in pd.ExcelFile(file_path).sheet_names}
|
|
|
|
|
48 |
all_dfs = [df.copy().assign(Question=q, Metric_Type=m) for q, qd in all_data.items() for m, df in qd.items()]
|
49 |
if not all_dfs: raise ValueError("No eye-tracking data files were found.")
|
|
|
50 |
self.combined_data = pd.concat(all_dfs, ignore_index=True)
|
51 |
self.combined_data.columns = self.combined_data.columns.str.strip()
|
|
|
52 |
self.et_id_col = next((c for c in self.combined_data.columns if 'participant' in c.lower()), 'Participant name')
|
53 |
resp_id_col = next((c for c in self.response_data.columns if 'participant' in c.lower()), 'Participant name')
|
54 |
+
for pair, ans in self.correct_answers.items():
|
|
|
55 |
if pair in self.response_data.columns:
|
56 |
+
self.response_data[f'{pair}_Correct'] = (self.response_data[pair].astype(str).str.strip().str.upper() == ans)
|
57 |
+
response_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=self.correct_answers.keys(), var_name='Pair')
|
|
|
58 |
correctness_long = self.response_data.melt(id_vars=[resp_id_col], value_vars=[f'{p}_Correct' for p in self.correct_answers.keys()], var_name='Pair_Correct_Col', value_name='Correct')
|
59 |
correctness_long['Pair'] = correctness_long['Pair_Correct_Col'].str.replace('_Correct', '')
|
60 |
response_long = response_long.merge(correctness_long[[resp_id_col, 'Pair', 'Correct']], on=[resp_id_col, 'Pair'])
|
|
|
61 |
q_to_pair = {f'Q{i+1}': f'Pair{i+1}' for i in range(6)}
|
62 |
self.combined_data['Pair'] = self.combined_data['Question'].map(q_to_pair)
|
63 |
self.combined_data = self.combined_data.merge(response_long, left_on=[self.et_id_col, 'Pair'], right_on=[resp_id_col, 'Pair'], how='left')
|
64 |
self.combined_data['Answer_Correctness'] = self.combined_data['Correct'].map({True: 'Correct', False: 'Incorrect'})
|
|
|
65 |
self.numeric_cols = self.combined_data.select_dtypes(include=np.number).columns.tolist()
|
66 |
self.time_metrics = [c for c in self.numeric_cols if any(k in c.lower() for k in ['time', 'duration', 'fixation'])]
|
|
|
|
|
|
|
67 |
self.participant_list = sorted([str(p) for p in self.combined_data[self.et_id_col].unique()])
|
|
|
|
|
68 |
self.group_means = self.combined_data.groupby('Answer_Correctness')[self.numeric_cols].mean()
|
69 |
print("Data loading complete.")
|
70 |
return self
|
71 |
|
72 |
def analyze_rq1_metric(self, metric):
|
73 |
+
# (Unchanged)
|
74 |
if not metric: return None, "Metric not found."
|
75 |
correct = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Correct', metric].dropna()
|
76 |
incorrect = self.combined_data.loc[self.combined_data['Answer_Correctness'] == 'Incorrect', metric].dropna()
|
|
|
80 |
return fig, summary
|
81 |
|
82 |
def run_prediction_model(self, test_size, n_estimators):
|
83 |
+
# (Unchanged)
|
84 |
leaky_features = ['Total_Correct', 'Overall_Accuracy', 'Correct', self.et_id_col]
|
85 |
self.feature_names = [col for col in self.numeric_cols if col not in leaky_features and col in self.combined_data.columns]
|
86 |
features = self.combined_data[self.feature_names].copy()
|
|
|
90 |
features = features.fillna(features.median()).fillna(0)
|
91 |
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=42, stratify=target)
|
92 |
self.scaler = StandardScaler().fit(X_train)
|
93 |
+
X_train_scaled = self.scaler.transform(X_train)
|
94 |
self.model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced').fit(X_train_scaled, y_train)
|
95 |
+
report = classification_report(y_test, self.model.predict(self.scaler.transform(X_test)), target_names=['Incorrect', 'Correct'], output_dict=True)
|
96 |
+
auc_score = roc_auc_score(y_test, self.model.predict_proba(self.scaler.transform(X_test))[:, 1])
|
97 |
summary_md = f"### Model Performance\n- **AUC Score:** **{auc_score:.4f}**\n- **Overall Accuracy:** {report['accuracy']:.3f}"
|
98 |
report_df = pd.DataFrame(report).transpose().round(3)
|
99 |
feature_importance = pd.DataFrame({'Feature': self.feature_names, 'Importance': self.model.feature_importances_}).sort_values('Importance', ascending=False).head(15)
|
100 |
fig, ax = plt.subplots(figsize=(10, 8)); sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis'); ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14); plt.tight_layout()
|
101 |
return summary_md, report_df, fig
|
102 |
|
103 |
+
def generate_aoi_focus_dashboard(self, participant, question):
|
104 |
+
# --- NEW INNOVATIVE ANALYSIS FUNCTION ---
|
105 |
if not participant or not question:
|
106 |
+
return "Please select a participant and a question.", None
|
107 |
|
108 |
+
trial_data_row = self.combined_data[(self.combined_data[self.et_id_col].astype(str) == str(participant)) & (self.combined_data['Question'] == question)]
|
109 |
+
if trial_data_row.empty:
|
110 |
+
return f"No data found for {participant} on {question}.", None
|
|
|
111 |
|
112 |
+
trial_data = trial_data_row.iloc[0]
|
113 |
actual_answer = trial_data['Answer_Correctness']
|
114 |
+
prediction_prob = self.model.predict_proba(self.scaler.transform(trial_data[self.feature_names].values.reshape(1, -1)))[0]
|
|
|
|
|
115 |
predicted_answer = "Correct" if prediction_prob[1] > 0.5 else "Incorrect"
|
116 |
+
summary_md = f"### Trial Breakdown: **{participant}** on **{question}**\n- **Actual Answer:** `{actual_answer}`\n- **Model Prediction:** `{predicted_answer}` (Confidence: {max(prediction_prob)*100:.1f}%)"
|
117 |
+
|
118 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
|
119 |
+
fig.suptitle(f'AOI Focus & Gaze Profile for {participant} - {question}', fontsize=18, weight='bold')
|
120 |
+
|
121 |
+
# --- Plot 1: Schematic Heatmap ---
|
122 |
+
ax1.set_title('Schematic AOI Heatmap (Fixation Duration)', fontsize=14)
|
123 |
+
ax1.set_xlim(0, 10); ax1.set_ylim(0, 10); ax1.set_aspect('equal')
|
124 |
+
ax1.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)
|
125 |
+
|
126 |
+
aoi_layout = {
|
127 |
+
'2 Face': (2.5, 7.5), '2 Hair': (2.5, 5), '1 Eyes': (7.5, 8.5), '2 Body': (7.5, 6),
|
128 |
+
'5 Sky': (2.5, 2.5), '5 Sea': (7.5, 3.5), '6 Background': (7.5, 1.5)
|
129 |
+
}
|
130 |
+
fixation_cols = [c for c in trial_data.index if 'Total fixation duration' in c]
|
131 |
+
aoi_values = {key.split('duration on ')[1]: trial_data[key] for key in fixation_cols if key.split('duration on ')[1] in aoi_layout}
|
132 |
+
max_val = max(aoi_values.values()) if aoi_values else 1.0
|
133 |
+
|
134 |
+
for name, pos in aoi_layout.items():
|
135 |
+
val_a = aoi_values.get(f'{name} A', 0); val_b = aoi_values.get(f'{name} B', 0)
|
136 |
+
color_a = plt.cm.viridis(val_a / max_val); color_b = plt.cm.viridis(val_b / max_val)
|
137 |
+
ax1.add_patch(patches.Rectangle((pos[0]-2, pos[1]-0.5), 2, 1, facecolor=color_a))
|
138 |
+
ax1.add_patch(patches.Rectangle((pos[0], pos[1]-0.5), 2, 1, facecolor=color_b))
|
139 |
+
ax1.text(pos[0]-1, pos[1], "A", ha='center', va='center', color='white', weight='bold')
|
140 |
+
ax1.text(pos[0]+1, pos[1], "B", ha='center', va='center', color='white', weight='bold')
|
141 |
+
ax1.text(pos[0], pos[1]+0.7, name, ha='center', va='center')
|
142 |
+
|
143 |
+
# --- Plot 2: Gaze Profile Radar Chart ---
|
144 |
+
top_features_indices = self.model.feature_importances_.argsort()[-5:][::-1]
|
145 |
+
labels = [self.feature_names[i].replace('Total fixation duration on ','') for i in top_features_indices]
|
146 |
+
|
147 |
+
scaler_minmax = MinMaxScaler()
|
148 |
+
profiles = pd.concat([
|
149 |
+
self.group_means.loc['Correct', [self.feature_names[i] for i in top_features_indices]],
|
150 |
+
self.group_means.loc['Incorrect', [self.feature_names[i] for i in top_features_indices]],
|
151 |
+
trial_data[[self.feature_names[i] for i in top_features_indices]]
|
152 |
+
], axis=1).T
|
153 |
+
profiles_scaled = scaler_minmax.fit_transform(profiles)
|
154 |
+
|
155 |
+
angles = np.linspace(0, 2*np.pi, len(labels), endpoint=False).tolist()
|
156 |
+
angles += angles[:1]
|
157 |
+
|
158 |
+
ax2 = plt.subplot(122, polar=True)
|
159 |
+
ax2.set_title('Gaze Profile Comparison', fontsize=14)
|
160 |
+
|
161 |
+
def add_to_radar(profile, color, label):
|
162 |
+
values = profile.tolist()
|
163 |
+
values += values[:1]
|
164 |
+
ax2.plot(angles, values, color=color, linewidth=2, linestyle='solid', label=label)
|
165 |
+
ax2.fill(angles, values, color=color, alpha=0.25)
|
166 |
+
|
167 |
+
add_to_radar(profiles_scaled[0,:], 'green', 'Avg. Correct Profile')
|
168 |
+
add_to_radar(profiles_scaled[1,:], 'red', 'Avg. Incorrect Profile')
|
169 |
+
add_to_radar(profiles_scaled[2,:], 'blue', 'This Trial Profile')
|
170 |
+
|
171 |
+
ax2.set_thetagrids(np.degrees(angles[:-1]), labels)
|
172 |
+
ax2.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
|
173 |
+
|
174 |
+
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
|
175 |
+
return summary_md, fig
|
176 |
|
177 |
# --- DATA SETUP ---
|
178 |
def setup_and_load_data():
|
179 |
repo_url = "https://github.com/RextonRZ/GenAIEyeTrackingCleanedDataset"
|
180 |
repo_dir = "GenAIEyeTrackingCleanedDataset"
|
181 |
+
if not os.path.exists(repo_dir): git.Repo.clone_from(repo_url, repo_dir)
|
182 |
+
else: print("Data repository already exists. Skipping clone.")
|
|
|
|
|
|
|
183 |
base_path = repo_dir
|
184 |
response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
|
185 |
analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
|
|
|
191 |
|
192 |
# --- GRADIO INTERACTIVE FUNCTIONS ---
|
193 |
def update_rq1_visuals(metric_choice):
|
194 |
+
return analyzer.analyze_rq1_metric(metric_choice)
|
|
|
|
|
195 |
|
196 |
def update_rq2_model(test_size, n_estimators):
|
197 |
+
return analyzer.run_prediction_model(int(test_size*100)/100, int(n_estimators))
|
|
|
|
|
198 |
|
199 |
def update_explorer_view(participant, question):
|
200 |
+
return analyzer.generate_aoi_focus_dashboard(participant, question)
|
|
|
201 |
|
202 |
# --- GRADIO INTERFACE DEFINITION ---
|
203 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
204 |
+
gr.Markdown("# Interactive Dashboard: AI vs. Real Gaze Analysis\nExplore the eye-tracking dataset by interacting with the controls below.")
|
205 |
with gr.Tabs():
|
206 |
with gr.TabItem("π RQ1: Viewing Time vs. Correctness"):
|
207 |
+
# (UI Unchanged)
|
208 |
with gr.Row():
|
209 |
with gr.Column(scale=1):
|
210 |
rq1_metric_dropdown = gr.Dropdown(choices=analyzer.time_metrics, label="Select a Time-Based Metric", value=analyzer.time_metrics[0] if analyzer.time_metrics else None)
|
211 |
rq1_summary_output = gr.Markdown(label="Statistical Summary")
|
212 |
with gr.Column(scale=2):
|
213 |
rq1_plot_output = gr.Plot(label="Metric Comparison")
|
214 |
+
|
215 |
with gr.TabItem("π€ RQ2: Predicting Correctness from Gaze"):
|
216 |
+
# (UI Unchanged)
|
217 |
with gr.Row():
|
218 |
with gr.Column(scale=1):
|
219 |
gr.Markdown("#### Tune Model Hyperparameters")
|
|
|
223 |
rq2_summary_output = gr.Markdown(label="Model Performance Summary")
|
224 |
rq2_table_output = gr.Dataframe(label="Classification Report", interactive=False)
|
225 |
rq2_plot_output = gr.Plot(label="Feature Importance")
|
226 |
+
|
227 |
+
with gr.TabItem("π¬ AOI Focus Dashboard"):
|
228 |
+
# --- NEW INNOVATIVE UI ---
|
229 |
+
gr.Markdown("### Deep Dive into a Single Trial\nSelect a participant and a question to visualize their unique gaze pattern.")
|
230 |
with gr.Row():
|
231 |
with gr.Column(scale=1):
|
232 |
explorer_participant = gr.Dropdown(choices=analyzer.participant_list, label="Select Participant")
|
233 |
explorer_question = gr.Dropdown(choices=analyzer.questions, label="Select Question")
|
234 |
explorer_summary = gr.Markdown(label="Trial Summary")
|
|
|
235 |
with gr.Column(scale=2):
|
236 |
+
explorer_plot = gr.Plot(label="Gaze Focus & Profile Analysis")
|
237 |
|
238 |
+
# --- WIRING FOR ALL TABS ---
|
239 |
outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
|
240 |
+
outputs_explorer = [explorer_summary, explorer_plot]
|
241 |
+
|
242 |
rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
243 |
rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
244 |
rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
245 |
+
|
246 |
+
# Wiring for the new explorer tab
|
247 |
+
explorer_inputs = [explorer_participant, explorer_question]
|
248 |
+
explorer_participant.change(fn=update_explorer_view, inputs=explorer_inputs, outputs=outputs_explorer)
|
249 |
+
explorer_question.change(fn=update_explorer_view, inputs=explorer_inputs, outputs=outputs_explorer)
|
250 |
+
|
251 |
+
# Load initial state
|
252 |
demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
253 |
demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
254 |
|