Spaces:
Configuration error
Configuration error
import gradio as gr | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import time | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.collections import LineCollection | |
import os | |
import datetime | |
import tempfile | |
from typing import Dict, List, Tuple, Optional, Union, Any | |
import google.generativeai as genai | |
from PIL import Image | |
import json | |
import warnings | |
from deepface import DeepFace | |
import base64 | |
import io | |
from pathlib import Path | |
import traceback | |
# Suppress warnings for cleaner output | |
warnings.filterwarnings('ignore') | |
# --- Constants --- | |
VIDEO_FPS = 30 # Target FPS for saved video | |
CSV_FILENAME_TEMPLATE = "facial_analysis_{timestamp}.csv" | |
VIDEO_FILENAME_TEMPLATE = "processed_{timestamp}.mp4" | |
TEMP_DIR = Path("temp_frames") | |
TEMP_DIR.mkdir(exist_ok=True) | |
# --- Configure Google Gemini API --- | |
print("Configuring Google Gemini API...") | |
try: | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
if not GOOGLE_API_KEY: | |
raise ValueError("GOOGLE_API_KEY environment variable not set.") | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# Use gemini-1.5-flash for quick responses | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
GEMINI_ENABLED = True | |
print("Google Gemini API configured successfully.") | |
except Exception as e: | |
print(f"WARNING: Failed to configure Google Gemini API: {e}") | |
print("Running with simulated Gemini API responses.") | |
GEMINI_ENABLED = False | |
# --- Initialize OpenCV face detector for backup --- | |
print("Initializing OpenCV face detector...") | |
try: | |
# Use OpenCV's built-in face detector as backup | |
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
# Check if the face detector loaded successfully | |
if face_cascade.empty(): | |
print("WARNING: Failed to load face cascade classifier") | |
else: | |
print("OpenCV face detector initialized successfully.") | |
except Exception as e: | |
print(f"ERROR initializing OpenCV face detector: {e}") | |
face_cascade = None | |
# --- Metrics Definition --- | |
metrics = [ | |
"valence", "arousal", "dominance", "cognitive_load", | |
"emotional_stability", "openness", "agreeableness", | |
"neuroticism", "conscientiousness", "extraversion", | |
"stress_index", "engagement_level" | |
] | |
# DeepFace emotion mapping | |
emotion_mapping = { | |
"angry": {"valence": 0.2, "arousal": 0.8, "dominance": 0.7}, | |
"disgust": {"valence": 0.2, "arousal": 0.6, "dominance": 0.5}, | |
"fear": {"valence": 0.2, "arousal": 0.8, "dominance": 0.3}, | |
"happy": {"valence": 0.9, "arousal": 0.7, "dominance": 0.6}, | |
"sad": {"valence": 0.3, "arousal": 0.4, "dominance": 0.3}, | |
"surprise": {"valence": 0.6, "arousal": 0.9, "dominance": 0.5}, | |
"neutral": {"valence": 0.5, "arousal": 0.5, "dominance": 0.5} | |
} | |
ad_context_columns = ["ad_description", "ad_detail", "ad_type", "gemini_ad_analysis"] | |
user_state_columns = ["user_state", "enhanced_user_state"] | |
all_columns = ['timestamp', 'frame_number'] + metrics + ad_context_columns + user_state_columns | |
initial_metrics_df = pd.DataFrame(columns=all_columns) | |
# --- Gemini API Functions --- | |
def call_gemini_api_for_ad(description, detail, ad_type): | |
""" | |
Uses Google Gemini to analyze ad context. | |
""" | |
print(f"Analyzing ad context: '{description}' ({ad_type})") | |
if not GEMINI_ENABLED: | |
# Simulated response | |
analysis = f"Simulated analysis: Ad='{description or 'N/A'}' ({ad_type}), Focus='{detail or 'N/A'}'." | |
if not description and not detail: | |
analysis = "No ad context provided." | |
print(f"Simulated Gemini Result: {analysis}") | |
return analysis | |
else: | |
try: | |
prompt = f""" | |
Please analyze this advertisement context: | |
- Description: {description} | |
- Detail focus: {detail} | |
- Type/Genre: {ad_type} | |
Provide a concise analysis of what emotional and cognitive responses might be expected from viewers. | |
Limit your response to 100 words. | |
""" | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
print(f"Error calling Gemini for ad context: {e}") | |
return f"Error analyzing ad context: {str(e)}" | |
def interpret_metrics_with_gemini(metrics_dict, deepface_results=None, ad_context=None): | |
""" | |
Uses Google Gemini to interpret facial metrics and DeepFace results | |
to determine user state. | |
""" | |
if not metrics_dict and not deepface_results: | |
return "No metrics", "No facial data detected" | |
if not GEMINI_ENABLED: | |
# Basic rule-based simulation for user state | |
valence = metrics_dict.get('valence', 0.5) if metrics_dict else 0.5 | |
arousal = metrics_dict.get('arousal', 0.5) if metrics_dict else 0.5 | |
# Extract emotion from DeepFace if available | |
dominant_emotion = "neutral" | |
if deepface_results and "emotion" in deepface_results: | |
emotion_dict = deepface_results["emotion"] | |
dominant_emotion = max(emotion_dict.items(), key=lambda x: x[1])[0] | |
# Simple rule-based simulation | |
state = dominant_emotion.capitalize() if dominant_emotion != "neutral" else "Neutral" | |
if valence > 0.65 and arousal > 0.55: | |
state = "Positive, Engaged" | |
elif valence < 0.4 and arousal > 0.6: | |
state = "Stressed, Negative" | |
enhanced_state = f"The viewer appears {state.lower()} while watching this content." | |
return state, enhanced_state | |
else: | |
try: | |
# Format metrics for Gemini | |
metrics_formatted = "" | |
if metrics_dict: | |
metrics_formatted = "\nMetrics (0-1 scale):\n" + "\n".join([f"- {k.replace('_', ' ').title()}: {v:.2f}" for k, v in metrics_dict.items() | |
if k not in ('timestamp', 'frame_number')]) | |
# Format DeepFace results | |
deepface_formatted = "" | |
if deepface_results and "emotion" in deepface_results: | |
emotion_dict = deepface_results["emotion"] | |
deepface_formatted = "\nDeepFace emotions:\n" + "\n".join([f"- {k.title()}: {v:.2f}" for k, v in emotion_dict.items()]) | |
# Include ad context if available | |
ad_info = "" | |
if ad_context: | |
ad_desc = ad_context.get('ad_description', 'N/A') | |
ad_type = ad_context.get('ad_type', 'N/A') | |
ad_info = f"\nThey are watching an advertisement: {ad_desc} (Type: {ad_type})" | |
prompt = f""" | |
Analyze the facial expression and emotion of a person watching an advertisement{ad_info}. | |
Use these combined inputs:{metrics_formatted}{deepface_formatted} | |
Provide two outputs: | |
1. User State: A short 1-3 word description of their emotional/cognitive state | |
2. Enhanced Analysis: A detailed 1-2 sentence interpretation of their reaction to the content | |
Format as JSON: {{"user_state": "STATE", "enhanced_user_state": "DETAILED ANALYSIS"}} | |
""" | |
response = model.generate_content(prompt) | |
try: | |
# Try to parse as JSON | |
result = json.loads(response.text) | |
return result.get("user_state", "Uncertain"), result.get("enhanced_user_state", "Analysis unavailable") | |
except json.JSONDecodeError: | |
# If not valid JSON, try to extract manually | |
text = response.text | |
if "user_state" in text and "enhanced_user_state" in text: | |
parts = text.split("enhanced_user_state") | |
user_state = parts[0].split("user_state")[1].replace('"', '').replace(':', '').replace(',', '').strip() | |
enhanced = parts[1].replace('"', '').replace(':', '').replace('}', '').strip() | |
return user_state, enhanced | |
else: | |
# Just return the raw text as enhanced state | |
return "Analyzed", text | |
except Exception as e: | |
print(f"Error calling Gemini for metric interpretation: {e}") | |
traceback.print_exc() | |
return "Error", f"Error analyzing facial metrics: {str(e)}" | |
# --- DeepFace Analysis Function --- | |
def analyze_face_with_deepface(image): | |
"""Analyze facial emotions and attributes using DeepFace""" | |
if image is None: | |
return None | |
try: | |
# Convert to RGB for DeepFace if needed | |
if len(image.shape) == 3 and image.shape[2] == 3: | |
# Check if BGR and convert to RGB if needed | |
if np.mean(image[:,:,0]) < np.mean(image[:,:,2]): # Rough BGR check | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
else: | |
image_rgb = image | |
else: | |
# Handle grayscale or other formats | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Save image to temp file (DeepFace sometimes works better with files) | |
temp_img = f"temp_frames/temp_analysis_{time.time()}.jpg" | |
cv2.imwrite(temp_img, image_rgb) | |
# Analyze with DeepFace | |
analysis = DeepFace.analyze( | |
img_path=temp_img, | |
actions=['emotion'], | |
enforce_detection=False, # Don't throw error if face not detected | |
detector_backend='opencv' # Faster detection | |
) | |
# Remove temporary file | |
try: | |
os.remove(temp_img) | |
except: | |
pass | |
# Return the first face analysis (assuming single face) | |
if isinstance(analysis, list) and len(analysis) > 0: | |
return analysis[0] | |
else: | |
return analysis | |
except Exception as e: | |
print(f"DeepFace analysis error: {e}") | |
return None | |
# --- Face Detection Backup with OpenCV --- | |
def detect_face_opencv(image): | |
"""Detect faces using OpenCV cascade classifier as backup""" | |
if image is None or face_cascade is None: | |
return None | |
try: | |
# Convert to grayscale for detection | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
# Detect faces | |
faces = face_cascade.detectMultiScale( | |
gray, | |
scaleFactor=1.1, | |
minNeighbors=5, | |
minSize=(30, 30) | |
) | |
if len(faces) == 0: | |
return None | |
# Get the largest face by area | |
largest_face = max(faces, key=lambda rect: rect[2] * rect[3]) | |
return {"rect": largest_face} | |
except Exception as e: | |
print(f"Error in OpenCV face detection: {e}") | |
return None | |
# --- Calculate Metrics from DeepFace Results --- | |
def calculate_metrics_from_deepface(deepface_results, ad_context=None): | |
""" | |
Calculate psychometric metrics from DeepFace analysis results | |
""" | |
if ad_context is None: | |
ad_context = {} | |
# Initialize default metrics | |
default_metrics = {m: 0.5 for m in metrics} | |
# If no facial data, return defaults | |
if not deepface_results or "emotion" not in deepface_results: | |
return default_metrics | |
# Extract emotion data from DeepFace | |
emotion_dict = deepface_results["emotion"] | |
# Find dominant emotion | |
dominant_emotion = max(emotion_dict.items(), key=lambda x: x[1])[0] | |
dominant_score = max(emotion_dict.items(), key=lambda x: x[1])[1] / 100.0 # Convert to 0-1 scale | |
# Get base values from emotion mapping | |
base_vals = emotion_mapping.get(dominant_emotion, {"valence": 0.5, "arousal": 0.5, "dominance": 0.5}) | |
# Calculate primary metrics with confidence weighting | |
val = base_vals["valence"] | |
arsl = base_vals["arousal"] | |
dom = base_vals["dominance"] | |
# Add directional adjustments based on specific emotions | |
if dominant_emotion == "happy": | |
val += 0.1 | |
elif dominant_emotion == "sad": | |
val -= 0.1 | |
elif dominant_emotion == "angry": | |
arsl += 0.1 | |
dom += 0.1 | |
elif dominant_emotion == "fear": | |
arsl += 0.1 | |
dom -= 0.1 | |
# Illustrative Context Adjustments from ad | |
ad_type = ad_context.get('ad_type', 'Unknown') | |
gem_txt = str(ad_context.get('gemini_ad_analysis', '')).lower() | |
# Adjust based on ad context | |
val_adj = 0.1 if ad_type == 'Funny' or 'humor' in gem_txt else 0.0 | |
arsl_adj = 0.1 if ad_type == 'Action' or 'exciting' in gem_txt else 0.0 | |
# Apply adjustments | |
val = max(0, min(1, val + val_adj)) | |
arsl = max(0, min(1, arsl + arsl_adj)) | |
# Estimate cognitive load based on emotional intensity | |
cl = 0.5 # Default | |
if dominant_emotion in ["neutral"]: | |
cl = 0.3 # Lower cognitive load for neutral expression | |
elif dominant_emotion in ["surprise", "fear"]: | |
cl = 0.7 # Higher cognitive load for surprise/fear | |
# Calculate secondary metrics | |
neur = max(0, min(1, (cl * 0.6) + ((1.0 - val) * 0.4))) | |
em_stab = 1.0 - neur | |
extr = max(0, min(1, (arsl * 0.5) + (val * 0.5))) | |
open = max(0, min(1, 0.5 + (val - 0.5) * 0.5)) | |
agree = max(0, min(1, (val * 0.7) + ((1.0 - arsl) * 0.3))) | |
consc = max(0, min(1, (1.0 - abs(arsl - 0.5)) * 0.7 + (em_stab * 0.3))) | |
stress = max(0, min(1, (cl * 0.5) + ((1.0 - val) * 0.5))) | |
engag = max(0, min(1, arsl * 0.7 + (val * 0.3))) | |
# Create metrics dictionary | |
calculated_metrics = { | |
'valence': val, | |
'arousal': arsl, | |
'dominance': dom, | |
'cognitive_load': cl, | |
'emotional_stability': em_stab, | |
'openness': open, | |
'agreeableness': agree, | |
'neuroticism': neur, | |
'conscientiousness': consc, | |
'extraversion': extr, | |
'stress_index': stress, | |
'engagement_level': engag | |
} | |
return calculated_metrics | |
def update_metrics_visualization(metrics_values): | |
"""Create a visualization of metrics""" | |
if not metrics_values: | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
ax.text(0.5, 0.5, "Waiting for facial metrics...", ha='center', va='center') | |
ax.axis('off') | |
fig.patch.set_facecolor('#FFFFFF') | |
ax.set_facecolor('#FFFFFF') | |
return fig | |
# Filter out non-metric keys | |
filtered_metrics = {k: v for k, v in metrics_values.items() | |
if k in metrics and isinstance(v, (int, float))} | |
if not filtered_metrics: | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
ax.text(0.5, 0.5, "No valid metrics available", ha='center', va='center') | |
ax.axis('off') | |
return fig | |
num_metrics = len(filtered_metrics) | |
nrows = (num_metrics + 2) // 3 | |
fig, axs = plt.subplots(nrows, 3, figsize=(10, nrows * 2.5), facecolor='#FFFFFF') | |
axs = axs.flatten() | |
colors = [(0.1, 0.1, 0.9), (0.9, 0.9, 0.1), (0.9, 0.1, 0.1)] | |
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors, N=100) | |
norm = plt.Normalize(0, 1) | |
metric_idx = 0 | |
for key, value in filtered_metrics.items(): | |
value = max(0.0, min(1.0, value)) # Clip value for safety | |
ax = axs[metric_idx] | |
ax.set_title(key.replace('_', ' ').title(), fontsize=10) | |
ax.set_xlim(0, 1) | |
ax.set_ylim(0, 0.5) | |
ax.set_aspect('equal') | |
ax.axis('off') | |
ax.set_facecolor('#FFFFFF') | |
r = 0.4 | |
theta = np.linspace(np.pi, 0, 100) | |
x_bg = 0.5 + r * np.cos(theta) | |
y_bg = 0.1 + r * np.sin(theta) | |
ax.plot(x_bg, y_bg, 'k-', linewidth=3, alpha=0.2) | |
value_angle = np.pi * (1 - value) | |
num_points = max(2, int(100 * value)) | |
value_theta = np.linspace(np.pi, value_angle, num_points) | |
x_val = 0.5 + r * np.cos(value_theta) | |
y_val = 0.1 + r * np.sin(value_theta) | |
if len(x_val) > 1: | |
points = np.array([x_val, y_val]).T.reshape(-1, 1, 2) | |
segments = np.concatenate([points[:-1], points[1:]], axis=1) | |
segment_values = np.linspace(0, value, len(segments)) | |
lc = LineCollection(segments, cmap=cmap, norm=norm) | |
lc.set_array(segment_values) | |
lc.set_linewidth(5) | |
ax.add_collection(lc) | |
ax.text(0.5, 0.15, f"{value:.2f}", ha='center', va='center', fontsize=11, | |
fontweight='bold', bbox=dict(facecolor='white', alpha=0.7, boxstyle='round,pad=0.2')) | |
metric_idx += 1 | |
for i in range(metric_idx, len(axs)): | |
axs[i].axis('off') | |
plt.tight_layout(pad=0.5) | |
return fig | |
def annotate_frame(frame, face_data=None, deepface_results=None, metrics=None, enhanced_state=None): | |
""" | |
Add facial annotations and metrics to a frame | |
""" | |
if frame is None: | |
return None | |
annotated = frame.copy() | |
# Draw face rectangle if available | |
if face_data and "rect" in face_data: | |
x, y, w, h = face_data["rect"] | |
cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
elif deepface_results and "region" in deepface_results: | |
region = deepface_results["region"] | |
x, y, w, h = region["x"], region["y"], region["w"], region["h"] | |
cv2.rectangle(annotated, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
# Add emotion and metrics summary | |
if deepface_results or metrics: | |
# Format for display | |
h, w = annotated.shape[:2] | |
y_pos = 30 # Starting Y position | |
# Add emotion info if available from DeepFace | |
if deepface_results and "dominant_emotion" in deepface_results: | |
emotion_text = f"Emotion: {deepface_results['dominant_emotion'].capitalize()}" | |
text_size = cv2.getTextSize(emotion_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] | |
cv2.rectangle(annotated, (10, y_pos - 20), (10 + text_size[0], y_pos + 5), (0, 0, 0), -1) | |
cv2.putText(annotated, emotion_text, (10, y_pos), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
y_pos += 30 | |
# Add enhanced user state if available | |
if enhanced_state: | |
# Truncate if too long | |
if len(enhanced_state) > 60: | |
enhanced_state = enhanced_state[:57] + "..." | |
# Draw background for text | |
text_size = cv2.getTextSize(enhanced_state, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] | |
cv2.rectangle(annotated, (10, y_pos - 20), (10 + text_size[0], y_pos + 5), (0, 0, 0), -1) | |
# Draw text | |
cv2.putText(annotated, enhanced_state, (10, y_pos), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) | |
y_pos += 30 | |
# Show top 3 metrics | |
if metrics: | |
top_metrics = sorted([(k, v) for k, v in metrics.items() if k in metrics], | |
key=lambda x: x[1], reverse=True)[:3] | |
for name, value in top_metrics: | |
metric_text = f"{name.replace('_', ' ').title()}: {value:.2f}" | |
text_size = cv2.getTextSize(metric_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] | |
cv2.rectangle(annotated, (10, y_pos - 15), (10 + text_size[0], y_pos + 5), (0, 0, 0), -1) | |
cv2.putText(annotated, metric_text, (10, y_pos), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
y_pos += 25 | |
return annotated | |
# --- API 1: Video File Processing --- | |
def process_video_file( | |
video_file: Union[str, np.ndarray], | |
ad_description: str = "", | |
ad_detail: str = "", | |
ad_type: str = "Video", | |
sampling_rate: int = 5, # Process every Nth frame | |
save_processed_video: bool = True, | |
show_progress: bool = True | |
) -> Tuple[str, str, pd.DataFrame, List[np.ndarray]]: | |
""" | |
Process a video file and analyze facial expressions frame by frame | |
Args: | |
video_file: Path to video file or video array | |
ad_description: Description of the ad being watched | |
ad_detail: Detail focus of the ad | |
ad_type: Type of ad (Video, Image, Audio, Text, Funny, etc.) | |
sampling_rate: Process every Nth frame | |
save_processed_video: Whether to save the processed video with annotations | |
show_progress: Whether to show processing progress | |
Returns: | |
Tuple of (csv_path, processed_video_path, metrics_dataframe, processed_frames_list) | |
""" | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp) | |
video_path = VIDEO_FILENAME_TEMPLATE.format(timestamp=timestamp) if save_processed_video else None | |
# Setup ad context | |
gemini_result = call_gemini_api_for_ad(ad_description, ad_detail, ad_type) | |
ad_context = { | |
"ad_description": ad_description, | |
"ad_detail": ad_detail, | |
"ad_type": ad_type, | |
"gemini_ad_analysis": gemini_result | |
} | |
# Initialize capture | |
if isinstance(video_file, str): | |
cap = cv2.VideoCapture(video_file) | |
else: | |
# Create a temporary file for the video array | |
temp_dir = tempfile.mkdtemp() | |
temp_path = os.path.join(temp_dir, "temp_video.mp4") | |
# Convert video array to file | |
if isinstance(video_file, np.ndarray) and len(video_file.shape) == 4: # Multiple frames | |
h, w = video_file[0].shape[:2] | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
temp_writer = cv2.VideoWriter(temp_path, fourcc, 30, (w, h)) | |
for frame in video_file: | |
temp_writer.write(frame) | |
temp_writer.release() | |
cap = cv2.VideoCapture(temp_path) | |
elif isinstance(video_file, np.ndarray) and len(video_file.shape) == 3: # Single frame | |
# For single frame, just process it directly | |
metrics_data = [] | |
processed_frames = [] | |
# Process the single frame | |
deepface_results = analyze_face_with_deepface(video_file) | |
face_data = None | |
# Fall back to OpenCV face detection if DeepFace didn't detect a face | |
if not deepface_results or "region" not in deepface_results: | |
face_data = detect_face_opencv(video_file) | |
# Calculate metrics if face detected | |
if deepface_results or face_data: | |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context) | |
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context) | |
# Create a row for the dataframe | |
row = { | |
'timestamp': 0.0, | |
'frame_number': 0, | |
**calculated_metrics, | |
**ad_context, | |
'user_state': user_state, | |
'enhanced_user_state': enhanced_state | |
} | |
metrics_data.append(row) | |
# Annotate the frame | |
annotated_frame = annotate_frame(video_file, face_data, deepface_results, calculated_metrics, enhanced_state) | |
processed_frames.append(annotated_frame) | |
# Save processed image | |
if save_processed_video: | |
cv2.imwrite(video_path.replace('.mp4', '.jpg'), annotated_frame) | |
# Create DataFrame and save to CSV | |
metrics_df = pd.DataFrame(metrics_data) | |
if not metrics_df.empty: | |
metrics_df.to_csv(csv_path, index=False) | |
return csv_path, video_path.replace('.mp4', '.jpg') if save_processed_video else None, metrics_df, processed_frames | |
else: | |
print("Error: Invalid video input format") | |
return None, None, None, [] | |
if not cap.isOpened(): | |
print("Error: Could not open video.") | |
return None, None, None, [] | |
# Get video properties | |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Initialize video writer if saving processed video | |
if save_processed_video: | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(video_path, fourcc, fps / sampling_rate, (frame_width, frame_height)) | |
# Process video frames | |
metrics_data = [] | |
processed_frames = [] | |
frame_count = 0 | |
if show_progress: | |
print(f"Processing video with {total_frames} frames at {fps} FPS") | |
print(f"Ad Context: {ad_description} ({ad_type})") | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Only process every Nth frame (according to sampling_rate) | |
if frame_count % sampling_rate == 0: | |
if show_progress and frame_count % (sampling_rate * 10) == 0: | |
print(f"Processing frame {frame_count}/{total_frames} ({frame_count/total_frames*100:.1f}%)") | |
# Analyze with DeepFace | |
deepface_results = analyze_face_with_deepface(frame) | |
face_data = None | |
# Fall back to OpenCV face detection if DeepFace didn't detect a face | |
if not deepface_results or "region" not in deepface_results: | |
face_data = detect_face_opencv(frame) | |
# Calculate metrics if face detected | |
if deepface_results or face_data: | |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context) | |
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context) | |
# Create a row for the dataframe | |
row = { | |
'timestamp': frame_count / fps, | |
'frame_number': frame_count, | |
**calculated_metrics, | |
**ad_context, | |
'user_state': user_state, | |
'enhanced_user_state': enhanced_state | |
} | |
metrics_data.append(row) | |
# Annotate the frame | |
annotated_frame = annotate_frame(frame, face_data, deepface_results, calculated_metrics, enhanced_state) | |
if save_processed_video: | |
out.write(annotated_frame) | |
processed_frames.append(annotated_frame) | |
else: | |
# No face detected | |
if save_processed_video: | |
# Add text to frame | |
no_face_frame = frame.copy() | |
cv2.putText(no_face_frame, "No face detected", (30, 30), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
out.write(no_face_frame) | |
processed_frames.append(no_face_frame) | |
frame_count += 1 | |
# Release resources | |
cap.release() | |
if save_processed_video: | |
out.release() | |
# Create DataFrame and save to CSV | |
metrics_df = pd.DataFrame(metrics_data) | |
if not metrics_df.empty: | |
metrics_df.to_csv(csv_path, index=False) | |
if show_progress: | |
print(f"Video processing complete. Analyzed {len(metrics_data)} frames.") | |
print(f"Results saved to {csv_path}") | |
if save_processed_video: | |
print(f"Processed video saved to {video_path}") | |
# Return results | |
return csv_path, video_path, metrics_df, processed_frames | |
# --- API 2: Webcam Processing Function --- | |
def process_webcam_frame( | |
frame: np.ndarray, | |
ad_context: Dict[str, Any], | |
metrics_data: pd.DataFrame, | |
frame_count: int, | |
start_time: float | |
) -> Tuple[np.ndarray, Dict[str, float], str, pd.DataFrame]: | |
""" | |
Process a single webcam frame | |
Args: | |
frame: Input frame from webcam | |
ad_context: Ad context dictionary | |
metrics_data: DataFrame to accumulate metrics | |
frame_count: Current frame count | |
start_time: Start time of the session | |
Returns: | |
Tuple of (annotated_frame, metrics_dict, enhanced_state, updated_metrics_df) | |
""" | |
if frame is None: | |
return None, None, None, metrics_data | |
# Analyze with DeepFace | |
deepface_results = analyze_face_with_deepface(frame) | |
face_data = None | |
# Fall back to OpenCV face detection if DeepFace didn't detect a face | |
if not deepface_results or "region" not in deepface_results: | |
face_data = detect_face_opencv(frame) | |
# Calculate metrics if face detected | |
if deepface_results or face_data: | |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context) | |
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context) | |
# Create a row for the dataframe | |
current_time = time.time() | |
row = { | |
'timestamp': current_time - start_time, | |
'frame_number': frame_count, | |
**calculated_metrics, | |
**ad_context, | |
'user_state': user_state, | |
'enhanced_user_state': enhanced_state | |
} | |
# Add row to DataFrame | |
new_row_df = pd.DataFrame([row], columns=all_columns) | |
metrics_data = pd.concat([metrics_data, new_row_df], ignore_index=True) | |
# Annotate the frame | |
annotated_frame = annotate_frame(frame, face_data, deepface_results, calculated_metrics, enhanced_state) | |
return annotated_frame, calculated_metrics, enhanced_state, metrics_data | |
else: | |
# No face detected | |
no_face_frame = frame.copy() | |
cv2.putText(no_face_frame, "No face detected", (30, 30), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
return no_face_frame, None, "No face detected", metrics_data | |
def start_webcam_session( | |
ad_description: str = "", | |
ad_detail: str = "", | |
ad_type: str = "Video", | |
save_interval: int = 100, # Save CSV every N frames | |
record_video: bool = True | |
) -> Dict[str, Any]: | |
""" | |
Initialize a webcam session for facial analysis | |
Args: | |
ad_description: Description of the ad being watched | |
ad_detail: Detail focus of the ad | |
ad_type: Type of ad | |
save_interval: How often to save data to CSV | |
record_video: Whether to record processed frames for later saving | |
Returns: | |
Session context dictionary | |
""" | |
# Generate timestamp for file naming | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
csv_path = CSV_FILENAME_TEMPLATE.format(timestamp=timestamp) | |
video_path = VIDEO_FILENAME_TEMPLATE.format(timestamp=timestamp) if record_video else None | |
# Setup ad context | |
gemini_result = call_gemini_api_for_ad(ad_description, ad_detail, ad_type) | |
ad_context = { | |
"ad_description": ad_description, | |
"ad_detail": ad_detail, | |
"ad_type": ad_type, | |
"gemini_ad_analysis": gemini_result | |
} | |
# Initialize session context | |
session = { | |
"start_time": time.time(), | |
"frame_count": 0, | |
"metrics_data": initial_metrics_df.copy(), | |
"ad_context": ad_context, | |
"csv_path": csv_path, | |
"video_path": video_path, | |
"save_interval": save_interval, | |
"last_saved": 0, | |
"record_video": record_video, | |
"recorded_frames": [] if record_video else None, | |
"timestamps": [] if record_video else None | |
} | |
return session | |
def update_webcam_session( | |
session: Dict[str, Any], | |
frame: np.ndarray | |
) -> Tuple[np.ndarray, Dict[str, float], str, Dict[str, Any]]: | |
""" | |
Update webcam session with a new frame | |
Args: | |
session: Session context dictionary | |
frame: New frame from webcam | |
Returns: | |
Tuple of (annotated_frame, metrics_dict, enhanced_state, updated_session) | |
""" | |
# Process the frame | |
annotated_frame, metrics, enhanced_state, updated_df = process_webcam_frame( | |
frame, | |
session["ad_context"], | |
session["metrics_data"], | |
session["frame_count"], | |
session["start_time"] | |
) | |
# Update session | |
session["frame_count"] += 1 | |
session["metrics_data"] = updated_df | |
# Record frame if enabled | |
if session["record_video"] and annotated_frame is not None: | |
session["recorded_frames"].append(annotated_frame) | |
session["timestamps"].append(time.time() - session["start_time"]) | |
# Save CSV periodically | |
if session["frame_count"] - session["last_saved"] >= session["save_interval"]: | |
if not updated_df.empty: | |
updated_df.to_csv(session["csv_path"], index=False) | |
session["last_saved"] = session["frame_count"] | |
return annotated_frame, metrics, enhanced_state, session | |
def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]: | |
""" | |
End a webcam session and save final results | |
Args: | |
session: Session context dictionary | |
Returns: | |
Tuple of (csv_path, video_path) | |
""" | |
# Save final metrics to CSV | |
if not session["metrics_data"].empty: | |
session["metrics_data"].to_csv(session["csv_path"], index=False) | |
# Save recorded video if available | |
video_path = None | |
if session["record_video"] and session["recorded_frames"]: | |
try: | |
frames = session["recorded_frames"] | |
if frames: | |
# Get frame dimensions | |
height, width = frames[0].shape[:2] | |
# Calculate FPS based on actual timestamps | |
if len(session["timestamps"]) > 1: | |
# Calculate average time between frames | |
time_diffs = np.diff(session["timestamps"]) | |
avg_frame_time = np.mean(time_diffs) | |
fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 15.0 | |
else: | |
fps = 15.0 # Default FPS | |
# Create video writer | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
video_path = session["video_path"] | |
out = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) | |
# Write frames | |
for frame in frames: | |
out.write(frame) | |
out.release() | |
print(f"Recorded video saved to {video_path}") | |
else: | |
print("No frames recorded") | |
except Exception as e: | |
print(f"Error saving video: {e}") | |
print(f"Session ended. Data saved to {session['csv_path']}") | |
return session["csv_path"], video_path | |
# --- Create Gradio Interface --- | |
def create_api_interface(): | |
with gr.Blocks(title="Facial Analysis APIs") as iface: | |
gr.Markdown(f""" | |
# Enhanced Facial Analysis APIs (DeepFace) | |
This interface provides two API endpoints: | |
1. **Video File API**: Upload and analyze pre-recorded videos | |
2. **Webcam API**: Analyze live webcam feed in real-time | |
Both APIs use DeepFace for emotion analysis and Google's Gemini API for enhanced interpretations. | |
""") | |
with gr.Tab("Video File API"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
video_input = gr.Video(label="Upload Video") | |
vid_ad_desc = gr.Textbox(label="Ad Description", placeholder="Enter a description of the advertisement being watched...") | |
vid_ad_detail = gr.Textbox(label="Ad Detail Focus", placeholder="Enter specific aspects to focus on...") | |
vid_ad_type = gr.Radio( | |
["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"], | |
label="Ad Type/Genre", | |
value="Video" | |
) | |
sampling_rate = gr.Slider( | |
minimum=1, maximum=30, step=1, value=5, | |
label="Sampling Rate (process every N frames)" | |
) | |
save_video = gr.Checkbox(label="Save Processed Video", value=True) | |
process_btn = gr.Button("Process Video", variant="primary") | |
with gr.Column(scale=2): | |
output_text = gr.Textbox(label="Processing Results", lines=3) | |
with gr.Row(): | |
with gr.Column(): | |
output_video = gr.Video(label="Processed Video") | |
with gr.Column(): | |
frame_gallery = gr.Gallery(label="Processed Frames", | |
show_label=True, columns=2, | |
height=400) | |
with gr.Row(): | |
with gr.Column(): | |
output_plot = gr.Plot(label="Sample Frame Metrics") | |
with gr.Column(): | |
output_csv = gr.File(label="Download CSV Results") | |
# Define function to handle video processing and show frames | |
def handle_video_processing(video, desc, detail, ad_type, rate, save_vid): | |
if video is None: | |
return "No video uploaded", None, None, [], None | |
try: | |
result_text = "Starting video processing...\n" | |
# Process the video | |
csv_path, video_path, metrics_df, processed_frames = process_video_file( | |
video, | |
ad_description=desc, | |
ad_detail=detail, | |
ad_type=ad_type, | |
sampling_rate=rate, | |
save_processed_video=save_vid, | |
show_progress=True | |
) | |
if metrics_df is None or metrics_df.empty: | |
return "No facial data detected in video", None, None, [], None | |
# Generate a sample metrics visualization | |
sample_row = metrics_df.iloc[0].to_dict() | |
metrics_plot = update_metrics_visualization(sample_row) | |
# Create a gallery of processed frames | |
# Take a subset if there are too many frames (maximum ~20 for display) | |
display_frames = [] | |
step = max(1, len(processed_frames) // 20) | |
for i in range(0, len(processed_frames), step): | |
if i < len(processed_frames): | |
# Convert BGR to RGB for display | |
rgb_frame = cv2.cvtColor(processed_frames[i], cv2.COLOR_BGR2RGB) | |
display_frames.append(rgb_frame) | |
# Return results summary | |
processed_count = metrics_df.shape[0] | |
total_count = len(processed_frames) | |
result_text = f"✅ Processed {processed_count} frames out of {total_count} total frames.\n" | |
result_text += f"📊 CSV saved with {len(metrics_df.columns)} metrics columns.\n" | |
if video_path: | |
result_text += f"🎬 Processed video saved to: {video_path}" | |
return result_text, video_path, metrics_plot, display_frames, csv_path | |
except Exception as e: | |
return f"❌ Error processing video: {str(e)}", None, None, [], None | |
process_btn.click( | |
handle_video_processing, | |
inputs=[video_input, vid_ad_desc, vid_ad_detail, vid_ad_type, sampling_rate, save_video], | |
outputs=[output_text, output_video, output_plot, frame_gallery, output_csv] | |
) | |
with gr.Tab("Webcam API"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
webcam_input = gr.Image(sources="webcam", streaming=True, label="Webcam Input", type="numpy") | |
with gr.Row(): | |
with gr.Column(): | |
web_ad_desc = gr.Textbox(label="Ad Description", placeholder="Enter a description of the advertisement being watched...") | |
web_ad_detail = gr.Textbox(label="Ad Detail Focus", placeholder="Enter specific aspects to focus on...") | |
web_ad_type = gr.Radio( | |
["Video", "Image", "Audio", "Text", "Funny", "Serious", "Action", "Informative"], | |
label="Ad Type/Genre", | |
value="Video" | |
) | |
with gr.Column(): | |
record_video_chk = gr.Checkbox(label="Record Video", value=True) | |
start_session_btn = gr.Button("Start Session", variant="primary") | |
end_session_btn = gr.Button("End Session", variant="stop") | |
session_status = gr.Textbox(label="Session Status", placeholder="Session not started...") | |
with gr.Column(scale=2): | |
processed_output = gr.Image(label="Processed Feed", type="numpy", height=360) | |
with gr.Row(): | |
with gr.Column(): | |
metrics_plot = gr.Plot(label="Current Metrics", height=300) | |
with gr.Column(): | |
enhanced_state_txt = gr.Textbox(label="Enhanced State Analysis", lines=3) | |
with gr.Row(): | |
download_csv = gr.File(label="Download Session Data") | |
download_video = gr.Video(label="Recorded Session") | |
# Session state | |
session_data = gr.State(value=None) | |
# Define session handlers | |
def start_session(desc, detail, ad_type, record_video): | |
session = start_webcam_session( | |
ad_description=desc, | |
ad_detail=detail, | |
ad_type=ad_type, | |
record_video=record_video | |
) | |
return ( | |
session, | |
f"Session started at {datetime.datetime.now().strftime('%H:%M:%S')}.\n" | |
f"Ad context: {desc} ({ad_type}).\n" | |
f"Data will be saved to {session['csv_path']}" | |
) | |
def process_frame(frame, session): | |
if session is None: | |
return frame, None, "No active session. Click 'Start Session' to begin.", session | |
# Process the frame | |
annotated_frame, metrics, enhanced_state, updated_session = update_webcam_session(session, frame) | |
# Update the metrics plot if metrics available | |
if metrics: | |
metrics_plot = update_metrics_visualization(metrics) | |
return annotated_frame, metrics_plot, enhanced_state, updated_session | |
else: | |
# Return the annotated frame (likely with "No face detected") | |
return annotated_frame, None, enhanced_state or "No metrics available", updated_session | |
def end_session(session): | |
if session is None: | |
return "No active session", None, None | |
csv_path, video_path = end_webcam_session(session) | |
end_time = datetime.datetime.now().strftime('%H:%M:%S') | |
result = f"Session ended at {end_time}.\n" | |
if csv_path: | |
result += f"CSV data saved to: {csv_path}\n" | |
if video_path: | |
result += f"Video saved to: {video_path}" | |
return result, csv_path, video_path | |
start_session_btn.click( | |
start_session, | |
inputs=[web_ad_desc, web_ad_detail, web_ad_type, record_video_chk], | |
outputs=[session_data, session_status] | |
) | |
webcam_input.stream( | |
process_frame, | |
inputs=[webcam_input, session_data], | |
outputs=[processed_output, metrics_plot, enhanced_state_txt, session_data] | |
) | |
end_session_btn.click( | |
end_session, | |
inputs=[session_data], | |
outputs=[session_status, download_csv, download_video] | |
) | |
return iface | |
# Entry point | |
if __name__ == "__main__": | |
print("Starting Enhanced Facial Analysis API (DeepFace)...") | |
print(f"Gemini API {'enabled' if GEMINI_ENABLED else 'disabled (using simulation)'}") | |
iface = create_api_interface() | |
iface.launch(debug=True) |