ssyok commited on
Commit
c2b98b5
Β·
1 Parent(s): 3d22019

upgrade the UI and fix warning of standard scaler

Browse files
Files changed (2) hide show
  1. app.py +252 -110
  2. requirements.txt +1 -1
app.py CHANGED
@@ -23,6 +23,16 @@ trained_models_folder = 'Optical Illusion - Trained Models'
23
  DISPLAY_WIDTH = 1920
24
  DISPLAY_HEIGHT = 1080
25
 
 
 
 
 
 
 
 
 
 
 
26
  # Load all saved models at startup
27
  def load_all_models():
28
  """Load all saved models into memory"""
@@ -65,16 +75,17 @@ def create_placeholder_image(image_name):
65
 
66
  # Handle None or empty image_name
67
  if image_name is None:
68
- display_text = 'NO IMAGE SELECTED\n\nSelect an image from the dropdown'
69
  else:
70
- display_text = f'{image_name.upper()}\n\nClick where you first look\n\n(Image not found)'
71
 
72
  ax.text(0.5, 0.5, display_text,
73
  transform=ax.transAxes, ha='center', va='center',
74
- fontsize=32, fontweight='bold')
75
  ax.set_xlim(0, DISPLAY_WIDTH)
76
  ax.set_ylim(0, DISPLAY_HEIGHT)
77
  ax.axis('off')
 
78
 
79
  buf = io.BytesIO()
80
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', pad_inches=0)
@@ -90,10 +101,10 @@ def process_click(image_name, model_type, evt: gr.SelectData):
90
  """Process click on image and return prediction"""
91
 
92
  if evt is None:
93
- return "Please click on the image where you first looked!", None, None
94
 
95
  if image_name is None:
96
- return "Please select an image first!", None, None
97
 
98
  # Get click coordinates (Gradio provides them in image coordinates)
99
  click_x_img, click_y_img = evt.index
@@ -107,7 +118,7 @@ def process_click(image_name, model_type, evt: gr.SelectData):
107
 
108
  # Get model data
109
  if image_name not in all_models:
110
- return f"No model found for {image_name}", None, None
111
 
112
  model_data = all_models[image_name]
113
 
@@ -121,7 +132,8 @@ def process_click(image_name, model_type, evt: gr.SelectData):
121
  bias = dist_right - dist_left
122
 
123
  # Make prediction
124
- X = np.array([[dist_left, dist_right, bias]])
 
125
  model = model_data[f'{model_type}_model']
126
  prediction = model.predict(X)[0]
127
  probability = model.predict_proba(X)[0]
@@ -130,18 +142,34 @@ def process_click(image_name, model_type, evt: gr.SelectData):
130
  predicted_class = model_data['label_classes'][prediction]
131
  confidence = probability[prediction]
132
 
 
 
 
 
 
 
 
 
 
 
133
  # Create detailed message
134
  message = f"""
135
- ## Prediction Results
 
136
 
137
- **You clicked at:** ({click_x_img}, {click_y_img}) in image coordinates
138
- **Normalized position:** ({click_x_norm:.1f}, {click_y_norm:.1f})
139
-
140
- **Distance to left centroid:** {dist_left:.1f} pixels
141
- **Distance to right centroid:** {dist_right:.1f} pixels
 
 
 
 
 
142
 
143
- ### Prediction: You likely see the **{predicted_class.upper()}** interpretation
144
- ### Confidence: {confidence:.1%}
145
  """
146
 
147
  # Create visualization
@@ -150,24 +178,26 @@ def process_click(image_name, model_type, evt: gr.SelectData):
150
 
151
  # Get example interpretations
152
  interpretations = {
153
- 'duck-rabbit': {'left': 'Duck', 'right': 'Rabbit'},
154
- 'face-vase': {'left': 'Faces', 'right': 'Vase'},
155
- 'young-old': {'left': 'Young Woman', 'right': 'Old Woman'},
156
- 'princess-oldMan': {'left': 'Princess', 'right': 'Old Man'},
157
- 'lily-woman': {'left': 'Lily', 'right': 'Woman'},
158
- 'tiger-monkey': {'left': 'Tiger', 'right': 'Monkey'}
159
  }
160
 
161
  if image_name in interpretations:
162
  specific = interpretations[image_name][predicted_class]
