ans123 commited on
Commit
579018b
·
verified ·
1 Parent(s): ade09dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +866 -695
app.py CHANGED
@@ -3,6 +3,7 @@ import cv2
3
  import numpy as np
4
  import pandas as pd
5
  import time
 
6
  import matplotlib.pyplot as plt
7
  from matplotlib.colors import LinearSegmentedColormap
8
  from matplotlib.collections import LineCollection
@@ -10,57 +11,80 @@ import os
10
  import datetime
11
  import tempfile
12
  from typing import Dict, List, Tuple, Optional, Union, Any
 
 
 
 
 
 
 
13
  import google.generativeai as genai
14
- from PIL import Image
15
- import json
16
- import warnings
17
- from deepface import DeepFace
18
- import base64
19
- import io
20
- from pathlib import Path
21
- import traceback
22
-
23
- # Suppress warnings for cleaner output
24
- warnings.filterwarnings('ignore')
25
 
26
  # --- Constants ---
27
- VIDEO_FPS = 30 # Target FPS for saved video
28
  CSV_FILENAME_TEMPLATE = "facial_analysis_{timestamp}.csv"
29
  VIDEO_FILENAME_TEMPLATE = "processed_{timestamp}.mp4"
30
- TEMP_DIR = Path("temp_frames")
31
- TEMP_DIR.mkdir(exist_ok=True)
32
 
33
- # --- Configure Google Gemini API ---
34
- print("Configuring Google Gemini API...")
35
- try:
36
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
37
- if not GOOGLE_API_KEY:
38
- raise ValueError("GOOGLE_API_KEY environment variable not set.")
39
-
40
- genai.configure(api_key=GOOGLE_API_KEY)
41
- # Use gemini-1.5-flash for quick responses
42
- model = genai.GenerativeModel('gemini-1.5-flash')
43
- GEMINI_ENABLED = True
44
- print("Google Gemini API configured successfully.")
45
- except Exception as e:
46
- print(f"WARNING: Failed to configure Google Gemini API: {e}")
47
- print("Running with simulated Gemini API responses.")
48
- GEMINI_ENABLED = False
49
 
