Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
@@ -81,40 +81,37 @@ class EnhancedAIvsRealGazeAnalyzer:
|
|
81 |
valid_indices = target.notna()
|
82 |
features, target = features[valid_indices], target[valid_indices]
|
83 |
features = features.fillna(features.median()).fillna(0)
|
84 |
-
if len(target.unique()) < 2: return "Not enough classes to train
|
85 |
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=42, stratify=target)
|
86 |
scaler = StandardScaler()
|
87 |
X_train_scaled, X_test_scaled = scaler.fit_transform(X_train), scaler.transform(X_test)
|
88 |
model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
|
89 |
model.fit(X_train_scaled, y_train)
|
90 |
-
y_pred_proba = model.predict_proba(X_test_scaled)[:, 1]
|
91 |
y_pred = model.predict(X_test_scaled)
|
92 |
report = classification_report(y_test, y_pred, target_names=['Incorrect', 'Correct'], output_dict=True)
|
93 |
-
auc_score = roc_auc_score(y_test,
|
94 |
|
95 |
# --- THIS IS THE KEY FIX ---
|
96 |
-
# 1.
|
97 |
-
|
98 |
-
# 2. Use the built-in .to_markdown() method for perfect formatting
|
99 |
-
report_table = report_df.to_markdown()
|
100 |
-
|
101 |
-
report_md = f"""
|
102 |
### Model Performance
|
103 |
- **AUC Score:** **{auc_score:.4f}**
|
104 |
- **Overall Accuracy:** {report['accuracy']:.3f}
|
105 |
-
|
106 |
-
**Classification Report:**
|
107 |
-
{report_table}
|
108 |
"""
|
109 |
-
#
|
110 |
-
|
|
|
|
|
111 |
feature_importance = pd.DataFrame({'Feature': features.columns, 'Importance': model.feature_importances_})
|
112 |
feature_importance = feature_importance.sort_values('Importance', ascending=False).head(15)
|
113 |
fig, ax = plt.subplots(figsize=(10, 8))
|
114 |
sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis')
|
115 |
ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14)
|
116 |
plt.tight_layout()
|
117 |
-
|
|
|
|
|
|
|
118 |
|
119 |
# --- DATA SETUP ---
|
120 |
def setup_and_load_data():
|
@@ -124,7 +121,7 @@ def setup_and_load_data():
|
|
124 |
print(f"Cloning data repository from {repo_url}...")
|
125 |
git.Repo.clone_from(repo_url, repo_dir)
|
126 |
else:
|
127 |
-
print("Data repository already
|
128 |
base_path = repo_dir
|
129 |
response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
|
130 |
analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
|
@@ -142,8 +139,9 @@ def update_rq1_visuals(metric_choice):
|
|
142 |
|
143 |
def update_rq2_model(test_size, n_estimators):
|
144 |
n_estimators = int(n_estimators)
|
145 |
-
|
146 |
-
|
|
|
147 |
|
148 |
# --- GRADIO INTERFACE DEFINITION ---
|
149 |
description = """
|
@@ -161,23 +159,35 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
161 |
rq1_summary_output = gr.Markdown(label="Statistical Summary")
|
162 |
with gr.Column(scale=2):
|
163 |
rq1_plot_output = gr.Plot(label="Metric Comparison")
|
|
|
164 |
with gr.TabItem("RQ2: Predicting Correctness from Gaze"):
|
165 |
with gr.Row():
|
166 |
with gr.Column(scale=1):
|
167 |
gr.Markdown("#### Tune Model Hyperparameters")
|
168 |
rq2_test_size_slider = gr.Slider(minimum=0.1, maximum=0.5, step=0.05, value=0.3, label="Test Set Size")
|
169 |
rq2_estimators_slider = gr.Slider(minimum=10, maximum=200, step=10, value=100, label="Number of Trees (n_estimators)")
|
|
|
|
|
170 |
with gr.Column(scale=2):
|
171 |
-
#
|
172 |
-
|
|
|
|
|
|
|
173 |
rq2_plot_output = gr.Plot(label="Feature Importance")
|
|
|
174 |
|
|
|
|
|
|
|
|
|
175 |
rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
176 |
-
rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=
|
177 |
-
rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=
|
178 |
|
179 |
demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
180 |
-
demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=
|
|
|
181 |
|
182 |
if __name__ == "__main__":
|
183 |
demo.launch()
|
|
|
1 |
+
# app.py
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import matplotlib.pyplot as plt
|
|
|
81 |
valid_indices = target.notna()
|
82 |
features, target = features[valid_indices], target[valid_indices]
|
83 |
features = features.fillna(features.median()).fillna(0)
|
84 |
+
if len(target.unique()) < 2: return "Not enough classes to train.", None, None
|
85 |
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=42, stratify=target)
|
86 |
scaler = StandardScaler()
|
87 |
X_train_scaled, X_test_scaled = scaler.fit_transform(X_train), scaler.transform(X_test)
|
88 |
model = RandomForestClassifier(n_estimators=n_estimators, random_state=42, class_weight='balanced')
|
89 |
model.fit(X_train_scaled, y_train)
|
|
|
90 |
y_pred = model.predict(X_test_scaled)
|
91 |
report = classification_report(y_test, y_pred, target_names=['Incorrect', 'Correct'], output_dict=True)
|
92 |
+
auc_score = roc_auc_score(y_test, model.predict_proba(X_test_scaled)[:, 1])
|
93 |
|
94 |
# --- THIS IS THE KEY FIX ---
|
95 |
+
# 1. Create the summary text separately.
|
96 |
+
summary_md = f"""
|
|
|
|
|
|
|
|
|
97 |
### Model Performance
|
98 |
- **AUC Score:** **{auc_score:.4f}**
|
99 |
- **Overall Accuracy:** {report['accuracy']:.3f}
|
|
|
|
|
|
|
100 |
"""
|
101 |
+
# 2. Create the report DataFrame.
|
102 |
+
report_df = pd.DataFrame(report).transpose().round(3)
|
103 |
+
|
104 |
+
# 3. Create the feature importance plot.
|
105 |
feature_importance = pd.DataFrame({'Feature': features.columns, 'Importance': model.feature_importances_})
|
106 |
feature_importance = feature_importance.sort_values('Importance', ascending=False).head(15)
|
107 |
fig, ax = plt.subplots(figsize=(10, 8))
|
108 |
sns.barplot(data=feature_importance, x='Importance', y='Feature', ax=ax, palette='viridis')
|
109 |
ax.set_title(f'Top 15 Predictive Features (n_estimators={n_estimators})', fontsize=14)
|
110 |
plt.tight_layout()
|
111 |
+
|
112 |
+
# 4. Return the three items separately.
|
113 |
+
return summary_md, report_df, fig
|
114 |
+
# --- END OF FIX ---
|
115 |
|
116 |
# --- DATA SETUP ---
|
117 |
def setup_and_load_data():
|
|
|
121 |
print(f"Cloning data repository from {repo_url}...")
|
122 |
git.Repo.clone_from(repo_url, repo_dir)
|
123 |
else:
|
124 |
+
print("Data repository already. Skipping clone.")
|
125 |
base_path = repo_dir
|
126 |
response_file = os.path.join(repo_dir, "GenAI Response.xlsx")
|
127 |
analyzer = EnhancedAIvsRealGazeAnalyzer().load_and_process_data(base_path, response_file)
|
|
|
139 |
|
140 |
def update_rq2_model(test_size, n_estimators):
|
141 |
n_estimators = int(n_estimators)
|
142 |
+
# The function now returns three items
|
143 |
+
summary, report_df, plot = analyzer.run_prediction_model(test_size, n_estimators)
|
144 |
+
return summary, report_df, plot
|
145 |
|
146 |
# --- GRADIO INTERFACE DEFINITION ---
|
147 |
description = """
|
|
|
159 |
rq1_summary_output = gr.Markdown(label="Statistical Summary")
|
160 |
with gr.Column(scale=2):
|
161 |
rq1_plot_output = gr.Plot(label="Metric Comparison")
|
162 |
+
|
163 |
with gr.TabItem("RQ2: Predicting Correctness from Gaze"):
|
164 |
with gr.Row():
|
165 |
with gr.Column(scale=1):
|
166 |
gr.Markdown("#### Tune Model Hyperparameters")
|
167 |
rq2_test_size_slider = gr.Slider(minimum=0.1, maximum=0.5, step=0.05, value=0.3, label="Test Set Size")
|
168 |
rq2_estimators_slider = gr.Slider(minimum=10, maximum=200, step=10, value=100, label="Number of Trees (n_estimators)")
|
169 |
+
|
170 |
+
# --- THIS IS THE KEY UI FIX ---
|
171 |
with gr.Column(scale=2):
|
172 |
+
# 1. A Markdown component for the summary text.
|
173 |
+
rq2_summary_output = gr.Markdown(label="Model Performance Summary")
|
174 |
+
# 2. A Dataframe component for the table.
|
175 |
+
rq2_table_output = gr.Dataframe(label="Classification Report", interactive=False)
|
176 |
+
# 3. A Plot component for the chart.
|
177 |
rq2_plot_output = gr.Plot(label="Feature Importance")
|
178 |
+
# --- END OF UI FIX ---
|
179 |
|
180 |
+
# --- THIS IS THE KEY WIRING FIX ---
|
181 |
+
# The outputs list now has 3 items to match the 3 components
|
182 |
+
outputs_rq2 = [rq2_summary_output, rq2_table_output, rq2_plot_output]
|
183 |
+
|
184 |
rq1_metric_dropdown.change(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
185 |
+
rq2_test_size_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
186 |
+
rq2_estimators_slider.release(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
187 |
|
188 |
demo.load(fn=update_rq1_visuals, inputs=[rq1_metric_dropdown], outputs=[rq1_plot_output, rq1_summary_output])
|
189 |
+
demo.load(fn=update_rq2_model, inputs=[rq2_test_size_slider, rq2_estimators_slider], outputs=outputs_rq2)
|
190 |
+
# --- END OF WIRING FIX ---
|
191 |
|
192 |
if __name__ == "__main__":
|
193 |
demo.launch()
|