163
- message += f"\n**Specific interpretation:** You see the **{specific}**"
 
 
164
 
165
  return message, viz, create_stats_table(image_name, model_type)
166
 
167
  def create_visualization(image_name, click_x, click_y, prediction, confidence, model_type='rf'):
168
  """Create a visualization showing the click point, centroids, and prediction"""
169
 
170
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
171
 
172
  # Get model data
173
  model_data = all_models[image_name]
@@ -189,15 +219,15 @@ def create_visualization(image_name, click_x, click_y, prediction, confidence, m
189
  bias = dist_right - dist_left
190
  features.append([dist_left, dist_right, bias])
191
 
192
- X = np.array(features)
193
  model = model_data[f'{model_type}_model']
194
  Z = model.predict(X)
195
  Z = Z.reshape(xx.shape)
196
 
197
  # Plot decision boundary
198
  from matplotlib.colors import ListedColormap
199
- colors = ListedColormap(['lightblue', 'lightcoral'])
200
- ax1.contourf(xx, yy, Z, alpha=0.6, cmap=colors)
201
 
202
  # Plot centroids
203
  ax1.scatter(centroid_left[0], centroid_left[1],
@@ -219,24 +249,33 @@ def create_visualization(image_name, click_x, click_y, prediction, confidence, m
219
  ax1.set_ylabel('Y (pixels from center)')
220
  ax1.set_title(f'Decision Space - {model_type.upper()} Model')
221
  ax1.grid(True, alpha=0.3)
222
- ax1.legend()
223
  ax1.set_xlim(-960, 960) # Full width range
224
  ax1.set_ylim(-540, 540) # Full height range
225
  ax1.set_aspect('equal')
 
226
 
227
  # Right plot: Statistics
228
  image_df = master_df[master_df['image_type'] == image_name]
229
 
230
  # Create bar chart of choices
231
  choice_counts = image_df['choice'].value_counts()
232
- ax2.bar(choice_counts.index, choice_counts.values,
233
- color=['blue' if x == 'left' else 'red' for x in choice_counts.index])
 
 
 
 
 
 
 
234
 
235
  # Add prediction annotation
236
  ax2.text(0.5, 0.95, f'Your Predicted Choice: {prediction.upper()}',
237
  transform=ax2.transAxes, ha='center', va='top',
238
  fontsize=16, fontweight='bold',
239
- bbox=dict(boxstyle='round', facecolor='lightgreen' if prediction == 'left' else 'lightcoral'))
 
240
 
241
  ax2.text(0.5, 0.85, f'Confidence: {confidence:.1%}',
242
  transform=ax2.transAxes, ha='center', va='top', fontsize=14)
@@ -249,7 +288,9 @@ def create_visualization(image_name, click_x, click_y, prediction, confidence, m
249
  ax2.text(0.5, 0.05, f'Model CV Accuracy: {model_data[f"cv_accuracy_{model_type}"]:.1%}',
250
  transform=ax2.transAxes, ha='center', va='bottom', fontsize=12,
251
  style='italic', alpha=0.7)
252
-
 
 
253
  plt.tight_layout()
254
 
255
  # Convert plot to image
@@ -266,132 +307,222 @@ def create_stats_table(image_name, model_type):
266
  image_df = master_df[master_df['image_type'] == image_name]
267
 
268
  stats = {
269
- 'Metric': ['Total Participants', 'Left Choices', 'Right Choices',
270
- 'Model Accuracy', 'Class Balance'],
271
  'Value': [
272
  len(image_df),
273
  model_data['class_distribution'].get('left', 0),
274
  model_data['class_distribution'].get('right', 0),
275
  f"{model_data[f'cv_accuracy_{model_type}']:.1%}",
276
- f"{min(model_data['class_distribution'].values()) / len(image_df):.1%}"
 
277
  ]
278
  }
279
 
280
  return pd.DataFrame(stats)
281
 
282
- # Create Gradio Interface
283
- with gr.Blocks(title="Optical Illusion First Fixation Predictor") as demo:
284
- gr.Markdown("""
285
- # 🧠 Optical Illusion First Fixation Predictor
286
-
287
- This tool predicts which interpretation of an ambiguous image you'll see based on where you first look!
288
-
289
- ## How to use:
290
- 1. Select an optical illusion from the dropdown
291
- 2. Choose a model type (Random Forest usually performs better)
292
- 3. Click on the image where your eyes first landed
293
- 4. See the prediction of what you're likely to perceive!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- **Note:** Images are displayed at 1920x1080 resolution. Click where you naturally first looked at the image.
 
 
 
 
 
 
 
 
 
296
  """)
297
 
298
  with gr.Row():
299
  with gr.Column(scale=2):
300
- # Image selection
301
  available_images = list(all_models.keys()) if all_models else []
302
  default_image = available_images[0] if available_images else None
303
 
304
  image_choice = gr.Dropdown(
305
  choices=available_images,
306
  value=default_image,
307
- label="Select Optical Illusion",
308
  info="Choose which ambiguous image to analyze"
309
  )
 
 
 
 
 
 
310
 
311
- # Model selection
312
  model_type = gr.Radio(
313
- choices=["rf", "lr"],
314
  value="rf",
315
- label="Model Type",
316
- info="Random Forest (rf) usually performs better"
 
317
  )
318
 
319
- # Display image with proper dimensions
320
  image_display = gr.Image(
321
- label="Click where you first look",
322
  interactive=True,
323
  type="pil",
324
- height=1080, # Set to full height
325
- width=1920 # Set to full width
 
326
  )
327
 
328
  with gr.Column(scale=1):
329
- # Results section
330
- prediction_output = gr.Markdown(label="Prediction Results")
331
- stats_table = gr.DataFrame(label="Image Statistics")
 
 
 
 
 
 
 
332
 
333
- # Visualization output
334
  with gr.Row():
335
  visualization_output = gr.Image(
336
- label="Analysis Visualization",
337
  type="pil"
338
  )
339
 
340
- # Information section
341
- with gr.Accordion("ℹ️ About the Models", open=False):
342
  gr.Markdown("""
343
- ### Model Information
344
- - **Random Forest (RF)**: An ensemble learning method that creates multiple decision trees
345
- - **Logistic Regression (LR)**: A linear model for binary classification
346
-
347
- ### Features Used
348
- - Distance from first fixation to left interpretation centroid
349
- - Distance from first fixation to right interpretation centroid
350
- - Bias (difference between distances)
351
-
352
- ### Coordinate System
353
- - Image dimensions: 1920x1080 pixels
354
- - Normalized coordinates: X(-960 to 960), Y(-540 to 540)
355
- - Center point (0,0) is the middle of the image
356
-
357
- ### Training Method
358
- - Leave-One-Participant-Out Cross-Validation (LOPO-CV)
359
- - Ensures model generalizes to new participants
 
 
360
  """)
361
 
362
- # Summary statistics
363
- with gr.Accordion("πŸ“Š Model Performance Summary", open=False):
364
- summary_data = []
365
- for img_name, model_data in all_models.items():
366
- summary_data.append({
367
- 'Image': img_name,
368
- 'RF Accuracy': f"{model_data['cv_accuracy_rf']:.1%}",
369
- 'LR Accuracy': f"{model_data['cv_accuracy_lr']:.1%}",
370
- 'Total Samples': model_data['total_samples']
371
- })
372
-
373
- gr.DataFrame(
374
- value=pd.DataFrame(summary_data),
375
- label="Cross-Validation Accuracies"
376
- )
 
377
 
378
- # Function to update image when selection changes
379
- def update_image(image_name):
380
  # Handle None case
381
  if image_name is None:
382
- return create_placeholder_image(None), "Please select an image first!", None
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  # Use real images if available, otherwise use placeholder
385
  if image_name in illusion_images:
386
- return illusion_images[image_name], "Click on the image to get a prediction!", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  else:
388
- return create_placeholder_image(image_name), "Click on the image to get a prediction!", None
 
 
 
 
 
 
 
 
 
389
 
390
  # Connect events
391
  image_choice.change(
392
- fn=update_image,
393
  inputs=[image_choice],
394
- outputs=[image_display, prediction_output, stats_table]
395
  )
396
 
397
  # Handle click event
@@ -403,16 +534,15 @@ with gr.Blocks(title="Optical Illusion First Fixation Predictor") as demo:
403
 
404
  # Load initial image
405
  demo.load(
406
- fn=update_image,
407
  inputs=[image_choice],
408
- outputs=[image_display, prediction_output, stats_table]
409
  )
410
 
411
- # Examples section
412
  if available_images:
413
- gr.Markdown("## πŸ“Œ Example Interpretations")
414
  with gr.Row():
415
- # Only show examples for available images
416
  example_list = []
417
  for img in ["duck-rabbit", "face-vase", "young-old", "tiger-monkey"]:
418
  if img in available_images:
@@ -422,8 +552,17 @@ with gr.Blocks(title="Optical Illusion First Fixation Predictor") as demo:
422
  gr.Examples(
423
  examples=example_list,
424
  inputs=[image_choice, model_type],
425
- label="Try these examples"
426
  )
 
 
 
 
 
 
 
 
 
427
 
428
  # Debug info
429
  print(f"\nImage folder: {image_folder}")
@@ -433,4 +572,7 @@ print(f"Image dimensions: {DISPLAY_WIDTH}x{DISPLAY_HEIGHT}")
433
 
434
  # Launch the app
435
  if __name__ == "__main__":
436
- demo.launch(share=True)
 
 
 
 
23
  DISPLAY_WIDTH = 1920
24
  DISPLAY_HEIGHT = 1080
25
 
26
+ # Image descriptions for better user understanding
27
+ IMAGE_DESCRIPTIONS = {
28
+ 'duck-rabbit': 'A classic ambiguous figure that can be seen as either a duck or a rabbit',
29
+ 'face-vase': 'The famous Rubin\'s vase - you might see two faces in profile or a vase',
30
+ 'young-old': 'This image can appear as either a young woman or an old woman',
31
+ 'princess-oldMan': 'Can be perceived as either a princess or an old man',
32
+ 'lily-woman': 'This ambiguous image shows either a lily flower or a woman',
33
+ 'tiger-monkey': 'You might see either a tiger or a monkey in this image'
34
+ }
35
+
36
  # Load all saved models at startup
37
  def load_all_models():
38
  """Load all saved models into memory"""
 
75
 
76
  # Handle None or empty image_name
77
  if image_name is None:
78
+ display_text = 'πŸ–ΌοΈ NO IMAGE SELECTED\n\nπŸ‘† Select an image from the dropdown above'
79
  else:
80
+ display_text = f'πŸ–ΌοΈ {image_name.upper()}\n\nπŸ‘† Click where you first look\n\n⚠️ (Image not found)'
81
 
82
  ax.text(0.5, 0.5, display_text,
83
  transform=ax.transAxes, ha='center', va='center',
84
+ fontsize=28, fontweight='bold', color='#666666')
85
  ax.set_xlim(0, DISPLAY_WIDTH)
86
  ax.set_ylim(0, DISPLAY_HEIGHT)
87
  ax.axis('off')
88
+ ax.set_facecolor('#f8f9fa') # Light gray background for placeholder
89
 
90
  buf = io.BytesIO()
91
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', pad_inches=0)
 
101
  """Process click on image and return prediction"""
102
 
103
  if evt is None:
104
+ return "❗ Please click on the image where you first looked!", None, None
105
 
106
  if image_name is None:
107
+ return "❗ Please select an image first!", None, None
108
 
109
  # Get click coordinates (Gradio provides them in image coordinates)
110
  click_x_img, click_y_img = evt.index
 
118
 
119
  # Get model data
120
  if image_name not in all_models:
121
+ return f"❗ No model found for {image_name}", None, None
122
 
123
  model_data = all_models[image_name]
124
 
 
132
  bias = dist_right - dist_left
133
 
134
  # Make prediction
135
+ X = pd.DataFrame([[dist_left, dist_right, bias]],
136
+ columns=['dist_to_left', 'dist_to_right', 'bias_to_left'])
137
  model = model_data[f'{model_type}_model']
138
  prediction = model.predict(X)[0]
139
  probability = model.predict_proba(X)[0]
 
142
  predicted_class = model_data['label_classes'][prediction]
143
  confidence = probability[prediction]
144
 
145
+ # Create confidence level description
146
+ if confidence >= 0.8:
147
+ confidence_level = "Very High 🟒"
148
+ elif confidence >= 0.65:
149
+ confidence_level = "High 🟑"
150
+ elif confidence >= 0.5:
151
+ confidence_level = "Moderate 🟠"
152
+ else:
153
+ confidence_level = "Low πŸ”΄"
154
+
155
  # Create detailed message
156
  message = f"""
157
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 1.5rem; border-radius: 10px; color: white; margin: 0.5rem 0;">
158
+ <h2 style="color: white; margin-top: 0;">πŸ” Prediction Results</h2>
159
 
160
+ <p><strong>πŸ‘† Click Location:</strong> ({click_x_img}, {click_y_img}) pixels from top-left<br>
161
+ <strong>🎯 Normalized Position:</strong> ({click_x_norm:.1f}, {click_y_norm:.1f}) from center</p>
162
+
163
+ <hr style="border-color: rgba(255,255,255,0.3);">
164
+
165
+ <p><strong>πŸ“ Distance to Left Region:</strong> {dist_left:.1f} pixels<br>
166
+ <strong>πŸ“ Distance to Right Region:</strong> {dist_right:.1f} pixels<br>
167
+ <strong>βš–οΈ Bias Score:</strong> {bias:.1f}</p>
168
+
169
+ <hr style="border-color: rgba(255,255,255,0.3);">
170
 
171
+ <h3 style="color: white;">🧠 Prediction: You likely see the {predicted_class.upper()} interpretation</h3>
172
+ <h3 style="color: white;">πŸ“Š Confidence: {confidence:.1%} ({confidence_level})</h3>
173
  """
174
 
175
  # Create visualization
 
178
 
179
  # Get example interpretations
180
  interpretations = {
181
+ 'duck-rabbit': {'left': 'Duck πŸ¦†', 'right': 'Rabbit 🐰'},
182
+ 'face-vase': {'left': 'Two Faces πŸ‘₯', 'right': 'Vase 🏺'},
183
+ 'young-old': {'left': 'Young Woman πŸ‘©', 'right': 'Old Woman πŸ‘΅'},
184
+ 'princess-oldMan': {'left': 'Princess πŸ‘Έ', 'right': 'Old Man πŸ‘΄'},
185
+ 'lily-woman': {'left': 'Lily 🌸', 'right': 'Woman πŸ‘©'},
186
+ 'tiger-monkey': {'left': 'Tiger πŸ…', 'right': 'Monkey πŸ’'}
187
  }
188
 
189
  if image_name in interpretations:
190
  specific = interpretations[image_name][predicted_class]
191
+ message += f"<p><strong>🎨 What you see:</strong> {specific}</p>"
192
+
193
+ message += "</div>"
194
 
195
  return message, viz, create_stats_table(image_name, model_type)
196
 
197
  def create_visualization(image_name, click_x, click_y, prediction, confidence, model_type='rf'):
198
  """Create a visualization showing the click point, centroids, and prediction"""
199
 
200
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6), facecolor='#f8f9fa')
201
 
202
  # Get model data
203
  model_data = all_models[image_name]
 
219
  bias = dist_right - dist_left
220
  features.append([dist_left, dist_right, bias])
221
 
222
+ X = pd.DataFrame(features, columns=['dist_to_left', 'dist_to_right', 'bias_to_left'])
223
  model = model_data[f'{model_type}_model']
224
  Z = model.predict(X)
225
  Z = Z.reshape(xx.shape)
226
 
227
  # Plot decision boundary
228
  from matplotlib.colors import ListedColormap
229
+ colors = ListedColormap(['#a8d5ff', '#ffb3b3']) # Softer blue and red
230
+ ax1.contourf(xx, yy, Z, alpha=0.7, cmap=colors)
231
 
232
  # Plot centroids
233
  ax1.scatter(centroid_left[0], centroid_left[1],
 
249
  ax1.set_ylabel('Y (pixels from center)')
250
  ax1.set_title(f'Decision Space - {model_type.upper()} Model')
251
  ax1.grid(True, alpha=0.3)
252
+ ax1.legend(loc='upper right', framealpha=0.9)
253
  ax1.set_xlim(-960, 960) # Full width range
254
  ax1.set_ylim(-540, 540) # Full height range
255
  ax1.set_aspect('equal')
256
+ ax1.set_facecolor('#f8f9fa') # Light background
257
 
258
  # Right plot: Statistics
259
  image_df = master_df[master_df['image_type'] == image_name]
260
 
261
  # Create bar chart of choices
262
  choice_counts = image_df['choice'].value_counts()
263
+ bars = ax2.bar(choice_counts.index, choice_counts.values,
264
+ color=['#4b86db' if x == 'left' else '#db4b4b' for x in choice_counts.index])
265
+
266
+ # Add values on top of bars
267
+ for bar in bars:
268
+ height = bar.get_height()
269
+ ax2.text(bar.get_x() + bar.get_width()/2., height + 0.5,
270
+ f'{height:.0f}',
271
+ ha='center', va='bottom', fontsize=10)
272
 
273
  # Add prediction annotation
274
  ax2.text(0.5, 0.95, f'Your Predicted Choice: {prediction.upper()}',
275
  transform=ax2.transAxes, ha='center', va='top',
276
  fontsize=16, fontweight='bold',
277
+ bbox=dict(boxstyle='round,pad=0.5', facecolor='#c2f0c2' if prediction == 'left' else '#f0c2c2',
278
+ alpha=0.9, edgecolor='gray'))
279
 
280
  ax2.text(0.5, 0.85, f'Confidence: {confidence:.1%}',
281
  transform=ax2.transAxes, ha='center', va='top', fontsize=14)
 
288
  ax2.text(0.5, 0.05, f'Model CV Accuracy: {model_data[f"cv_accuracy_{model_type}"]:.1%}',
289
  transform=ax2.transAxes, ha='center', va='bottom', fontsize=12,
290
  style='italic', alpha=0.7)
291
+
292
+ ax2.set_facecolor('#f8f9fa') # Light background
293
+
294
  plt.tight_layout()
295
 
296
  # Convert plot to image
 
307
  image_df = master_df[master_df['image_type'] == image_name]
308
 
309
  stats = {
310
+ 'Metric': ['πŸ‘₯ Total Participants', '⬅️ Left Choices', '➑️ Right Choices',
311
+ f'🎯 {model_type.upper()} Accuracy', 'βš–οΈ Class Balance', 'πŸ“Š Majority Choice'],
312
  'Value': [
313
  len(image_df),
314
  model_data['class_distribution'].get('left', 0),
315
  model_data['class_distribution'].get('right', 0),
316
  f"{model_data[f'cv_accuracy_{model_type}']:.1%}",
317
+ f"{min(model_data['class_distribution'].values()) / len(image_df):.1%}",
318
+ f"{image_df['choice'].mode()[0].title()} ({image_df['choice'].value_counts().max()}/{len(image_df)})"
319
  ]
320
  }
321
 
322
  return pd.DataFrame(stats)
323
 
324
+ # Custom CSS for better styling
325
+ css = """
326
+ .gradio-container {
327
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
328
+ }
329
+
330
+ .main-header {
331
+ text-align: center;
332
+ margin-bottom: 2rem;
333
+ padding: 1.5rem;
334
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
335
+ border-radius: 15px;
336
+ color: white;
337
+ box-shadow: 0 4px 15px rgba(0,0,0,0.1);
338
+ }
339
+
340
+ .instruction-box {
341
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
342
+ padding: 1rem;
343
+ border-radius: 10px;
344
+ color: white;
345
+ margin: 1rem 0;
346
+ }
347
+
348
+ .stats-highlight {
349
+ background-color: #f8f9fa;
350
+ border-left: 4px solid #007bff;
351
+ padding: 1rem;
352
+ margin: 0.5rem 0;
353
+ }
354
+ """
355
 
356
+ # Create Gradio Interface
357
+ with gr.Blocks(title="🧠 Optical Illusion First Fixation Predictor",
358
+ theme=gr.themes.Soft(), css=css) as demo:
359
+
360
+ gr.HTML("""
361
+ <div class="main-header">
362
+ <h1>🧠 Optical Illusion First Fixation Predictor</h1>
363
+ <h3>Can we predict what you see based on where you look?</h3>
364
+ <p>This AI-powered tool analyzes your first fixation point to predict which interpretation of an ambiguous image you'll perceive!</p>
365
+ </div>
366
  """)
367
 
368
  with gr.Row():
369
  with gr.Column(scale=2):
370
+ # Image selection with description
371
  available_images = list(all_models.keys()) if all_models else []
372
  default_image = available_images[0] if available_images else None
373
 
374
  image_choice = gr.Dropdown(
375
  choices=available_images,
376
  value=default_image,
377
+ label="πŸ–ΌοΈ Select Optical Illusion",
378
  info="Choose which ambiguous image to analyze"
379
  )
380
+
381
+ # Display image description
382
+ image_description = gr.Markdown(
383
+ value=IMAGE_DESCRIPTIONS.get(default_image, "Select an image to see its description.") if default_image else "Select an image to see its description.",
384
+ label="πŸ“– Image Description"
385
+ )
386
 
387
+ # Model selection with enhanced info
388
  model_type = gr.Radio(
389
+ choices=[("Random Forest (Recommended)", "rf"), ("Logistic Regression", "lr")],
390
  value="rf",
391
+ label="πŸ” Prediction Model",
392
+ info="Random Forest typically provides better accuracy for this task",
393
+ container=True
394
  )
395
 
396
+ # Display image with better styling
397
  image_display = gr.Image(
398
+ label="πŸ‘† Click where your eyes first landed on the image",
399
  interactive=True,
400
  type="pil",
401
+ height=540, # Reduced for better mobile compatibility
402
+ width=960,
403
+ elem_classes="main-image"
404
  )
405
 
406
  with gr.Column(scale=1):
407
+ # Results section with enhanced styling
408
+ prediction_output = gr.Markdown(
409
+ label="🧠 Prediction Results",
410
+ value="""<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 1rem; border-radius: 10px; color: white;">
411
+ <strong>πŸ‘† Click on the image to get your prediction!</strong><br><br>
412
+ The AI will analyze where you looked first and predict what you're likely to see.
413
+ </div>""",
414
+ elem_classes="stats-highlight"
415
+ )
416
+ stats_table = gr.DataFrame(label="πŸ“Š Image Statistics")
417
 
418
+ # Visualization output with better layout
419
  with gr.Row():
420
  visualization_output = gr.Image(
421
+ label="πŸ“ˆ Analysis Visualization",
422
  type="pil"
423
  )
424
 
425
+ # Enhanced information sections
426
+ with gr.Accordion("ℹ️ How It Works", open=False):
427
  gr.Markdown("""
428
+ ### πŸ€– The Science Behind the Prediction
429
+
430
+ **🎯 Feature Extraction:**
431
+ - We calculate the distance from your click point to the centroid of each interpretation region
432
+ - A "bias score" measures which region you're closer to
433
+
434
+ **🧠 Machine Learning Models:**
435
+ - **Random Forest:** Uses multiple decision trees for robust predictions
436
+ - **Logistic Regression:** A linear approach that's fast and interpretable
437
+
438
+ **πŸ“Š Training Process:**
439
+ - Trained on eye-tracking data from multiple participants
440
+ - Uses Leave-One-Participant-Out Cross-Validation for unbiased evaluation
441
+ - Ensures the model generalizes to new users
442
+
443
+ **🎨 Coordinate System:**
444
+ - Center of image = (0, 0)
445
+ - X-axis: -960 to +960 pixels (left to right)
446
+ - Y-axis: -540 to +540 pixels (bottom to top)
447
  """)
448
 
449
+ with gr.Accordion("πŸ“Š Model Performance", open=False):
450
+ if all_models:
451
+ summary_data = []
452
+ for img_name, model_data in all_models.items():
453
+ summary_data.append({
454
+ 'Image': img_name.replace('-', ' ').title(),
455
+ 'RF Accuracy': f"{model_data['cv_accuracy_rf']:.1%}",
456
+ 'LR Accuracy': f"{model_data['cv_accuracy_lr']:.1%}",
457
+ 'Participants': model_data['total_samples'],
458
+ 'Best Model': 'RF' if model_data['cv_accuracy_rf'] > model_data['cv_accuracy_lr'] else 'LR'
459
+ })
460
+
461
+ gr.DataFrame(
462
+ value=pd.DataFrame(summary_data),
463
+ label="Cross-Validation Performance Summary"
464
+ )
465
 
466
+ # Function to update image and description
467
+ def update_image_and_description(image_name):
468
  # Handle None case
469
  if image_name is None:
470
+ empty_stats = pd.DataFrame({
471
+ 'Metric': ['Select an image to see statistics'],
472
+ 'Value': ['']
473
+ })
474
+ return (create_placeholder_image(None),
475
+ "Select an image to see its description.",
476
+ """<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 1rem; border-radius: 10px; color: white;">
477
+ <strong>πŸ‘† Please select an image first!</strong>
478
+ </div>""",
479
+ empty_stats)
480
+
481
+ # Update description
482
+ description = IMAGE_DESCRIPTIONS.get(image_name, "Description not available.")
483
 
484
  # Use real images if available, otherwise use placeholder
485
  if image_name in illusion_images:
486
+ # Create initial stats table with proper data
487
+ model_data = all_models[image_name]
488
+ image_df = master_df[master_df['image_type'] == image_name]
489
+
490
+ stats = {
491
+ 'Metric': ['πŸ‘₯ Total Participants', '⬅️ Left Choices', '➑️ Right Choices',
492
+ '🎯 RF Accuracy', 'βš–οΈ Class Balance', 'πŸ“Š Majority Choice'],
493
+ 'Value': [
494
+ len(image_df),
495
+ model_data['class_distribution'].get('left', 0),
496
+ model_data['class_distribution'].get('right', 0),
497
+ f"{model_data['cv_accuracy_rf']:.1%}",
498
+ f"{min(model_data['class_distribution'].values()) / len(image_df):.1%}",
499
+ f"{image_df['choice'].mode()[0].title()} ({image_df['choice'].value_counts().max()}/{len(image_df)})"
500
+ ]
501
+ }
502
+ return (illusion_images[image_name],
503
+ f"**{description}**",
504
+ """<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 1rem; border-radius: 10px; color: white;">
505
+ <strong>πŸ‘† Click on the image to get your prediction!</strong><br><br>
506
+ The AI will analyze where you looked first and predict what you're likely to see.
507
+ </div>""",
508
+ pd.DataFrame(stats))
509
  else:
510
+ empty_stats = pd.DataFrame({
511
+ 'Metric': ['Image not found'],
512
+ 'Value': ['']
513
+ })
514
+ return (create_placeholder_image(image_name),
515
+ f"**{description}**",
516
+ """<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 1rem; border-radius: 10px; color: white;">
517
+ <strong>⚠️ Image file not found!</strong>
518
+ </div>""",
519
+ empty_stats)
520
 
521
  # Connect events
522
  image_choice.change(
523
+ fn=update_image_and_description,
524
  inputs=[image_choice],
525
+ outputs=[image_display, image_description, prediction_output, stats_table]
526
  )
527
 
528
  # Handle click event
 
534
 
535
  # Load initial image
536
  demo.load(
537
+ fn=update_image_and_description,
538
  inputs=[image_choice],
539
+ outputs=[image_display, image_description, prediction_output, stats_table]
540
  )
541
 
542
+ # Enhanced examples section
543
  if available_images:
544
+ gr.Markdown("## πŸ“Œ Quick Examples")
545
  with gr.Row():
 
546
  example_list = []
547
  for img in ["duck-rabbit", "face-vase", "young-old", "tiger-monkey"]:
548
  if img in available_images:
 
552
  gr.Examples(
553
  examples=example_list,
554
  inputs=[image_choice, model_type],
555
+ label="Try these popular illusions"
556
  )
557
+
558
+ # Enhanced footer
559
+ gr.HTML("""
560
+ <div style="text-align: center; margin-top: 2rem; padding: 1.5rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; color: white;">
561
+ <h4>πŸ”¬ WID2003 Cognitive Science Group Assignment - OCC 2 Group 2</h4>
562
+ <p><strong>Universiti Malaya</strong> | 2025</p>
563
+ <p style="font-size: 0.9em; opacity: 0.8;">Vote for Us!</p>
564
+ </div>
565
+ """)
566
 
567
  # Debug info
568
  print(f"\nImage folder: {image_folder}")
 
572
 
573
  # Launch the app
574
  if __name__ == "__main__":
575
+ demo.launch(
576
+ # share=True,
577
+ # debug=True
578
+ )
requirements.txt CHANGED
@@ -3,5 +3,5 @@ numpy
3
  pandas
4
  joblib
5
  matplotlib
6
- scikit-learn
7
  pillow
 
3
  pandas
4
  joblib
5
  matplotlib
6
+ scikit-learn==1.6.1
7
  pillow