50
- # --- Initialize OpenCV face detector for backup ---
51
- print("Initializing OpenCV face detector...")
52
- try:
53
- # Use OpenCV's built-in face detector as backup
54
- face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
55
-
56
- # Check if the face detector loaded successfully
57
- if face_cascade.empty():
58
- print("WARNING: Failed to load face cascade classifier")
59
- else:
60
- print("OpenCV face detector initialized successfully.")
61
- except Exception as e:
62
- print(f"ERROR initializing OpenCV face detector: {e}")
63
- face_cascade = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # --- Metrics Definition ---
66
  metrics = [
@@ -69,336 +93,408 @@ metrics = [
69
  "neuroticism", "conscientiousness", "extraversion",
70
  "stress_index", "engagement_level"
71
  ]
72
-
73
- # DeepFace emotion mapping
74
- emotion_mapping = {
75
- "angry": {"valence": 0.2, "arousal": 0.8, "dominance": 0.7},
76
- "disgust": {"valence": 0.2, "arousal": 0.6, "dominance": 0.5},
77
- "fear": {"valence": 0.2, "arousal": 0.8, "dominance": 0.3},
78
- "happy": {"valence": 0.9, "arousal": 0.7, "dominance": 0.6},
79
- "sad": {"valence": 0.3, "arousal": 0.4, "dominance": 0.3},
80
- "surprise": {"valence": 0.6, "arousal": 0.9, "dominance": 0.5},
81
- "neutral": {"valence": 0.5, "arousal": 0.5, "dominance": 0.5}
82
- }
83
-
84
  ad_context_columns = ["ad_description", "ad_detail", "ad_type", "gemini_ad_analysis"]
85
- user_state_columns = ["user_state", "enhanced_user_state"]
86
- all_columns = ['timestamp', 'frame_number'] + metrics + ad_context_columns + user_state_columns
87
  initial_metrics_df = pd.DataFrame(columns=all_columns)
88
 
89
- # --- Gemini API Functions ---
90
- def call_gemini_api_for_ad(description, detail, ad_type):
91
- """
92
- Uses Google Gemini to analyze ad context.
93
- """
94
- print(f"Analyzing ad context: '{description}' ({ad_type})")
95
-
96
- if not GEMINI_ENABLED:
97
- # Simulated response
98
- analysis = f"Simulated analysis: Ad='{description or 'N/A'}' ({ad_type}), Focus='{detail or 'N/A'}'."
99
- if not description and not detail:
100
- analysis = "No ad context provided."
101
- print(f"Simulated Gemini Result: {analysis}")
102
- return analysis
103
- else:
104
- try:
105
- prompt = f"""
106
- Please analyze this advertisement context:
107
- - Description: {description}
108
- - Detail focus: {detail}
109
- - Type/Genre: {ad_type}
110
-
111
- Provide a concise analysis of what emotional and cognitive responses might be expected from viewers.
112
- Limit your response to 100 words.
113
- """
114
-
115
- response = model.generate_content(prompt)
116
- return response.text
117
- except Exception as e:
118
- print(f"Error calling Gemini for ad context: {e}")
119
- return f"Error analyzing ad context: {str(e)}"
120
 
121
- def interpret_metrics_with_gemini(metrics_dict, deepface_results=None, ad_context=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  """
123
- Uses Google Gemini to interpret facial metrics and DeepFace results
124
- to determine user state.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  """
126
- if not metrics_dict and not deepface_results:
127
- return "No metrics", "No facial data detected"
128
 
129
- if not GEMINI_ENABLED:
130
- # Basic rule-based simulation for user state
131
- valence = metrics_dict.get('valence', 0.5) if metrics_dict else 0.5
132
- arousal = metrics_dict.get('arousal', 0.5) if metrics_dict else 0.5
133
-
134
- # Extract emotion from DeepFace if available
135
- dominant_emotion = "neutral"
136
- if deepface_results and "emotion" in deepface_results:
137
- emotion_dict = deepface_results["emotion"]
138
- dominant_emotion = max(emotion_dict.items(), key=lambda x: x[1])[0]
139
 
140
- # Simple rule-based simulation
141
- state = dominant_emotion.capitalize() if dominant_emotion != "neutral" else "Neutral"
142
- if valence > 0.65 and arousal > 0.55:
143
- state = "Positive, Engaged"
144
- elif valence < 0.4 and arousal > 0.6:
145
- state = "Stressed, Negative"
 
 
146
 
147
- enhanced_state = f"The viewer appears {state.lower()} while watching this content."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- return state, enhanced_state
150
- else:
151
- try:
152
- # Format metrics for Gemini
153
- metrics_formatted = ""
154
- if metrics_dict:
155
- metrics_formatted = "\nMetrics (0-1 scale):\n" + "\n".join([f"- {k.replace('_', ' ').title()}: {v:.2f}" for k, v in metrics_dict.items()
156
- if k not in ('timestamp', 'frame_number')])
157
-
158
- # Format DeepFace results
159
- deepface_formatted = ""
160
- if deepface_results and "emotion" in deepface_results:
161
- emotion_dict = deepface_results["emotion"]
162
- deepface_formatted = "\nDeepFace emotions:\n" + "\n".join([f"- {k.title()}: {v:.2f}" for k, v in emotion_dict.items()])
163
-
164
- # Include ad context if available
165
- ad_info = ""
166
- if ad_context:
167
- ad_desc = ad_context.get('ad_description', 'N/A')
168
- ad_type = ad_context.get('ad_type', 'N/A')
169
- ad_info = f"\nThey are watching an advertisement: {ad_desc} (Type: {ad_type})"
170
-
171
- prompt = f"""
172
- Analyze the facial expression and emotion of a person watching an advertisement{ad_info}.
173
-
174
- Use these combined inputs:{metrics_formatted}{deepface_formatted}
175
-
176
- Provide two outputs:
177
- 1. User State: A short 1-3 word description of their emotional/cognitive state
178
- 2. Enhanced Analysis: A detailed 1-2 sentence interpretation of their reaction to the content
179
-
180
- Format as JSON: {{"user_state": "STATE", "enhanced_user_state": "DETAILED ANALYSIS"}}
181
- """
182
-
183
- response = model.generate_content(prompt)
184
-
185
- try:
186
- # Try to parse as JSON
187
- result = json.loads(response.text)
188
- return result.get("user_state", "Uncertain"), result.get("enhanced_user_state", "Analysis unavailable")
189
- except json.JSONDecodeError:
190
- # If not valid JSON, try to extract manually
191
- text = response.text
192
- if "user_state" in text and "enhanced_user_state" in text:
193
- parts = text.split("enhanced_user_state")
194
- user_state = parts[0].split("user_state")[1].replace('"', '').replace(':', '').replace(',', '').strip()
195
- enhanced = parts[1].replace('"', '').replace(':', '').replace('}', '').strip()
196
- return user_state, enhanced
197
- else:
198
- # Just return the raw text as enhanced state
199
- return "Analyzed", text
200
-
201
- except Exception as e:
202
- print(f"Error calling Gemini for metric interpretation: {e}")
203
- traceback.print_exc()
204
- return "Error", f"Error analyzing facial metrics: {str(e)}"
205
 
206
- # --- DeepFace Analysis Function ---
207
- def analyze_face_with_deepface(image):
208
- """Analyze facial emotions and attributes using DeepFace"""
209
- if image is None:
 
 
 
 
 
 
 
 
 
210
  return None
 
 
 
 
211
 
212
  try:
213
- # Convert to RGB for DeepFace if needed
214
- if len(image.shape) == 3 and image.shape[2] == 3:
215
- # Check if BGR and convert to RGB if needed
216
- if np.mean(image[:,:,0]) < np.mean(image[:,:,2]): # Rough BGR check
217
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
218
- else:
219
- image_rgb = image
220
- else:
221
- # Handle grayscale or other formats
222
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
223
 
224
- # Save image to temp file (DeepFace sometimes works better with files)
225
- temp_img = f"temp_frames/temp_analysis_{time.time()}.jpg"
226
- cv2.imwrite(temp_img, image_rgb)
227
 
228
- # Analyze with DeepFace
229
- analysis = DeepFace.analyze(
230
- img_path=temp_img,
231
- actions=['emotion'],
232
- enforce_detection=False, # Don't throw error if face not detected
233
- detector_backend='opencv' # Faster detection
234
- )
235
 
236
- # Remove temporary file
237
- try:
238
- os.remove(temp_img)
239
- except:
240
- pass
 
 
 
 
241
 
242
- # Return the first face analysis (assuming single face)
243
- if isinstance(analysis, list) and len(analysis) > 0:
244
- return analysis[0]
245
- else:
246
- return analysis
 
 
 
 
 
 
 
 
 
 
 
 
247
 
 
 
 
 
 
 
 
 
248
  except Exception as e:
249
- print(f"DeepFace analysis error: {e}")
250
  return None
251
 
252
- # --- Face Detection Backup with OpenCV ---
253
- def detect_face_opencv(image):
254
- """Detect faces using OpenCV cascade classifier as backup"""
255
- if image is None or face_cascade is None:
256
  return None
257
-
258
  try:
259
- # Convert to grayscale for detection
260
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- # Detect faces
263
- faces = face_cascade.detectMultiScale(
264
- gray,
265
- scaleFactor=1.1,
266
- minNeighbors=5,
267
- minSize=(30, 30)
268
- )
269
 
270
- if len(faces) == 0:
271
- return None
272
 
273
- # Get the largest face by area
274
- largest_face = max(faces, key=lambda rect: rect[2] * rect[3])
 
 
 
275
 
276
- return {"rect": largest_face}
277
-
278
- except Exception as e:
279
- print(f"Error in OpenCV face detection: {e}")
280
- return None
281
 
282
- # --- Calculate Metrics from DeepFace Results ---
283
- def calculate_metrics_from_deepface(deepface_results, ad_context=None):
284
- """
285
- Calculate psychometric metrics from DeepFace analysis results
286
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  if ad_context is None:
288
  ad_context = {}
289
-
290
- # Initialize default metrics
291
- default_metrics = {m: 0.5 for m in metrics}
292
-
293
- # If no facial data, return defaults
294
- if not deepface_results or "emotion" not in deepface_results:
295
- return default_metrics
296
-
297
- # Extract emotion data from DeepFace
298
- emotion_dict = deepface_results["emotion"]
299
- # Find dominant emotion
300
- dominant_emotion = max(emotion_dict.items(), key=lambda x: x[1])[0]
301
- dominant_score = max(emotion_dict.items(), key=lambda x: x[1])[1] / 100.0 # Convert to 0-1 scale
302
-
303
- # Get base values from emotion mapping
304
- base_vals = emotion_mapping.get(dominant_emotion, {"valence": 0.5, "arousal": 0.5, "dominance": 0.5})
305
-
306
- # Calculate primary metrics with confidence weighting
307
- val = base_vals["valence"]
308
- arsl = base_vals["arousal"]
309
- dom = base_vals["dominance"]
310
-
311
- # Add directional adjustments based on specific emotions
312
- if dominant_emotion == "happy":
313
- val += 0.1
314
- elif dominant_emotion == "sad":
315
- val -= 0.1
316
- elif dominant_emotion == "angry":
317
- arsl += 0.1
318
- dom += 0.1
319
- elif dominant_emotion == "fear":
320
- arsl += 0.1
321
- dom -= 0.1
322
-
323
- # Illustrative Context Adjustments from ad
324
- ad_type = ad_context.get('ad_type', 'Unknown')
325
  gem_txt = str(ad_context.get('gemini_ad_analysis', '')).lower()
326
-
327
- # Adjust based on ad context
328
- val_adj = 0.1 if ad_type == 'Funny' or 'humor' in gem_txt else 0.0
329
- arsl_adj = 0.1 if ad_type == 'Action' or 'exciting' in gem_txt else 0.0
330
-
331
- # Apply adjustments
332
- val = max(0, min(1, val + val_adj))
333
- arsl = max(0, min(1, arsl + arsl_adj))
334
-
335
- # Estimate cognitive load based on emotional intensity
336
- cl = 0.5 # Default
337
- if dominant_emotion in ["neutral"]:
338
- cl = 0.3 # Lower cognitive load for neutral expression
339
- elif dominant_emotion in ["surprise", "fear"]:
340
- cl = 0.7 # Higher cognitive load for surprise/fear
341
-
342
- # Calculate secondary metrics
343
  neur = max(0, min(1, (cl * 0.6) + ((1.0 - val) * 0.4)))
344
  em_stab = 1.0 - neur
345
  extr = max(0, min(1, (arsl * 0.5) + (val * 0.5)))
346
- open = max(0, min(1, 0.5 + (val - 0.5) * 0.5))
347
  agree = max(0, min(1, (val * 0.7) + ((1.0 - arsl) * 0.3)))
348
  consc = max(0, min(1, (1.0 - abs(arsl - 0.5)) * 0.7 + (em_stab * 0.3)))
349
- stress = max(0, min(1, (cl * 0.5) + ((1.0 - val) * 0.5)))
350
- engag = max(0, min(1, arsl * 0.7 + (val * 0.3)))
351
-
352
- # Create metrics dictionary
353
- calculated_metrics = {
354
- 'valence': val,
355
- 'arousal': arsl,
356
- 'dominance': dom,
357
- 'cognitive_load': cl,
358
- 'emotional_stability': em_stab,
359
- 'openness': open,
360
- 'agreeableness': agree,
361
- 'neuroticism': neur,
362
- 'conscientiousness': consc,
363
- 'extraversion': extr,
364
- 'stress_index': stress,
365
- 'engagement_level': engag
366
  }
367
-
368
- return calculated_metrics
369
 
370
- def update_metrics_visualization(metrics_values):
371
- """Create a visualization of metrics"""
372
  if not metrics_values:
373
  fig, ax = plt.subplots(figsize=(10, 8))
374
- ax.text(0.5, 0.5, "Waiting for facial metrics...", ha='center', va='center')
375
  ax.axis('off')
376
  fig.patch.set_facecolor('#FFFFFF')
377
  ax.set_facecolor('#FFFFFF')
378
  return fig
379
 
380
- # Filter out non-metric keys
381
- filtered_metrics = {k: v for k, v in metrics_values.items()
382
- if k in metrics and isinstance(v, (int, float))}
 
 
383
 
384
- if not filtered_metrics:
385
- fig, ax = plt.subplots(figsize=(10, 8))
386
- ax.text(0.5, 0.5, "No valid metrics available", ha='center', va='center')
387
- ax.axis('off')
388
- return fig
389
 
390
- num_metrics = len(filtered_metrics)
391
  nrows = (num_metrics + 2) // 3
392
  fig, axs = plt.subplots(nrows, 3, figsize=(10, nrows * 2.5), facecolor='#FFFFFF')
393
  axs = axs.flatten()
394
 
 
 
 
395
  colors = [(0.1, 0.1, 0.9), (0.9, 0.9, 0.1), (0.9, 0.1, 0.1)]
396
  cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=100)
397
  norm = plt.Normalize(0, 1)
398
  metric_idx = 0
399
 
400
- for key, value in filtered_metrics.items():
401
- value = max(0.0, min(1.0, value)) # Clip value for safety
 
 
402
 
403
  ax = axs[metric_idx]
404
  ax.set_title(key.replace('_', ' ').title(), fontsize=10)
@@ -440,69 +536,97 @@ def update_metrics_visualization(metrics_values):
440
  plt.tight_layout(pad=0.5)
441
  return fig
442
 
443
- def annotate_frame(frame, face_data=None, deepface_results=None, metrics=None, enhanced_state=None):
444
- """
445
- Add facial annotations and metrics to a frame
446
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  if frame is None:
448
  return None
449
 
450
  annotated = frame.copy()
451
 
452
- # Draw face rectangle if available
453
- if face_data and "rect" in face_data:
454
- x, y, w, h = face_data["rect"]
455
- cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 2)
456
- elif deepface_results and "region" in deepface_results:
457
- region = deepface_results["region"]
458
- x, y, w, h = region["x"], region["y"], region["w"], region["h"]
459
- cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 2)
460
-
461
- # Add emotion and metrics summary
462
- if deepface_results or metrics:
463
- # Format for display
464
- h, w = annotated.shape[:2]
465
- y_pos = 30 # Starting Y position
466
-
467
- # Add emotion info if available from DeepFace
468
- if deepface_results and "dominant_emotion" in deepface_results:
469
- emotion_text = f"Emotion: {deepface_results['dominant_emotion'].capitalize()}"
470
- text_size = cv2.getTextSize(emotion_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
471
- cv2.rectangle(annotated, (10, y_pos - 20), (10 + text_size[0], y_pos + 5), (0, 0, 0), -1)
472
- cv2.putText(annotated, emotion_text, (10, y_pos),
473
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
474
- y_pos += 30
475
-
476
- # Add enhanced user state if available
477
- if enhanced_state:
478
- # Truncate if too long
479
- if len(enhanced_state) > 60:
480
- enhanced_state = enhanced_state[:57] + "..."
481
-
482
- # Draw background for text
483
- text_size = cv2.getTextSize(enhanced_state, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
484
- cv2.rectangle(annotated, (10, y_pos - 20), (10 + text_size[0], y_pos + 5), (0, 0, 0), -1)
485
- # Draw text
486
- cv2.putText(annotated, enhanced_state, (10, y_pos),
487
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
488
- y_pos += 30
489
-
490
- # Show top 3 metrics
491
- if metrics:
492
- top_metrics = sorted([(k, v) for k, v in metrics.items() if k in metrics],
493
- key=lambda x: x[1], reverse=True)[:3]
494
-
495
- for name, value in top_metrics:
496
- metric_text = f"{name.replace('_', ' ').title()}: {value:.2f}"
497
- text_size = cv2.getTextSize(metric_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
498
- cv2.rectangle(annotated, (10, y_pos - 15), (10 + text_size[0], y_pos + 5), (0, 0, 0), -1)
499
- cv2.putText(annotated, metric_text, (10, y_pos),
500
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
501
- y_pos += 25
502
 
503
  return annotated
504
 
505
- # --- API 1: Video File Processing ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  def process_video_file(
507
  video_file: Union[str, np.ndarray],
508
  ad_description: str = "",
@@ -510,8 +634,8 @@ def process_video_file(
510
  ad_type: str = "Video",
511
  sampling_rate: int = 5, # Process every Nth frame
512
  save_processed_video: bool = True,
513
- show_progress: bool = True
514
- ) -> Tuple[str, str, pd.DataFrame, List[np.ndarray]]:
515
  """
516
  Process a video file and analyze facial expressions frame by frame
517
 
@@ -522,17 +646,21 @@ def process_video_file(
522
  ad_type: Type of ad (Video, Image, Audio, Text, Funny, etc.)
523
  sampling_rate: Process every Nth frame
524
  save_processed_video: Whether to save the processed video with annotations
525
- show_progress: Whether to show processing progress
526
 
527
  Returns:
528
- Tuple of (csv_path, processed_video_path, metrics_dataframe, processed_frames_list)
529
  """
 
 
 
530
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
531
  csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp)
 
532
  video_path = VIDEO_FILENAME_TEMPLATE.format(timestamp=timestamp) if save_processed_video else None
533
 
534
  # Setup ad context
535
- gemini_result = call_gemini_api_for_ad(ad_description, ad_detail, ad_type)
536
  ad_context = {
537
  "ad_description": ad_description,
538
  "ad_detail": ad_detail,
@@ -540,6 +668,8 @@ def process_video_file(
540
  "gemini_ad_analysis": gemini_result
541
  }
542
 
 
 
543
  # Initialize capture
544
  if isinstance(video_file, str):
545
  cap = cv2.VideoCapture(video_file)
@@ -549,64 +679,24 @@ def process_video_file(
549
  temp_path = os.path.join(temp_dir, "temp_video.mp4")
550
 
551
  # Convert video array to file
552
- if isinstance(video_file, np.ndarray) and len(video_file.shape) == 4: # Multiple frames
553
- h, w = video_file[0].shape[:2]
 
554
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
555
  temp_writer = cv2.VideoWriter(temp_path, fourcc, 30, (w, h))
556
  for frame in video_file:
557
  temp_writer.write(frame)
558
  temp_writer.release()
559
- cap = cv2.VideoCapture(temp_path)
560
- elif isinstance(video_file, np.ndarray) and len(video_file.shape) == 3: # Single frame
561
- # For single frame, just process it directly
562
- metrics_data = []
563
- processed_frames = []
564
-
565
- # Process the single frame
566
- deepface_results = analyze_face_with_deepface(video_file)
567
- face_data = None
568
-
569
- # Fall back to OpenCV face detection if DeepFace didn't detect a face
570
- if not deepface_results or "region" not in deepface_results:
571
- face_data = detect_face_opencv(video_file)
572
-
573
- # Calculate metrics if face detected
574
- if deepface_results or face_data:
575
- calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
576
- user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context)
577
-
578
- # Create a row for the dataframe
579
- row = {
580
- 'timestamp': 0.0,
581
- 'frame_number': 0,
582
- **calculated_metrics,
583
- **ad_context,
584
- 'user_state': user_state,
585
- 'enhanced_user_state': enhanced_state
586
- }
587
- metrics_data.append(row)
588
-
589
- # Annotate the frame
590
- annotated_frame = annotate_frame(video_file, face_data, deepface_results, calculated_metrics, enhanced_state)
591
- processed_frames.append(annotated_frame)
592
-
593
- # Save processed image
594
- if save_processed_video:
595
- cv2.imwrite(video_path.replace('.mp4', '.jpg'), annotated_frame)
596
-
597
- # Create DataFrame and save to CSV
598
- metrics_df = pd.DataFrame(metrics_data)
599
- if not metrics_df.empty:
600
- metrics_df.to_csv(csv_path, index=False)
601
-
602
- return csv_path, video_path.replace('.mp4', '.jpg') if save_processed_video else None, metrics_df, processed_frames
603
- else:
604
- print("Error: Invalid video input format")
605
- return None, None, None, []
606
 
607
  if not cap.isOpened():
608
  print("Error: Could not open video.")
609
- return None, None, None, []
 
 
 
610
 
611
  # Get video properties
612
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -617,68 +707,117 @@ def process_video_file(
617
  # Initialize video writer if saving processed video
618
  if save_processed_video:
619
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
620
- out = cv2.VideoWriter(video_path, fourcc, fps / sampling_rate, (frame_width, frame_height))
621
 
622
  # Process video frames
623
  metrics_data = []
624
- processed_frames = []
625
  frame_count = 0
626
 
627
- if show_progress:
628
- print(f"Processing video with {total_frames} frames at {fps} FPS")
629
- print(f"Ad Context: {ad_description} ({ad_type})")
630
-
631
- while True:
632
- ret, frame = cap.read()
633
- if not ret:
634
- break
635
 
636
- # Only process every Nth frame (according to sampling_rate)
637
- if frame_count % sampling_rate == 0:
638
- if show_progress and frame_count % (sampling_rate * 10) == 0:
639
- print(f"Processing frame {frame_count}/{total_frames} ({frame_count/total_frames*100:.1f}%)")
640
-
641
- # Analyze with DeepFace
642
- deepface_results = analyze_face_with_deepface(frame)
643
- face_data = None
644
 
645
- # Fall back to OpenCV face detection if DeepFace didn't detect a face
646
- if not deepface_results or "region" not in deepface_results:
647
- face_data = detect_face_opencv(frame)
648
 
649
- # Calculate metrics if face detected
650
- if deepface_results or face_data:
651
- calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
652
- user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context)
653
 
654
- # Create a row for the dataframe
655
- row = {
656
- 'timestamp': frame_count / fps,
657
- 'frame_number': frame_count,
658
- **calculated_metrics,
659
- **ad_context,
660
- 'user_state': user_state,
661
- 'enhanced_user_state': enhanced_state
662
- }
663
- metrics_data.append(row)
664
 
665
- # Annotate the frame
666
- annotated_frame = annotate_frame(frame, face_data, deepface_results, calculated_metrics, enhanced_state)
 
 
 
667
 
668
- if save_processed_video:
669
- out.write(annotated_frame)
670
- processed_frames.append(annotated_frame)
671
- else:
672
- # No face detected
673
- if save_processed_video:
674
- # Add text to frame
675
- no_face_frame = frame.copy()
676
- cv2.putText(no_face_frame, "No face detected", (30, 30),
677
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
678
- out.write(no_face_frame)
679
- processed_frames.append(no_face_frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
 
681
- frame_count += 1
 
 
 
 
 
682
 
683
  # Release resources
684
  cap.release()
@@ -689,26 +828,25 @@ def process_video_file(
689
  metrics_df = pd.DataFrame(metrics_data)
690
  if not metrics_df.empty:
691
  metrics_df.to_csv(csv_path, index=False)
692
-
693
- if show_progress:
694
- print(f"Video processing complete. Analyzed {len(metrics_data)} frames.")
695
- print(f"Results saved to {csv_path}")
696
- if save_processed_video:
697
- print(f"Processed video saved to {video_path}")
698
 
699
  # Return results
700
- return csv_path, video_path, metrics_df, processed_frames
701
 
702
- # --- API 2: Webcam Processing Function ---
703
  def process_webcam_frame(
704
  frame: np.ndarray,
705
  ad_context: Dict[str, Any],
706
  metrics_data: pd.DataFrame,
707
  frame_count: int,
708
- start_time: float
709
- ) -> Tuple[np.ndarray, Dict[str, float], str, pd.DataFrame]:
 
 
710
  """
711
- Process a single webcam frame
712
 
713
  Args:
714
  frame: Input frame from webcam
@@ -716,68 +854,113 @@ def process_webcam_frame(
716
  metrics_data: DataFrame to accumulate metrics
717
  frame_count: Current frame count
718
  start_time: Start time of the session
 
 
719
 
720
  Returns:
721
- Tuple of (annotated_frame, metrics_dict, enhanced_state, updated_metrics_df)
722
  """
723
  if frame is None:
724
- return None, None, None, metrics_data
725
 
726
- # Analyze with DeepFace
727
- deepface_results = analyze_face_with_deepface(frame)
728
- face_data = None
729
 
730
- # Fall back to OpenCV face detection if DeepFace didn't detect a face
731
- if not deepface_results or "region" not in deepface_results:
732
- face_data = detect_face_opencv(frame)
733
 
734
- # Calculate metrics if face detected
735
- if deepface_results or face_data:
736
- calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
737
- user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context)
 
 
 
 
 
 
 
738
 
739
  # Create a row for the dataframe
740
- current_time = time.time()
741
  row = {
742
- 'timestamp': current_time - start_time,
743
  'frame_number': frame_count,
744
- **calculated_metrics,
745
- **ad_context,
746
- 'user_state': user_state,
747
- 'enhanced_user_state': enhanced_state
748
  }
749
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  # Add row to DataFrame
751
  new_row_df = pd.DataFrame([row], columns=all_columns)
752
  metrics_data = pd.concat([metrics_data, new_row_df], ignore_index=True)
753
 
 
 
 
 
 
 
 
 
754
  # Annotate the frame
755
- annotated_frame = annotate_frame(frame, face_data, deepface_results, calculated_metrics, enhanced_state)
 
 
 
 
 
 
 
 
 
 
 
756
 
757
- return annotated_frame, calculated_metrics, enhanced_state, metrics_data
 
 
 
 
 
 
 
 
 
 
 
 
758
  else:
759
  # No face detected
760
- no_face_frame = frame.copy()
761
- cv2.putText(no_face_frame, "No face detected", (30, 30),
762
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
763
- return no_face_frame, None, "No face detected", metrics_data
764
 
 
765
  def start_webcam_session(
766
  ad_description: str = "",
767
  ad_detail: str = "",
768
  ad_type: str = "Video",
769
  save_interval: int = 100, # Save CSV every N frames
770
- record_video: bool = True
771
  ) -> Dict[str, Any]:
772
  """
773
- Initialize a webcam session for facial analysis
774
 
775
  Args:
776
  ad_description: Description of the ad being watched
777
  ad_detail: Detail focus of the ad
778
  ad_type: Type of ad
779
  save_interval: How often to save data to CSV
780
- record_video: Whether to record processed frames for later saving
781
 
782
  Returns:
783
  Session context dictionary
@@ -785,10 +968,13 @@ def start_webcam_session(
785
  # Generate timestamp for file naming
786
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
787
  csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp)
788
- video_path = VIDEO_FILENAME_TEMPLATE.format(timestamp=timestamp) if record_video else None
 
 
 
789
 
790
  # Setup ad context
791
- gemini_result = call_gemini_api_for_ad(ad_description, ad_detail, ad_type)
792
  ad_context = {
793
  "ad_description": ad_description,
794
  "ad_detail": ad_detail,
@@ -803,20 +989,25 @@ def start_webcam_session(
803
  "metrics_data": initial_metrics_df.copy(),
804
  "ad_context": ad_context,
805
  "csv_path": csv_path,
806
- "video_path": video_path,
807
  "save_interval": save_interval,
808
  "last_saved": 0,
809
- "record_video": record_video,
810
- "recorded_frames": [] if record_video else None,
811
- "timestamps": [] if record_video else None
812
  }
813
 
 
 
 
 
 
 
814
  return session
815
 
816
  def update_webcam_session(
817
  session: Dict[str, Any],
818
  frame: np.ndarray
819
- ) -> Tuple[np.ndarray, Dict[str, float], str, Dict[str, Any]]:
820
  """
821
  Update webcam session with a new frame
822
 
@@ -825,33 +1016,47 @@ def update_webcam_session(
825
  frame: New frame from webcam
826
 
827
  Returns:
828
- Tuple of (annotated_frame, metrics_dict, enhanced_state, updated_session)
829
  """
830
- # Process the frame
831
- annotated_frame, metrics, enhanced_state, updated_df = process_webcam_frame(
832
- frame,
833
- session["ad_context"],
834
- session["metrics_data"],
835
- session["frame_count"],
836
- session["start_time"]
837
- )
838
-
839
- # Update session
 
 
 
 
 
 
840
  session["frame_count"] += 1
841
- session["metrics_data"] = updated_df
842
 
843
- # Record frame if enabled
844
- if session["record_video"] and annotated_frame is not None:
845
- session["recorded_frames"].append(annotated_frame)
846
- session["timestamps"].append(time.time() - session["start_time"])
 
 
 
 
 
 
 
 
 
847
 
848
  # Save CSV periodically
849
  if session["frame_count"] - session["last_saved"] >= session["save_interval"]:
850
- if not updated_df.empty:
851
- updated_df.to_csv(session["csv_path"], index=False)
852
  session["last_saved"] = session["frame_count"]
853
 
854
- return annotated_frame, metrics, enhanced_state, session
855
 
856
  def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]:
857
  """
@@ -861,69 +1066,34 @@ def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]:
861
  session: Session context dictionary
862
 
863
  Returns:
864
- Tuple of (csv_path, video_path)
865
  """
 
 
 
 
 
 
 
 
866
  # Save final metrics to CSV
867
  if not session["metrics_data"].empty:
868
  session["metrics_data"].to_csv(session["csv_path"], index=False)
869
 
870
- # Save recorded video if available
871
- video_path = None
872
- if session["record_video"] and session["recorded_frames"]:
873
- try:
874
- frames = session["recorded_frames"]
875
- if frames:
876
- # Get frame dimensions
877
- height, width = frames[0].shape[:2]
878
-
879
- # Calculate FPS based on actual timestamps
880
- if len(session["timestamps"]) > 1:
881
- # Calculate average time between frames
882
- time_diffs = np.diff(session["timestamps"])
883
- avg_frame_time = np.mean(time_diffs)
884
- fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 15.0
885
- else:
886
- fps = 15.0 # Default FPS
887
-
888
- # Create video writer
889
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
890
- video_path = session["video_path"]
891
- out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
892
-
893
- # Write frames
894
- for frame in frames:
895
- out.write(frame)
896
-
897
- out.release()
898
- print(f"Recorded video saved to {video_path}")
899
- else:
900
- print("No frames recorded")
901
- except Exception as e:
902
- print(f"Error saving video: {e}")
903
-
904
  print(f"Session ended. Data saved to {session['csv_path']}")
905
- return session["csv_path"], video_path
906
 
907
- # --- Create Gradio Interface ---
908
  def create_api_interface():
909
- with gr.Blocks(title="Facial Analysis APIs") as iface:
910
- gr.Markdown(f"""
911
- # Enhanced Facial Analysis APIs (DeepFace)
912
-
913
- This interface provides two API endpoints:
914
-
915
- 1. **Video File API**: Upload and analyze pre-recorded videos
916
- 2. **Webcam API**: Analyze live webcam feed in real-time
917
-
918
- Both APIs use DeepFace for emotion analysis and Google's Gemini API for enhanced interpretations.
919
- """)
920
 
921
  with gr.Tab("Video File API"):
922
  with gr.Row():
923
  with gr.Column(scale=1):
924
  video_input = gr.Video(label="Upload Video")
925
- vid_ad_desc = gr.Textbox(label="Ad Description", placeholder="Enter a description of the advertisement being watched...")
926
- vid_ad_detail = gr.Textbox(label="Ad Detail Focus", placeholder="Enter specific aspects to focus on...")
927
  vid_ad_type = gr.Radio(
928
  ["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"],
929
  label="Ad Type/Genre",
@@ -934,181 +1104,182 @@ def create_api_interface():
934
  label="Sampling Rate (process every N frames)"
935
  )
936
  save_video = gr.Checkbox(label="Save Processed Video", value=True)
937
- process_btn = gr.Button("Process Video", variant="primary")
938
 
939
  with gr.Column(scale=2):
940
- output_text = gr.Textbox(label="Processing Results", lines=3)
941
  with gr.Row():
942
- with gr.Column():
943
- output_video = gr.Video(label="Processed Video")
944
- with gr.Column():
945
- frame_gallery = gr.Gallery(label="Processed Frames",
946
- show_label=True, columns=2,
947
- height=400)
 
 
948
 
949
  with gr.Row():
950
- with gr.Column():
951
- output_plot = gr.Plot(label="Sample Frame Metrics")
952
- with gr.Column():
953
- output_csv = gr.File(label="Download CSV Results")
954
 
955
- # Define function to handle video processing and show frames
956
- def handle_video_processing(video, desc, detail, ad_type, rate, save_vid):
957
  if video is None:
958
- return "No video uploaded", None, None, [], None
959
 
960
  try:
961
- result_text = "Starting video processing...\n"
962
- # Process the video
963
- csv_path, video_path, metrics_df, processed_frames = process_video_file(
964
  video,
965
  ad_description=desc,
966
  ad_detail=detail,
967
  ad_type=ad_type,
968
  sampling_rate=rate,
969
  save_processed_video=save_vid,
970
- show_progress=True
971
  )
972
 
973
  if metrics_df is None or metrics_df.empty:
974
- return "No facial data detected in video", None, None, [], None
 
 
 
 
 
 
 
 
 
 
 
975
 
976
- # Generate a sample metrics visualization
977
- sample_row = metrics_df.iloc[0].to_dict()
978
- metrics_plot = update_metrics_visualization(sample_row)
 
979
 
980
- # Create a gallery of processed frames
981
- # Take a subset if there are too many frames (maximum ~20 for display)
982
- display_frames = []
983
- step = max(1, len(processed_frames) // 20)
984
- for i in range(0, len(processed_frames), step):
985
- if i < len(processed_frames):
986
- # Convert BGR to RGB for display
987
- rgb_frame = cv2.cvtColor(processed_frames[i], cv2.COLOR_BGR2RGB)
988
- display_frames.append(rgb_frame)
989
 
990
- # Return results summary
991
- processed_count = metrics_df.shape[0]
992
- total_count = len(processed_frames)
993
- result_text = f"✅ Processed {processed_count} frames out of {total_count} total frames.\n"
994
- result_text += f"📊 CSV saved with {len(metrics_df.columns)} metrics columns.\n"
995
  if video_path:
996
- result_text += f"🎬 Processed video saved to: {video_path}"
997
 
998
- return result_text, video_path, metrics_plot, display_frames, csv_path
999
  except Exception as e:
1000
- return f"Error processing video: {str(e)}", None, None, [], None
1001
 
1002
  process_btn.click(
1003
  handle_video_processing,
1004
  inputs=[video_input, vid_ad_desc, vid_ad_detail, vid_ad_type, sampling_rate, save_video],
1005
- outputs=[output_text, output_video, output_plot, frame_gallery, output_csv]
1006
  )
1007
 
1008
  with gr.Tab("Webcam API"):
1009
  with gr.Row():
1010
- with gr.Column(scale=2):
1011
  webcam_input = gr.Image(sources="webcam", streaming=True, label="Webcam Input", type="numpy")
1012
-
1013
- with gr.Row():
1014
- with gr.Column():
1015
- web_ad_desc = gr.Textbox(label="Ad Description", placeholder="Enter a description of the advertisement being watched...")
1016
- web_ad_detail = gr.Textbox(label="Ad Detail Focus", placeholder="Enter specific aspects to focus on...")
1017
- web_ad_type = gr.Radio(
1018
- ["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"],
1019
- label="Ad Type/Genre",
1020
- value="Video"
1021
- )
1022
- with gr.Column():
1023
- record_video_chk = gr.Checkbox(label="Record Video", value=True)
1024
- start_session_btn = gr.Button("Start Session", variant="primary")
1025
- end_session_btn = gr.Button("End Session", variant="stop")
1026
- session_status = gr.Textbox(label="Session Status", placeholder="Session not started...")
1027
 
1028
  with gr.Column(scale=2):
1029
- processed_output = gr.Image(label="Processed Feed", type="numpy", height=360)
 
1030
 
1031
  with gr.Row():
1032
- with gr.Column():
1033
- metrics_plot = gr.Plot(label="Current Metrics", height=300)
1034
- with gr.Column():
1035
- enhanced_state_txt = gr.Textbox(label="Enhanced State Analysis", lines=3)
1036
 
1037
  with gr.Row():
 
1038
  download_csv = gr.File(label="Download Session Data")
1039
- download_video = gr.Video(label="Recorded Session")
1040
 
1041
  # Session state
1042
  session_data = gr.State(value=None)
1043
 
1044
  # Define session handlers
1045
- def start_session(desc, detail, ad_type, record_video):
1046
- session = start_webcam_session(
1047
- ad_description=desc,
1048
- ad_detail=detail,
1049
- ad_type=ad_type,
1050
- record_video=record_video
1051
- )
1052
- return (
1053
- session,
1054
- f"Session started at {datetime.datetime.now().strftime('%H:%M:%S')}.\n"
1055
- f"Ad context: {desc} ({ad_type}).\n"
1056
- f"Data will be saved to {session['csv_path']}"
1057
- )
 
 
 
 
 
1058
 
1059
  def process_frame(frame, session):
1060
- if session is None:
1061
- return frame, None, "No active session. Click 'Start Session' to begin.", session
1062
 
1063
- # Process the frame
1064
- annotated_frame, metrics, enhanced_state, updated_session = update_webcam_session(session, frame)
1065
-
1066
- # Update the metrics plot if metrics available
1067
- if metrics:
1068
- metrics_plot = update_metrics_visualization(metrics)
1069
- return annotated_frame, metrics_plot, enhanced_state, updated_session
1070
- else:
1071
- # Return the annotated frame (likely with "No face detected")
1072
- return annotated_frame, None, enhanced_state or "No metrics available", updated_session
1073
 
1074
  def end_session(session):
1075
  if session is None:
1076
- return "No active session", None, None
1077
-
1078
- csv_path, video_path = end_webcam_session(session)
1079
- end_time = datetime.datetime.now().strftime('%H:%M:%S')
1080
- result = f"Session ended at {end_time}.\n"
1081
 
1082
- if csv_path:
1083
- result += f"CSV data saved to: {csv_path}\n"
1084
- if video_path:
1085
- result += f"Video saved to: {video_path}"
1086
-
1087
- return result, csv_path, video_path
 
 
 
 
 
1088
 
1089
  start_session_btn.click(
1090
  start_session,
1091
- inputs=[web_ad_desc, web_ad_detail, web_ad_type, record_video_chk],
1092
  outputs=[session_data, session_status]
1093
  )
1094
 
1095
  webcam_input.stream(
1096
  process_frame,
1097
  inputs=[webcam_input, session_data],
1098
- outputs=[processed_output, metrics_plot, enhanced_state_txt, session_data]
1099
  )
1100
 
1101
  end_session_btn.click(
1102
  end_session,
1103
  inputs=[session_data],
1104
- outputs=[session_status, download_csv, download_video]
1105
  )
1106
 
1107
  return iface
1108
 
1109
  # Entry point
1110
  if __name__ == "__main__":
1111
- print("Starting Enhanced Facial Analysis API (DeepFace)...")
1112
- print(f"Gemini API {'enabled' if GEMINI_ENABLED else 'disabled (using simulation)'}")
 
1113
  iface = create_api_interface()
1114
  iface.launch(debug=True)
 
3
  import numpy as np
4
  import pandas as pd
5
  import time
6
+ import mediapipe as mp
7
  import matplotlib.pyplot as plt
8
  from matplotlib.colors import LinearSegmentedColormap
9
  from matplotlib.collections import LineCollection
 
11
  import datetime
12
  import tempfile
13
  from typing import Dict, List, Tuple, Optional, Union, Any
14
+ import threading
15
+ import queue
16
+ import asyncio
17
+ import librosa
18
+ import torch
19
+ from moviepy.editor import VideoFileClip
20
+ from transformers import pipeline, AutoFeatureExtractor, AutoModelForAudioClassification
21
  import google.generativeai as genai
22
+ from concurrent.futures import ThreadPoolExecutor
 
 
 
 
 
 
 
 
 
 
23
 
24
  # --- Constants ---
25
+ VIDEO_FPS = 15 # Estimated/Target FPS for saved video
26
  CSV_FILENAME_TEMPLATE = "facial_analysis_{timestamp}.csv"
27
  VIDEO_FILENAME_TEMPLATE = "processed_{timestamp}.mp4"
28
+ AUDIO_FILENAME_TEMPLATE = "audio_{timestamp}.wav"
 
29
 
30
+ # --- MediaPipe Initialization ---
31
+ mp_face_mesh = mp.solutions.face_mesh
32
+ mp_drawing = mp.solutions.drawing_utils
33
+ mp_drawing_styles = mp.solutions.drawing_styles
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ face_mesh = mp_face_mesh.FaceMesh(
36
+ max_num_faces=1,
37
+ refine_landmarks=True,
38
+ min_detection_confidence=0.5,
39
+ min_tracking_confidence=0.5)
40
+
41
+ # --- Audio Model Initialization ---
42
+ # We'll initialize this in a function to avoid loading at startup
43
+ audio_classifier = None
44
+ audio_feature_extractor = None
45
+
46
+ def initialize_audio_model():
47
+ global audio_classifier, audio_feature_extractor
48
+ if audio_classifier is None:
49
+ print("Loading audio classification model...")
50
+ model_name = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
51
+ audio_feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
52
+ audio_classifier = AutoModelForAudioClassification.from_pretrained(model_name)
53
+ print("Audio model loaded successfully")
54
+ return audio_classifier, audio_feature_extractor
55
+
56
+ # --- Gemini API Configuration ---
57
+ # Replace with your Gemini API key
58
+ GEMINI_API_KEY = "your-gemini-api-key" # In production, load from environment variable
59
+
60
+ def configure_gemini():
61
+ genai.configure(api_key=GEMINI_API_KEY)
62
+
63
+ # Set up the model
64
+ generation_config = {
65
+ "temperature": 0.2,
66
+ "top_p": 0.8,
67
+ "top_k": 40,
68
+ "max_output_tokens": 256,
69
+ }
70
+
71
+ safety_settings = [
72
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
73
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
74
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
75
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
76
+ ]
77
+
78
+ try:
79
+ model = genai.GenerativeModel(
80
+ model_name="gemini-1.5-flash",
81
+ generation_config=generation_config,
82
+ safety_settings=safety_settings
83
+ )
84
+ return model
85
+ except Exception as e:
86
+ print(f"Error configuring Gemini: {e}")
87
+ return None
88
 
89
  # --- Metrics Definition ---
90
  metrics = [
 
93
  "neuroticism", "conscientiousness", "extraversion",
94
  "stress_index", "engagement_level"
95
  ]
96
+ audio_metrics = [
97
+ "audio_valence", "audio_arousal", "audio_intensity",
98
+ "audio_emotion", "audio_confidence"
99
+ ]
 
 
 
 
 
 
 
 
100
  ad_context_columns = ["ad_description", "ad_detail", "ad_type", "gemini_ad_analysis"]
101
+ user_state_column = ["user_state", "detailed_user_analysis"]
102
+ all_columns = ['timestamp', 'frame_number'] + metrics + audio_metrics + ad_context_columns + user_state_column
103
  initial_metrics_df = pd.DataFrame(columns=all_columns)
104
 
105
+ # --- Live Processing Queue ---
106
+ processing_queue = queue.Queue()
107
+ results_queue = queue.Queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # --- Gemini Functions ---
110
+ def call_gemini_api_for_ad(model, description, detail, ad_type):
111
+ """Uses Gemini to analyze ad context."""
112
+ if not model:
113
+ return "Gemini model not available. Using simulated analysis."
114
+
115
+ if not description and not detail:
116
+ return "No ad context provided."
117
+
118
+ prompt = f"""
119
+ Analyze this advertisement context:
120
+ - Description: {description or 'N/A'}
121
+ - Detail/Focus: {detail or 'N/A'}
122
+ - Type/Genre: {ad_type}
123
+
124
+ Provide a concise analysis of how this ad might affect viewer emotions and cognition.
125
+ Focus on potential emotional triggers, cognitive demands, and engagement patterns.
126
+ Keep your analysis under 100 words.
127
  """
128
+
129
+ try:
130
+ response = model.generate_content(prompt)
131
+ return response.text
132
+ except Exception as e:
133
+ print(f"Error calling Gemini API: {e}")
134
+ return f"Simulated analysis: Ad='{description or 'N/A'}' ({ad_type}), Focus='{detail or 'N/A'}'."
135
+
136
+ def interpret_metrics_with_gemini(model, metrics_dict, audio_metrics_dict=None, ad_context=None, timestamp=None):
137
+ """Uses Gemini to interpret facial and audio metrics -> detailed user state."""
138
+ if not model:
139
+ return simple_user_state_analysis(metrics_dict, audio_metrics_dict), "Gemini model not available. Using rule-based analysis."
140
+
141
+ if not metrics_dict:
142
+ return "No response", "No metrics data available"
143
+
144
+ metrics_text = "\n".join([f"- {k}: {v:.3f}" for k, v in metrics_dict.items()])
145
+
146
+ audio_text = ""
147
+ if audio_metrics_dict:
148
+ audio_text = "\n".join([f"- {k}: {v}" for k, v in audio_metrics_dict.items()])
149
+
150
+ ad_text = ""
151
+ if ad_context:
152
+ ad_text = f"""
153
+ Ad Context:
154
+ - Description: {ad_context.get('ad_description', 'N/A')}
155
+ - Detail/Focus: {ad_context.get('ad_detail', 'N/A')}
156
+ - Type/Genre: {ad_context.get('ad_type', 'N/A')}
157
+ """
158
+
159
+ timestamp_text = f"Timestamp: {timestamp:.2f} seconds" if timestamp is not None else ""
160
+
161
+ prompt = f"""
162
+ Analyze the following viewer metrics and provide a detailed assessment of their current state:
163
+
164
+ {timestamp_text}
165
+
166
+ Facial Expression Metrics:
167
+ {metrics_text}
168
+
169
+ {'Audio Expression Metrics:' if audio_text else ''}
170
+ {audio_text}
171
+
172
+ {ad_text}
173
+
174
+ First, provide a short 1-5 word state label that summarizes the viewer's current emotional and cognitive state.
175
+
176
+ Then, provide a more detailed 2-3 sentence analysis explaining what these metrics suggest about the viewer's:
177
+ - Emotional state
178
+ - Cognitive engagement
179
+ - Likely response to the content
180
+ - Any notable patterns or anomalies
181
+
182
+ Format your response as:
183
+ USER STATE: [state label]
184
+
185
+ DETAILED ANALYSIS: [your analysis]
186
  """
 
 
187
 
188
+ try:
189
+ response = model.generate_content(prompt)
190
+ text = response.text.strip()
 
 
 
 
 
 
 
191
 
192
+ # Parse the response
193
+ state_parts = text.split("USER STATE:", 1)
194
+ if len(state_parts) > 1:
195
+ state_text = state_parts[1].split("DETAILED ANALYSIS:", 1)
196
+ if len(state_text) > 1:
197
+ simple_state = state_text[0].strip()
198
+ detailed_analysis = state_text[1].strip()
199
+ return simple_state, detailed_analysis
200
 
201
+ # Fallback if parsing fails
202
+ simple_state = text.split('\n')[0].strip()
203
+ detailed_analysis = ' '.join(text.split('\n')[1:]).strip()
204
+ return simple_state, detailed_analysis
205
+ except Exception as e:
206
+ print(f"Error interpreting metrics with Gemini: {e}")
207
+ return simple_user_state_analysis(metrics_dict, audio_metrics_dict), "Error generating detailed analysis"
208
+
209
+ def simple_user_state_analysis(metrics_dict, audio_metrics_dict=None):
210
+ """Simple rule-based user state analysis as fallback."""
211
+ if not metrics_dict:
212
+ return "No metrics"
213
+
214
+ valence = metrics_dict.get('valence', 0.5)
215
+ arousal = metrics_dict.get('arousal', 0.5)
216
+ cog_load = metrics_dict.get('cognitive_load', 0.5)
217
+ stress = metrics_dict.get('stress_index', 0.5)
218
+ engagement = metrics_dict.get('engagement_level', 0.5)
219
+
220
+ # Include audio metrics when available
221
+ audio_emotion = None
222
+ audio_valence = 0.5
223
+ if audio_metrics_dict:
224
+ audio_emotion = audio_metrics_dict.get('audio_emotion')
225
+ audio_valence = audio_metrics_dict.get('audio_valence', 0.5)
226
 
227
+ # Blend facial and audio valence
228
+ valence = (valence * 0.7) + (audio_valence * 0.3)
229
+
230
+ # Simple rule-based analysis
231
+ state = "Neutral"
232
+ if valence > 0.65 and arousal > 0.55 and engagement > 0.6:
233
+ state = "Positive, Engaged"
234
+ elif valence < 0.4 and stress > 0.6:
235
+ state = "Stressed, Negative"
236
+ elif cog_load > 0.7 and engagement < 0.4:
237
+ state = "Confused, Disengaged"
238
+ elif arousal < 0.4 and engagement < 0.5:
239
+ state = "Calm, Passive"
240
+
241
+ # Override with audio emotion if it's strong
242
+ if audio_emotion in ["happy", "excited"] and audio_metrics_dict.get('audio_confidence', 0) > 0.7:
243
+ state = audio_emotion.capitalize()
244
+ elif audio_emotion in ["angry", "sad", "fearful"] and audio_metrics_dict.get('audio_confidence', 0) > 0.7:
245
+ state = audio_emotion.capitalize()
246
+
247
+ return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ # --- Audio Analysis Functions ---
250
+ def extract_audio_from_video(video_path, output_audio_path=None):
251
+ """Extract audio from video file"""
252
+ if output_audio_path is None:
253
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
254
+ output_audio_path = AUDIO_FILENAME_TEMPLATE.format(timestamp=timestamp)
255
+
256
+ try:
257
+ video = VideoFileClip(video_path)
258
+ video.audio.write_audiofile(output_audio_path, fps=16000, nbytes=2, codec='pcm_s16le')
259
+ return output_audio_path
260
+ except Exception as e:
261
+ print(f"Error extracting audio: {e}")
262
  return None
263
+
264
+ def analyze_audio_segment(audio_path, start_time, duration=1.0):
265
+ """Analyze a segment of audio for emotion"""
266
+ classifier, feature_extractor = initialize_audio_model()
267
 
268
  try:
269
+ # Load audio segment
270
+ y, sr = librosa.load(audio_path, sr=16000, offset=start_time, duration=duration)
 
 
 
 
 
 
 
 
271
 
272
+ if len(y) < 100: # Too short to analyze
273
+ return None
 
274
 
275
+ # Extract features
276
+ inputs = feature_extractor(y, sampling_rate=sr, return_tensors="pt")
 
 
 
 
 
277
 
278
+ # Get predictions
279
+ with torch.no_grad():
280
+ outputs = classifier(**inputs)
281
+ logits = outputs.logits
282
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
283
+
284
+ # Get the predicted class and its probability
285
+ predicted_class_idx = torch.argmax(probabilities, dim=1).item()
286
+ confidence = probabilities[0][predicted_class_idx].item()
287
 
288
+ # Map to emotion labels (verify these match your model's labels)
289
+ emotion_labels = ["angry", "fearful", "happy", "neutral", "sad", "surprised"]
290
+ predicted_emotion = emotion_labels[predicted_class_idx]
291
+
292
+ # Calculate valence and arousal based on emotion
293
+ emotion_mappings = {
294
+ "angry": {"valence": 0.2, "arousal": 0.9, "intensity": 0.8},
295
+ "fearful": {"valence": 0.3, "arousal": 0.8, "intensity": 0.7},
296
+ "happy": {"valence": 0.9, "arousal": 0.7, "intensity": 0.6},
297
+ "neutral": {"valence": 0.5, "arousal": 0.5, "intensity": 0.3},
298
+ "sad": {"valence": 0.2, "arousal": 0.3, "intensity": 0.5},
299
+ "surprised": {"valence": 0.6, "arousal": 0.8, "intensity": 0.7}
300
+ }
301
+
302
+ valence = emotion_mappings.get(predicted_emotion, {"valence": 0.5})["valence"]
303
+ arousal = emotion_mappings.get(predicted_emotion, {"arousal": 0.5})["arousal"]
304
+ intensity = emotion_mappings.get(predicted_emotion, {"intensity": 0.5})["intensity"]
305
 
306
+ # Return audio metrics
307
+ return {
308
+ "audio_valence": valence,
309
+ "audio_arousal": arousal,
310
+ "audio_intensity": intensity,
311
+ "audio_emotion": predicted_emotion,
312
+ "audio_confidence": confidence
313
+ }
314
  except Exception as e:
315
+ print(f"Error analyzing audio segment: {e}")
316
  return None
317
 
318
+ # --- Analysis Functions ---
319
+ def extract_face_landmarks(image, face_mesh_instance):
320
+ if image is None or face_mesh_instance is None:
 
321
  return None
 
322
  try:
323
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
324
+ image_rgb.flags.writeable = False
325
+ results = face_mesh_instance.process(image_rgb)
326
+ image_rgb.flags.writeable = True
327
+ if results.multi_face_landmarks:
328
+ return results.multi_face_landmarks[0]
329
+ except Exception as e:
330
+ print(f"Error in landmark extraction: {e}")
331
+ return None
332
+
333
+ def calculate_ear(landmarks):
334
+ if not landmarks:
335
+ return 0.0
336
+ try:
337
+ LEFT_EYE = [33, 160, 158, 133, 153, 144]
338
+ RIGHT_EYE = [362, 385, 387, 263, 373, 380]
339
 
340
+ def get_coords(idx_list):
341
+ return np.array([(landmarks.landmark[i].x, landmarks.landmark[i].y) for i in idx_list])
 
 
 
 
 
342
 
343
+ left_pts = get_coords(LEFT_EYE)
344
+ right_pts = get_coords(RIGHT_EYE)
345
 
346
+ def ear_aspect(pts):
347
+ v1 = np.linalg.norm(pts[1] - pts[5])
348
+ v2 = np.linalg.norm(pts[2] - pts[4])
349
+ h = np.linalg.norm(pts[0] - pts[3])
350
+ return (v1 + v2) / (2.0 * h) if h > 1e-6 else 0.0
351
 
352
+ return (ear_aspect(left_pts) + ear_aspect(right_pts)) / 2.0
353
+ except (IndexError, AttributeError) as e:
354
+ print(f"Error calculating EAR: {e}")
355
+ return 0.0
 
356
 
357
+ def calculate_mar(landmarks):
358
+ if not landmarks:
359
+ return 0.0
360
+ try:
361
+ MOUTH = [61, 291, 39, 181, 0, 17, 269, 405]
362
+ pts = np.array([(landmarks.landmark[i].x, landmarks.landmark[i].y) for i in MOUTH])
363
+ h = np.mean([np.linalg.norm(pts[1] - pts[7]), np.linalg.norm(pts[2] - pts[6]), np.linalg.norm(pts[3] - pts[5])])
364
+ w = np.linalg.norm(pts[0] - pts[4])
365
+ return h / w if w > 1e-6 else 0.0
366
+ except (IndexError, AttributeError) as e:
367
+ print(f"Error calculating MAR: {e}")
368
+ return 0.0
369
+
370
+ def calculate_eyebrow_position(landmarks):
371
+ if not landmarks:
372
+ return 0.0
373
+ try:
374
+ L_BROW = 107
375
+ R_BROW = 336
376
+ L_EYE_C = 159
377
+ R_EYE_C = 386
378
+
379
+ l_brow_y = landmarks.landmark[L_BROW].y
380
+ r_brow_y = landmarks.landmark[R_BROW].y
381
+ l_eye_y = landmarks.landmark[L_EYE_C].y
382
+ r_eye_y = landmarks.landmark[R_EYE_C].y
383
+
384
+ l_dist = l_eye_y - l_brow_y
385
+ r_dist = r_eye_y - r_brow_y
386
+ avg_dist = (l_dist + r_dist) / 2.0
387
+ norm = (avg_dist - 0.02) / 0.06
388
+
389
+ return max(0.0, min(1.0, norm))
390
+ except (IndexError, AttributeError) as e:
391
+ print(f"Error calculating Eyebrow Pos: {e}")
392
+ return 0.0
393
+
394
+ def estimate_head_pose(landmarks):
395
+ if not landmarks:
396
+ return 0.0, 0.0
397
+ try:
398
+ NOSE = 4
399
+ L_EYE_C = 159
400
+ R_EYE_C = 386
401
+
402
+ nose_pt = np.array([landmarks.landmark[NOSE].x, landmarks.landmark[NOSE].y])
403
+ l_eye_pt = np.array([landmarks.landmark[L_EYE_C].x, landmarks.landmark[L_EYE_C].y])
404
+ r_eye_pt = np.array([landmarks.landmark[R_EYE_C].x, landmarks.landmark[R_EYE_C].y])
405
+
406
+ eye_mid_y = (l_eye_pt[1] + r_eye_pt[1]) / 2.0
407
+ eye_mid_x = (l_eye_pt[0] + r_eye_pt[0]) / 2.0
408
+
409
+ v_tilt = nose_pt[1] - eye_mid_y
410
+ h_tilt = nose_pt[0] - eye_mid_x
411
+
412
+ v_tilt_norm = max(-1.0, min(1.0, v_tilt * 5.0))
413
+ h_tilt_norm = max(-1.0, min(1.0, h_tilt * 10.0))
414
+
415
+ return v_tilt_norm, h_tilt_norm
416
+ except (IndexError, AttributeError) as e:
417
+ print(f"Error estimating Head Pose: {e}")
418
+ return 0.0, 0.0
419
+
420
+ def calculate_metrics(landmarks, ad_context=None):
421
  if ad_context is None:
422
  ad_context = {}
423
+ if not landmarks:
424
+ return {m: 0.5 for m in metrics} # Return defaults if no landmarks
425
+
426
+ # Calculate base features
427
+ ear = calculate_ear(landmarks)
428
+ mar = calculate_mar(landmarks)
429
+ eb_pos = calculate_eyebrow_position(landmarks)
430
+ v_tilt, h_tilt = estimate_head_pose(landmarks)
431
+
432
+ # Illustrative Context Adjustments
433
+ ad_type = ad_context.get('ad_type', 'Unk')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  gem_txt = str(ad_context.get('gemini_ad_analysis', '')).lower()
435
+ val_mar_w = 2.5 if ad_type == 'Funny' or 'humor' in gem_txt else 2.0
436
+ val_eb_w = 0.8 if ad_type == 'Serious' or 'sad' in gem_txt else 1.0
437
+ arsl_base = 0.05 if ad_type == 'Action' or 'exciting' in gem_txt else 0.0
438
+
439
+ # Calculate final metrics using base features and context adjustments
440
+ cl = max(0, min(1, 1.0 - ear * 2.5))
441
+ val = max(0, min(1, mar * val_mar_w * (val_eb_w * (1.0 - eb_pos))))
442
+ arsl = max(0, min(1, arsl_base + (mar + (1.0 - ear) + eb_pos) / 3.0))
443
+ dom = max(0, min(1, 0.5 + v_tilt))
 
 
 
 
 
 
 
 
444
  neur = max(0, min(1, (cl * 0.6) + ((1.0 - val) * 0.4)))
445
  em_stab = 1.0 - neur
446
  extr = max(0, min(1, (arsl * 0.5) + (val * 0.5)))
447
+ open = max(0, min(1, 0.5 + ((mar - 0.5) * 0.5)))
448
  agree = max(0, min(1, (val * 0.7) + ((1.0 - arsl) * 0.3)))
449
  consc = max(0, min(1, (1.0 - abs(arsl - 0.5)) * 0.7 + (em_stab * 0.3)))
450
+ stress = max(0, min(1, (cl * 0.5) + (eb_pos * 0.3) + ((1.0 - val) * 0.2)))
451
+ engag = max(0, min(1, (arsl * 0.7) + ((1.0 - abs(h_tilt)) * 0.3)))
452
+
453
+ # Return dictionary of metrics
454
+ return {
455
+ 'valence': val, 'arousal': arsl, 'dominance': dom, 'cognitive_load': cl,
456
+ 'emotional_stability': em_stab, 'openness': open, 'agreeableness': agree,
457
+ 'neuroticism': neur, 'conscientiousness': consc, 'extraversion': extr,
458
+ 'stress_index': stress, 'engagement_level': engag
 
 
 
 
 
 
 
 
459
  }
 
 
460
 
461
+ def update_metrics_visualization(metrics_values, audio_metrics=None, title=None):
 
462
  if not metrics_values:
463
  fig, ax = plt.subplots(figsize=(10, 8))
464
+ ax.text(0.5, 0.5, "Waiting...", ha='center', va='center')
465
  ax.axis('off')
466
  fig.patch.set_facecolor('#FFFFFF')
467
  ax.set_facecolor('#FFFFFF')
468
  return fig
469
 
470
+ # Combine face and audio metrics for visualization
471
+ all_metrics = {}
472
+ for k, v in metrics_values.items():
473
+ if k not in ('timestamp', 'frame_number', 'user_state', 'detailed_user_analysis'):
474
+ all_metrics[k] = v
475
 
476
+ if audio_metrics:
477
+ for k, v in audio_metrics.items():
478
+ if isinstance(v, (int, float)):
479
+ all_metrics[k] = v
 
480
 
481
+ num_metrics = len(all_metrics)
482
  nrows = (num_metrics + 2) // 3
483
  fig, axs = plt.subplots(nrows, 3, figsize=(10, nrows * 2.5), facecolor='#FFFFFF')
484
  axs = axs.flatten()
485
 
486
+ if title:
487
+ fig.suptitle(title, fontsize=12)
488
+
489
  colors = [(0.1, 0.1, 0.9), (0.9, 0.9, 0.1), (0.9, 0.1, 0.1)]
490
  cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=100)
491
  norm = plt.Normalize(0, 1)
492
  metric_idx = 0
493
 
494
+ for key, value in all_metrics.items():
495
+ if not isinstance(value, (int, float)):
496
+ value = 0.5
497
+ value = max(0.0, min(1.0, value))
498
 
499
  ax = axs[metric_idx]
500
  ax.set_title(key.replace('_', ' ').title(), fontsize=10)
 
536
  plt.tight_layout(pad=0.5)
537
  return fig
538
 
539
+ def create_user_state_display(state_text, detailed_analysis=None):
540
+ """Create a visual display of the user state"""
541
+ fig, ax = plt.subplots(figsize=(10, 2.5))
542
+ ax.axis('off')
543
+
544
+ # Display state
545
+ ax.text(0.5, 0.8, f"USER STATE: {state_text}",
546
+ ha='center', va='center', fontsize=14, fontweight='bold',
547
+ bbox=dict(facecolor='#e6f2ff', alpha=0.7, boxstyle='round,pad=0.5'))
548
+
549
+ # Display detailed analysis if available
550
+ if detailed_analysis:
551
+ ax.text(0.5, 0.3, detailed_analysis,
552
+ ha='center', va='center', fontsize=10,
553
+ bbox=dict(facecolor='#f2f2f2', alpha=0.7, boxstyle='round,pad=0.5'))
554
+
555
+ plt.tight_layout()
556
+ return fig
557
+
558
+ def annotate_frame(frame, landmarks):
559
+ """Add facial landmark annotations to a frame"""
560
  if frame is None:
561
  return None
562
 
563
  annotated = frame.copy()
564
 
565
+ if landmarks:
566
+ try:
567
+ mp_drawing.draw_landmarks(
568
+ image=annotated,
569
+ landmark_list=landmarks,
570
+ connections=mp_face_mesh.FACEMESH_TESSELATION,
571
+ landmark_drawing_spec=None,
572
+ connection_drawing_spec=mp_drawing_styles.get_default_face_mesh_tesselation_style()
573
+ )
574
+ mp_drawing.draw_landmarks(
575
+ image=annotated,
576
+ landmark_list=landmarks,
577
+ connections=mp_face_mesh.FACEMESH_CONTOURS,
578
+ landmark_drawing_spec=None,
579
+ connection_drawing_spec=mp_drawing_styles.get_default_face_mesh_contours_style()
580
+ )
581
+ except Exception as e:
582
+ print(f"Error drawing landmarks: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
 
584
  return annotated
585
 
586
+ # --- Background Processing Functions ---
587
+ def process_frames_in_background(session):
588
+ """Background thread for processing frames and updating metrics"""
589
+ while True:
590
+ try:
591
+ # Get task from queue
592
+ task = processing_queue.get(timeout=1.0)
593
+ if task.get('command') == 'stop':
594
+ break
595
+
596
+ frame = task.get('frame')
597
+ if frame is None:
598
+ continue
599
+
600
+ # Process the frame
601
+ result = process_webcam_frame(
602
+ frame,
603
+ task.get('ad_context', {}),
604
+ task.get('metrics_data', initial_metrics_df.copy()),
605
+ task.get('frame_count', 0),
606
+ task.get('start_time', time.time()),
607
+ task.get('audio_path'),
608
+ task.get('gemini_model')
609
+ )
610
+
611
+ # Put result in results queue
612
+ results_queue.put({
613
+ 'annotated_frame': result[0],
614
+ 'metrics': result[1],
615
+ 'audio_metrics': result[2],
616
+ 'metrics_df': result[3],
617
+ 'state_fig': result[4],
618
+ 'metrics_fig': result[5]
619
+ })
620
+
621
+ # Mark task as done
622
+ processing_queue.task_done()
623
+ except queue.Empty:
624
+ continue
625
+ except Exception as e:
626
+ print(f"Error in background processing: {e}")
627
+ continue
628
+
629
+ # --- Video File Processing with Progress Updates ---
630
  def process_video_file(
631
  video_file: Union[str, np.ndarray],
632
  ad_description: str = "",
 
634
  ad_type: str = "Video",
635
  sampling_rate: int = 5, # Process every Nth frame
636
  save_processed_video: bool = True,
637
+ progress=gr.Progress()
638
+ ) -> Tuple[str, str, str, pd.DataFrame]:
639
  """
640
  Process a video file and analyze facial expressions frame by frame
641
 
 
646
  ad_type: Type of ad (Video, Image, Audio, Text, Funny, etc.)
647
  sampling_rate: Process every Nth frame
648
  save_processed_video: Whether to save the processed video with annotations
649
+ progress: Gradio progress bar
650
 
651
  Returns:
652
+ Tuple of (csv_path, audio_path, processed_video_path, metrics_dataframe)
653
  """
654
+ # Initialize Gemini model
655
+ gemini_model = configure_gemini()
656
+
657
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
658
  csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp)
659
+ audio_path = AUDIO_FILENAME_TEMPLATE.format(timestamp=timestamp)
660
  video_path = VIDEO_FILENAME_TEMPLATE.format(timestamp=timestamp) if save_processed_video else None
661
 
662
  # Setup ad context
663
+ gemini_result = call_gemini_api_for_ad(gemini_model, ad_description, ad_detail, ad_type)
664
  ad_context = {
665
  "ad_description": ad_description,
666
  "ad_detail": ad_detail,
 
668
  "gemini_ad_analysis": gemini_result
669
  }
670
 
671
+ progress(0, desc="Initializing video processing")
672
+
673
  # Initialize capture
674
  if isinstance(video_file, str):
675
  cap = cv2.VideoCapture(video_file)
 
679
  temp_path = os.path.join(temp_dir, "temp_video.mp4")
680
 
681
  # Convert video array to file
682
+ if isinstance(video_file, np.ndarray):
683
+ # Assuming it's a series of frames
684
+ h, w = video_file[0].shape[:2] if len(video_file) > 0 else (480, 640)
685
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
686
  temp_writer = cv2.VideoWriter(temp_path, fourcc, 30, (w, h))
687
  for frame in video_file:
688
  temp_writer.write(frame)
689
  temp_writer.release()
690
+
691
+ video_file = temp_path
692
+ cap = cv2.VideoCapture(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
  if not cap.isOpened():
695
  print("Error: Could not open video.")
696
+ return None, None, None, None
697
+
698
+ # Extract audio for analysis
699
+ audio_extracted = extract_audio_from_video(video_file, audio_path)
700
 
701
  # Get video properties
702
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
707
  # Initialize video writer if saving processed video
708
  if save_processed_video:
709
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
710
+ out = cv2.VideoWriter(video_path, fourcc, fps, (frame_width, frame_height))
711
 
712
  # Process video frames
713
  metrics_data = []
 
714
  frame_count = 0
715
 
716
+ # Create a thread pool for audio processing
717
+ with ThreadPoolExecutor(max_workers=2) as executor:
718
+ # Queue for audio analysis results
719
+ audio_futures = {}
 
 
 
 
720
 
721
+ progress(0.1, desc="Starting frame analysis")
722
+
723
+ while True:
724
+ ret, frame = cap.read()
725
+ if not ret:
726
+ break
 
 
727
 
728
+ # Only process every Nth frame (according to sampling_rate)
729
+ process_this_frame = frame_count % sampling_rate == 0
730
+ frame_timestamp = frame_count / fps
731
 
732
+ if process_this_frame:
733
+ progress(min(0.1 + 0.8 * (frame_count / total_frames), 0.9),
734
+ desc=f"Processing frame {frame_count}/{total_frames}")
 
735
 
736
+ # Extract facial landmarks
737
+ landmarks = extract_face_landmarks(frame, face_mesh)
 
 
 
 
 
 
 
 
738
 
739
+ # Submit audio analysis task if audio was extracted
740
+ if process_this_frame and audio_extracted and frame_timestamp not in audio_futures:
741
+ audio_futures[frame_timestamp] = executor.submit(
742
+ analyze_audio_segment, audio_path, frame_timestamp, 1.0
743
+ )
744
 
745
+ # Get audio analysis results if available
746
+ audio_metrics = None
747
+ if frame_timestamp in audio_futures and audio_futures[frame_timestamp].done():
748
+ audio_metrics = audio_futures[frame_timestamp].result()
749
+
750
+ # Calculate metrics if landmarks detected
751
+ if landmarks:
752
+ calculated_metrics = calculate_metrics(landmarks, ad_context)
753
+ user_state, detailed_analysis = interpret_metrics_with_gemini(
754
+ gemini_model, calculated_metrics, audio_metrics, ad_context, frame_timestamp
755
+ )
756
+
757
+ # Create a row for the dataframe
758
+ row = {
759
+ 'timestamp': frame_timestamp,
760
+ 'frame_number': frame_count,
761
+ **calculated_metrics
762
+ }
763
+
764
+ # Add audio metrics if available
765
+ if audio_metrics:
766
+ row.update(audio_metrics)
767
+ else:
768
+ # Default audio metrics
769
+ row.update({m: 0.5 for m in audio_metrics})
770
+
771
+ # Add context and state
772
+ row.update(ad_context)
773
+ row['user_state'] = user_state
774
+ row['detailed_user_analysis'] = detailed_analysis
775
+
776
+ metrics_data.append(row)
777
+
778
+ # Annotate the frame with facial landmarks
779
+ if save_processed_video:
780
+ annotated_frame = annotate_frame(frame, landmarks)
781
+
782
+ # Add user state text to frame
783
+ cv2.putText(
784
+ annotated_frame,
785
+ f"State: {user_state}",
786
+ (10, 30),
787
+ cv2.FONT_HERSHEY_SIMPLEX,
788
+ 0.7,
789
+ (0, 255, 0),
790
+ 2
791
+ )
792
+
793
+ # Add audio emotion if available
794
+ if audio_metrics and 'audio_emotion' in audio_metrics:
795
+ cv2.putText(
796
+ annotated_frame,
797
+ f"Audio: {audio_metrics['audio_emotion']}",
798
+ (10, 60),
799
+ cv2.FONT_HERSHEY_SIMPLEX,
800
+ 0.7,
801
+ (255, 0, 0),
802
+ 2
803
+ )
804
+
805
+ out.write(annotated_frame)
806
+ elif save_processed_video:
807
+ # If no landmarks detected, still write the original frame to the video
808
+ out.write(frame)
809
+ elif save_processed_video:
810
+ # For frames not being analyzed, still include them in the output video
811
+ out.write(frame)
812
+
813
+ frame_count += 1
814
 
815
+ # Wait for all audio analysis to complete
816
+ for future in audio_futures.values():
817
+ if not future.done():
818
+ future.result() # This will wait for completion
819
+
820
+ progress(0.95, desc="Finalizing results")
821
 
822
  # Release resources
823
  cap.release()
 
828
  metrics_df = pd.DataFrame(metrics_data)
829
  if not metrics_df.empty:
830
  metrics_df.to_csv(csv_path, index=False)
831
+ progress(1.0, desc="Processing complete")
832
+ else:
833
+ progress(1.0, desc="No facial data detected")
 
 
 
834
 
835
  # Return results
836
+ return csv_path, audio_path, video_path, metrics_df
837
 
838
+ # --- Updated Webcam Processing Function ---
839
  def process_webcam_frame(
840
  frame: np.ndarray,
841
  ad_context: Dict[str, Any],
842
  metrics_data: pd.DataFrame,
843
  frame_count: int,
844
+ start_time: float,
845
+ audio_path: str = None,
846
+ gemini_model = None
847
+ ) -> Tuple[np.ndarray, Dict[str, float], Dict[str, Any], pd.DataFrame, object, object]:
848
  """
849
+ Process a single webcam frame with audio integration
850
 
851
  Args:
852
  frame: Input frame from webcam
 
854
  metrics_data: DataFrame to accumulate metrics
855
  frame_count: Current frame count
856
  start_time: Start time of the session
857
+ audio_path: Path to extracted audio file (if available)
858
+ gemini_model: Configured Gemini model instance
859
 
860
  Returns:
861
+ Tuple of (annotated_frame, metrics_dict, audio_metrics, updated_metrics_df, state_fig, metrics_fig)
862
  """
863
  if frame is None:
864
+ return None, None, None, metrics_data, None, None
865
 
866
+ # Extract facial landmarks
867
+ landmarks = extract_face_landmarks(frame, face_mesh)
 
868
 
869
+ # Get current timestamp
870
+ current_time = time.time()
871
+ elapsed_time = current_time - start_time
872
 
873
+ # Analyze audio segment if available
874
+ audio_metrics = None
875
+ if audio_path and os.path.exists(audio_path):
876
+ audio_metrics = analyze_audio_segment(audio_path, elapsed_time, 1.0)
877
+
878
+ # Calculate metrics if landmarks detected
879
+ if landmarks:
880
+ calculated_metrics = calculate_metrics(landmarks, ad_context)
881
+ user_state, detailed_analysis = interpret_metrics_with_gemini(
882
+ gemini_model, calculated_metrics, audio_metrics, ad_context, elapsed_time
883
+ )
884
 
885
  # Create a row for the dataframe
 
886
  row = {
887
+ 'timestamp': elapsed_time,
888
  'frame_number': frame_count,
889
+ **calculated_metrics
 
 
 
890
  }
891
 
892
+ # Add audio metrics if available
893
+ if audio_metrics:
894
+ row.update(audio_metrics)
895
+ else:
896
+ # Default audio metrics
897
+ row.update({m: 0.5 for m in audio_metrics})
898
+
899
+ # Add context and state
900
+ row.update(ad_context)
901
+ row['user_state'] = user_state
902
+ row['detailed_user_analysis'] = detailed_analysis
903
+
904
  # Add row to DataFrame
905
  new_row_df = pd.DataFrame([row], columns=all_columns)
906
  metrics_data = pd.concat([metrics_data, new_row_df], ignore_index=True)
907
 
908
+ # Create visualizations
909
+ metrics_plot = update_metrics_visualization(
910
+ calculated_metrics,
911
+ audio_metrics,
912
+ title=f"Frame {frame_count} Metrics"
913
+ )
914
+ state_plot = create_user_state_display(user_state, detailed_analysis)
915
+
916
  # Annotate the frame
917
+ annotated_frame = annotate_frame(frame, landmarks)
918
+
919
+ # Add user state text to frame
920
+ cv2.putText(
921
+ annotated_frame,
922
+ f"State: {user_state}",
923
+ (10, 30),
924
+ cv2.FONT_HERSHEY_SIMPLEX,
925
+ 0.7,
926
+ (0, 255, 0),
927
+ 2
928
+ )
929
 
930
+ # Add audio emotion if available
931
+ if audio_metrics and 'audio_emotion' in audio_metrics:
932
+ cv2.putText(
933
+ annotated_frame,
934
+ f"Audio: {audio_metrics['audio_emotion']}",
935
+ (10, 60),
936
+ cv2.FONT_HERSHEY_SIMPLEX,
937
+ 0.7,
938
+ (255, 0, 0),
939
+ 2
940
+ )
941
+
942
+ return annotated_frame, calculated_metrics, audio_metrics, metrics_data, state_plot, metrics_plot
943
  else:
944
  # No face detected
945
+ return frame, None, None, metrics_data, None, None
 
 
 
946
 
947
+ # --- Updated Webcam Session Functions ---
948
  def start_webcam_session(
949
  ad_description: str = "",
950
  ad_detail: str = "",
951
  ad_type: str = "Video",
952
  save_interval: int = 100, # Save CSV every N frames
953
+ record_audio: bool = False
954
  ) -> Dict[str, Any]:
955
  """
956
+ Initialize a webcam session for facial analysis with audio recording
957
 
958
  Args:
959
  ad_description: Description of the ad being watched
960
  ad_detail: Detail focus of the ad
961
  ad_type: Type of ad
962
  save_interval: How often to save data to CSV
963
+ record_audio: Whether to record audio during session
964
 
965
  Returns:
966
  Session context dictionary
 
968
  # Generate timestamp for file naming
969
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
970
  csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp)
971
+ audio_path = AUDIO_FILENAME_TEMPLATE.format(timestamp=timestamp) if record_audio else None
972
+
973
+ # Initialize Gemini model
974
+ gemini_model = configure_gemini()
975
 
976
  # Setup ad context
977
+ gemini_result = call_gemini_api_for_ad(gemini_model, ad_description, ad_detail, ad_type)
978
  ad_context = {
979
  "ad_description": ad_description,
980
  "ad_detail": ad_detail,
 
989
  "metrics_data": initial_metrics_df.copy(),
990
  "ad_context": ad_context,
991
  "csv_path": csv_path,
992
+ "audio_path": audio_path,
993
  "save_interval": save_interval,
994
  "last_saved": 0,
995
+ "gemini_model": gemini_model,
996
+ "processing_thread": None
 
997
  }
998
 
999
+ # Start background processing thread
1000
+ processor = threading.Thread(target=process_frames_in_background, args=(session,))
1001
+ processor.daemon = True
1002
+ processor.start()
1003
+ session["processing_thread"] = processor
1004
+
1005
  return session
1006
 
1007
  def update_webcam_session(
1008
  session: Dict[str, Any],
1009
  frame: np.ndarray
1010
+ ) -> Tuple[np.ndarray, object, object, Dict[str, Any]]:
1011
  """
1012
  Update webcam session with a new frame
1013
 
 
1016
  frame: New frame from webcam
1017
 
1018
  Returns:
1019
+ Tuple of (annotated_frame, state_plot, metrics_plot, updated_session)
1020
  """
1021
+ if session is None:
1022
+ return frame, None, None, session
1023
+
1024
+ # Add task to processing queue
1025
+ processing_queue.put({
1026
+ 'command': 'process',
1027
+ 'frame': frame.copy() if frame is not None else None,
1028
+ 'ad_context': session["ad_context"],
1029
+ 'metrics_data': session["metrics_data"],
1030
+ 'frame_count': session["frame_count"],
1031
+ 'start_time': session["start_time"],
1032
+ 'audio_path': session["audio_path"],
1033
+ 'gemini_model': session["gemini_model"]
1034
+ })
1035
+
1036
+ # Update frame count
1037
  session["frame_count"] += 1
 
1038
 
1039
+ # Get result if available
1040
+ try:
1041
+ result = results_queue.get_nowait()
1042
+ annotated_frame = result.get('annotated_frame', frame)
1043
+ state_fig = result.get('state_fig')
1044
+ metrics_fig = result.get('metrics_fig')
1045
+ session["metrics_data"] = result.get('metrics_df', session["metrics_data"])
1046
+ results_queue.task_done()
1047
+ except queue.Empty:
1048
+ # No result yet, return original frame
1049
+ annotated_frame = frame
1050
+ state_fig = None
1051
+ metrics_fig = None
1052
 
1053
  # Save CSV periodically
1054
  if session["frame_count"] - session["last_saved"] >= session["save_interval"]:
1055
+ if not session["metrics_data"].empty:
1056
+ session["metrics_data"].to_csv(session["csv_path"], index=False)
1057
  session["last_saved"] = session["frame_count"]
1058
 
1059
+ return annotated_frame, state_fig, metrics_fig, session
1060
 
1061
  def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]:
1062
  """
 
1066
  session: Session context dictionary
1067
 
1068
  Returns:
1069
+ Tuple of (csv_path, audio_path)
1070
  """
1071
+ if session is None:
1072
+ return None, None
1073
+
1074
+ # Stop background processing thread
1075
+ if session["processing_thread"] and session["processing_thread"].is_alive():
1076
+ processing_queue.put({"command": "stop"})
1077
+ session["processing_thread"].join(timeout=2.0)
1078
+
1079
  # Save final metrics to CSV
1080
  if not session["metrics_data"].empty:
1081
  session["metrics_data"].to_csv(session["csv_path"], index=False)
1082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1083
  print(f"Session ended. Data saved to {session['csv_path']}")
1084
+ return session["csv_path"], session["audio_path"]
1085
 
1086
+ # --- Create Enhanced Gradio Interface ---
1087
  def create_api_interface():
1088
+ with gr.Blocks(title="Enhanced Facial Analysis APIs") as iface:
1089
+ gr.Markdown("# Enhanced Facial Analysis APIs\nAnalyze facial expressions and audio in videos or webcam feed")
 
 
 
 
 
 
 
 
 
1090
 
1091
  with gr.Tab("Video File API"):
1092
  with gr.Row():
1093
  with gr.Column(scale=1):
1094
  video_input = gr.Video(label="Upload Video")
1095
+ vid_ad_desc = gr.Textbox(label="Ad Description")
1096
+ vid_ad_detail = gr.Textbox(label="Ad Detail Focus")
1097
  vid_ad_type = gr.Radio(
1098
  ["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"],
1099
  label="Ad Type/Genre",
 
1104
  label="Sampling Rate (process every N frames)"
1105
  )
1106
  save_video = gr.Checkbox(label="Save Processed Video", value=True)
1107
+ process_btn = gr.Button("Process Video")
1108
 
1109
  with gr.Column(scale=2):
 
1110
  with gr.Row():
1111
+ output_text = gr.Textbox(label="Processing Status")
1112
+
1113
+ with gr.Row():
1114
+ output_video = gr.Video(label="Processed Video")
1115
+
1116
+ with gr.Row():
1117
+ output_plot = gr.Plot(label="Metrics Visualization")
1118
+ user_state_plot = gr.Plot(label="User State Analysis")
1119
 
1120
  with gr.Row():
1121
+ output_csv = gr.File(label="Download CSV Results")
1122
+ output_audio = gr.Audio(label="Extracted Audio")
 
 
1123
 
1124
+ # Define function to handle video processing with live updates
1125
+ def handle_video_processing(video, desc, detail, ad_type, rate, save_vid, progress=gr.Progress()):
1126
  if video is None:
1127
+ return "No video uploaded", None, None, None, None, None
1128
 
1129
  try:
1130
+ progress(0.05, "Starting video processing...")
1131
+
1132
+ csv_path, audio_path, video_path, metrics_df = process_video_file(
1133
  video,
1134
  ad_description=desc,
1135
  ad_detail=detail,
1136
  ad_type=ad_type,
1137
  sampling_rate=rate,
1138
  save_processed_video=save_vid,
1139
+ progress=progress
1140
  )
1141
 
1142
  if metrics_df is None or metrics_df.empty:
1143
+ return "No facial data detected in video", None, None, None, None, None
1144
+
1145
+ # Get a sample row for visualization
1146
+ middle_idx = len(metrics_df) // 2
1147
+ sample_row = metrics_df.iloc[middle_idx].to_dict()
1148
+
1149
+ # Generate visualizations
1150
+ metrics_plot = update_metrics_visualization(
1151
+ {k: v for k, v in sample_row.items() if k in metrics},
1152
+ {k: v for k, v in sample_row.items() if k in audio_metrics},
1153
+ title=f"Sample Frame Metrics (Frame {sample_row['frame_number']})"
1154
+ )
1155
 
1156
+ state_plot = create_user_state_display(
1157
+ sample_row.get('user_state', 'No state'),
1158
+ sample_row.get('detailed_user_analysis', '')
1159
+ )
1160
 
1161
+ processed_frames = metrics_df.shape[0]
1162
+ total_duration = metrics_df['timestamp'].max() if not metrics_df.empty else 0
 
 
 
 
 
 
 
1163
 
1164
+ result_text = f"✅ Processing complete!\n"
1165
+ result_text += f"• Analyzed {processed_frames} frames over {total_duration:.2f} seconds\n"
1166
+ result_text += f"• CSV saved to: {csv_path}\n"
1167
+ if audio_path:
1168
+ result_text += f" Audio extracted to: {audio_path}\n"
1169
  if video_path:
1170
+ result_text += f" Processed video saved to: {video_path}\n"
1171
 
1172
+ return result_text, csv_path, video_path, audio_path, metrics_plot, state_plot
1173
  except Exception as e:
1174
+ return f"Error processing video: {str(e)}", None, None, None, None, None
1175
 
1176
  process_btn.click(
1177
  handle_video_processing,
1178
  inputs=[video_input, vid_ad_desc, vid_ad_detail, vid_ad_type, sampling_rate, save_video],
1179
+ outputs=[output_text, output_csv, output_video, output_audio, output_plot, user_state_plot]
1180
  )
1181
 
1182
  with gr.Tab("Webcam API"):
1183
  with gr.Row():
1184
+ with gr.Column(scale=1):
1185
  webcam_input = gr.Image(sources="webcam", streaming=True, label="Webcam Input", type="numpy")
1186
+ web_ad_desc = gr.Textbox(label="Ad Description")
1187
+ web_ad_detail = gr.Textbox(label="Ad Detail Focus")
1188
+ web_ad_type = gr.Radio(
1189
+ ["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"],
1190
+ label="Ad Type/Genre",
1191
+ value="Video"
1192
+ )
1193
+ record_audio = gr.Checkbox(label="Record Audio", value=True)
1194
+ start_session_btn = gr.Button("Start Session")
1195
+ end_session_btn = gr.Button("End Session")
 
 
 
 
 
1196
 
1197
  with gr.Column(scale=2):
1198
+ with gr.Row():
1199
+ processed_output = gr.Image(label="Processed Feed", type="numpy")
1200
 
1201
  with gr.Row():
1202
+ metrics_plot = gr.Plot(label="Live Metrics")
1203
+ state_plot = gr.Plot(label="User State Analysis")
 
 
1204
 
1205
  with gr.Row():
1206
+ session_status = gr.Textbox(label="Session Status")
1207
  download_csv = gr.File(label="Download Session Data")
 
1208
 
1209
  # Session state
1210
  session_data = gr.State(value=None)
1211
 
1212
  # Define session handlers
1213
+ def start_session(desc, detail, ad_type, record_audio):
1214
+ try:
1215
+ session = start_webcam_session(
1216
+ ad_description=desc,
1217
+ ad_detail=detail,
1218
+ ad_type=ad_type,
1219
+ record_audio=record_audio
1220
+ )
1221
+
1222
+ status_text = "Session started successfully!\n\n"
1223
+ status_text += f"Ad Context: {desc} ({ad_type})\n"
1224
+ status_text += f" Focus: {detail}\n"
1225
+ status_text += f"• Audio Recording: {'Enabled' if record_audio else 'Disabled'}\n"
1226
+ status_text += f"• Data will be saved to: {session['csv_path']}"
1227
+
1228
+ return session, status_text
1229
+ except Exception as e:
1230
+ return None, f"Error starting session: {str(e)}"
1231
 
1232
  def process_frame(frame, session):
1233
+ if session is None or frame is None:
1234
+ return frame, None, None, session
1235
 
1236
+ try:
1237
+ annotated_frame, state_fig, metrics_fig, updated_session = update_webcam_session(session, frame)
1238
+ return annotated_frame, state_fig, metrics_fig, updated_session
1239
+ except Exception as e:
1240
+ print(f"Error processing frame: {e}")
1241
+ return frame, None, None, session
 
 
 
 
1242
 
1243
  def end_session(session):
1244
  if session is None:
1245
+ return "No active session", None
 
 
 
 
1246
 
1247
+ try:
1248
+ csv_path, audio_path = end_webcam_session(session)
1249
+
1250
+ status_text = " Session ended successfully!\n\n"
1251
+ status_text += f"• Data saved to: {csv_path}\n"
1252
+ if audio_path:
1253
+ status_text += f"• Audio saved to: {audio_path}"
1254
+
1255
+ return status_text, csv_path
1256
+ except Exception as e:
1257
+ return f"Error ending session: {str(e)}", None
1258
 
1259
  start_session_btn.click(
1260
  start_session,
1261
+ inputs=[web_ad_desc, web_ad_detail, web_ad_type, record_audio],
1262
  outputs=[session_data, session_status]
1263
  )
1264
 
1265
  webcam_input.stream(
1266
  process_frame,
1267
  inputs=[webcam_input, session_data],
1268
+ outputs=[processed_output, state_plot, metrics_plot, session_data]
1269
  )
1270
 
1271
  end_session_btn.click(
1272
  end_session,
1273
  inputs=[session_data],
1274
+ outputs=[session_status, download_csv]
1275
  )
1276
 
1277
  return iface
1278
 
1279
  # Entry point
1280
  if __name__ == "__main__":
1281
+ print("Starting Enhanced Facial Analysis API server...")
1282
+ # Pre-initialize models if needed
1283
+ # initialize_audio_model()
1284
  iface = create_api_interface()
1285
  iface.launch(debug=True)