Spaces:
Configuration error
Configuration error
Update app.py
Browse files
app.py
CHANGED
@@ -18,9 +18,7 @@ from deepface import DeepFace
|
|
18 |
import base64
|
19 |
import io
|
20 |
from pathlib import Path
|
21 |
-
import
|
22 |
-
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
|
23 |
-
from io import BytesIO
|
24 |
|
25 |
# Suppress warnings for cleaner output
|
26 |
warnings.filterwarnings('ignore')
|
@@ -40,8 +38,8 @@ try:
|
|
40 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
41 |
|
42 |
genai.configure(api_key=GOOGLE_API_KEY)
|
43 |
-
# Use gemini-
|
44 |
-
model = genai.GenerativeModel('gemini-
|
45 |
GEMINI_ENABLED = True
|
46 |
print("Google Gemini API configured successfully.")
|
47 |
except Exception as e:
|
@@ -49,45 +47,6 @@ except Exception as e:
|
|
49 |
print("Running with simulated Gemini API responses.")
|
50 |
GEMINI_ENABLED = False
|
51 |
|
52 |
-
# --- Initialize LLaVA Vision Model ---
|
53 |
-
print("Initializing LLaVA Vision Model...")
|
54 |
-
LLAVA_ENABLED = False
|
55 |
-
try:
|
56 |
-
# Check if GPU is available
|
57 |
-
if torch.cuda.is_available():
|
58 |
-
device = "cuda"
|
59 |
-
else:
|
60 |
-
device = "cpu"
|
61 |
-
|
62 |
-
# Use a smaller LLaVA model for better performance
|
63 |
-
model_id = "llava-hf/llava-1.5-7b-hf"
|
64 |
-
|
65 |
-
# Initialize the model
|
66 |
-
processor = AutoProcessor.from_pretrained(model_id)
|
67 |
-
llava_model = AutoModelForCausalLM.from_pretrained(
|
68 |
-
model_id,
|
69 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
70 |
-
low_cpu_mem_usage=True if device == "cuda" else False,
|
71 |
-
).to(device)
|
72 |
-
|
73 |
-
# Create a pipeline
|
74 |
-
vision_llm = pipeline(
|
75 |
-
"image-to-text",
|
76 |
-
model=llava_model,
|
77 |
-
tokenizer=processor.tokenizer,
|
78 |
-
image_processor=processor.image_processor,
|
79 |
-
device=device,
|
80 |
-
max_new_tokens=512,
|
81 |
-
)
|
82 |
-
|
83 |
-
LLAVA_ENABLED = True
|
84 |
-
print(f"LLaVA Vision Model initialized successfully on {device.upper()}")
|
85 |
-
|
86 |
-
except Exception as e:
|
87 |
-
print(f"WARNING: Failed to initialize LLaVA Vision Model: {e}")
|
88 |
-
print("Running with DeepFace only (no LLaVA vision features).")
|
89 |
-
vision_llm = None
|
90 |
-
|
91 |
# --- Initialize OpenCV face detector for backup ---
|
92 |
print("Initializing OpenCV face detector...")
|
93 |
try:
|
@@ -123,63 +82,10 @@ emotion_mapping = {
|
|
123 |
}
|
124 |
|
125 |
ad_context_columns = ["ad_description", "ad_detail", "ad_type", "gemini_ad_analysis"]
|
126 |
-
user_state_columns = ["user_state", "enhanced_user_state"
|
127 |
all_columns = ['timestamp', 'frame_number'] + metrics + ad_context_columns + user_state_columns
|
128 |
initial_metrics_df = pd.DataFrame(columns=all_columns)
|
129 |
|
130 |
-
# --- LLaVA Vision Analysis Function ---
|
131 |
-
def analyze_image_with_llava(image, ad_context=None):
|
132 |
-
"""
|
133 |
-
Use LLaVA vision model to analyze facial expression and emotion in image
|
134 |
-
"""
|
135 |
-
if not LLAVA_ENABLED or vision_llm is None or image is None:
|
136 |
-
return "LLaVA analysis not available"
|
137 |
-
|
138 |
-
try:
|
139 |
-
# Convert OpenCV image (BGR) to PIL Image (RGB)
|
140 |
-
if len(image.shape) == 3 and image.shape[2] == 3:
|
141 |
-
# Check if BGR and convert to RGB if needed
|
142 |
-
if np.mean(image[:,:,0]) < np.mean(image[:,:,2]): # Rough BGR check
|
143 |
-
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
144 |
-
else:
|
145 |
-
image_rgb = image
|
146 |
-
else:
|
147 |
-
# Handle grayscale or other formats
|
148 |
-
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
149 |
-
|
150 |
-
# Convert to PIL Image
|
151 |
-
pil_image = Image.fromarray(image_rgb)
|
152 |
-
|
153 |
-
# Create prompt based on ad context
|
154 |
-
ad_info = ""
|
155 |
-
if ad_context:
|
156 |
-
ad_desc = ad_context.get('ad_description', '')
|
157 |
-
ad_type = ad_context.get('ad_type', '')
|
158 |
-
if ad_desc:
|
159 |
-
ad_info = f" while watching an ad about {ad_desc} (type: {ad_type})"
|
160 |
-
|
161 |
-
prompt = f"""Analyze this person's facial expression and emotion{ad_info}.
|
162 |
-
Describe their emotional state, engagement level, and cognitive state in detail.
|
163 |
-
Focus on: valence (positive/negative emotion), arousal (excitement level),
|
164 |
-
attention, stress indicators, and overall reaction to what they're seeing.
|
165 |
-
"""
|
166 |
-
|
167 |
-
# Process with Vision LLM
|
168 |
-
outputs = vision_llm(pil_image, prompt=prompt)
|
169 |
-
|
170 |
-
# Extract the generated text
|
171 |
-
if isinstance(outputs, list) and len(outputs) > 0:
|
172 |
-
if isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
|
173 |
-
return outputs[0]["generated_text"]
|
174 |
-
elif isinstance(outputs[0], str):
|
175 |
-
return outputs[0]
|
176 |
-
|
177 |
-
return str(outputs) if outputs else "No results from LLaVA analysis"
|
178 |
-
|
179 |
-
except Exception as e:
|
180 |
-
print(f"Error in LLaVA analysis: {e}")
|
181 |
-
return f"LLaVA analysis error: {str(e)}"
|
182 |
-
|
183 |
# --- Gemini API Functions ---
|
184 |
def call_gemini_api_for_ad(description, detail, ad_type):
|
185 |
"""
|
@@ -212,12 +118,12 @@ def call_gemini_api_for_ad(description, detail, ad_type):
|
|
212 |
print(f"Error calling Gemini for ad context: {e}")
|
213 |
return f"Error analyzing ad context: {str(e)}"
|
214 |
|
215 |
-
def interpret_metrics_with_gemini(metrics_dict, deepface_results=None,
|
216 |
"""
|
217 |
-
Uses Google Gemini to interpret facial metrics
|
218 |
to determine user state.
|
219 |
"""
|
220 |
-
if not metrics_dict and not deepface_results
|
221 |
return "No metrics", "No facial data detected"
|
222 |
|
223 |
if not GEMINI_ENABLED:
|
@@ -239,10 +145,6 @@ def interpret_metrics_with_gemini(metrics_dict, deepface_results=None, llava_ana
|
|
239 |
state = "Stressed, Negative"
|
240 |
|
241 |
enhanced_state = f"The viewer appears {state.lower()} while watching this content."
|
242 |
-
if llava_analysis and llava_analysis != "LLaVA analysis not available":
|
243 |
-
# Extract a brief summary from LLaVA analysis (first sentence)
|
244 |
-
first_sentence = llava_analysis.split('.')[0] + '.'
|
245 |
-
enhanced_state += f" {first_sentence}"
|
246 |
|
247 |
return state, enhanced_state
|
248 |
else:
|
@@ -259,11 +161,6 @@ def interpret_metrics_with_gemini(metrics_dict, deepface_results=None, llava_ana
|
|
259 |
emotion_dict = deepface_results["emotion"]
|
260 |
deepface_formatted = "\nDeepFace emotions:\n" + "\n".join([f"- {k.title()}: {v:.2f}" for k, v in emotion_dict.items()])
|
261 |
|
262 |
-
# Format LLaVA analysis
|
263 |
-
llava_formatted = ""
|
264 |
-
if llava_analysis and llava_analysis != "LLaVA analysis not available":
|
265 |
-
llava_formatted = f"\nLLaVA Vision Analysis:\n{llava_analysis}"
|
266 |
-
|
267 |
# Include ad context if available
|
268 |
ad_info = ""
|
269 |
if ad_context:
|
@@ -274,7 +171,7 @@ def interpret_metrics_with_gemini(metrics_dict, deepface_results=None, llava_ana
|
|
274 |
prompt = f"""
|
275 |
Analyze the facial expression and emotion of a person watching an advertisement{ad_info}.
|
276 |
|
277 |
-
Use these combined inputs:{metrics_formatted}{deepface_formatted}
|
278 |
|
279 |
Provide two outputs:
|
280 |
1. User State: A short 1-3 word description of their emotional/cognitive state
|
@@ -303,6 +200,7 @@ def interpret_metrics_with_gemini(metrics_dict, deepface_results=None, llava_ana
|
|
303 |
|
304 |
except Exception as e:
|
305 |
print(f"Error calling Gemini for metric interpretation: {e}")
|
|
|
306 |
return "Error", f"Error analyzing facial metrics: {str(e)}"
|
307 |
|
308 |
# --- DeepFace Analysis Function ---
|
@@ -330,7 +228,7 @@ def analyze_face_with_deepface(image):
|
|
330 |
# Analyze with DeepFace
|
331 |
analysis = DeepFace.analyze(
|
332 |
img_path=temp_img,
|
333 |
-
actions=['emotion'
|
334 |
enforce_detection=False, # Don't throw error if face not detected
|
335 |
detector_backend='opencv' # Faster detection
|
336 |
)
|
@@ -422,16 +320,6 @@ def calculate_metrics_from_deepface(deepface_results, ad_context=None):
|
|
422 |
arsl += 0.1
|
423 |
dom -= 0.1
|
424 |
|
425 |
-
# Adjust for gender and age if available (just examples of potential factors)
|
426 |
-
if "gender" in deepface_results:
|
427 |
-
gender = deepface_results["gender"]
|
428 |
-
gender_score = deepface_results.get("gender_score", 0.5)
|
429 |
-
# No real adjustment needed, this is just an example
|
430 |
-
|
431 |
-
if "age" in deepface_results:
|
432 |
-
age = deepface_results["age"]
|
433 |
-
# No real adjustment needed, this is just an example
|
434 |
-
|
435 |
# Illustrative Context Adjustments from ad
|
436 |
ad_type = ad_context.get('ad_type', 'Unknown')
|
437 |
gem_txt = str(ad_context.get('gemini_ad_analysis', '')).lower()
|
@@ -682,16 +570,10 @@ def process_video_file(
|
|
682 |
if not deepface_results or "region" not in deepface_results:
|
683 |
face_data = detect_face_opencv(video_file)
|
684 |
|
685 |
-
# Use LLaVA for additional analysis (once per frame)
|
686 |
-
llava_analysis = "LLaVA analysis not available"
|
687 |
-
if face_data is not None or (deepface_results and "region" in deepface_results):
|
688 |
-
# Only use LLaVA if a face was detected
|
689 |
-
llava_analysis = analyze_image_with_llava(video_file, ad_context)
|
690 |
-
|
691 |
# Calculate metrics if face detected
|
692 |
if deepface_results or face_data:
|
693 |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
|
694 |
-
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results,
|
695 |
|
696 |
# Create a row for the dataframe
|
697 |
row = {
|
@@ -700,8 +582,7 @@ def process_video_file(
|
|
700 |
**calculated_metrics,
|
701 |
**ad_context,
|
702 |
'user_state': user_state,
|
703 |
-
'enhanced_user_state': enhanced_state
|
704 |
-
'llava_analysis': llava_analysis
|
705 |
}
|
706 |
metrics_data.append(row)
|
707 |
|
@@ -742,13 +623,10 @@ def process_video_file(
|
|
742 |
metrics_data = []
|
743 |
processed_frames = []
|
744 |
frame_count = 0
|
745 |
-
llava_counter = 0 # To limit LLaVA analysis (it's slow)
|
746 |
-
llava_interval = sampling_rate * 10 # Run LLaVA every X frames
|
747 |
|
748 |
if show_progress:
|
749 |
print(f"Processing video with {total_frames} frames at {fps} FPS")
|
750 |
print(f"Ad Context: {ad_description} ({ad_type})")
|
751 |
-
print(f"LLaVA Vision Model: {'Enabled' if LLAVA_ENABLED else 'Disabled'}")
|
752 |
|
753 |
while True:
|
754 |
ret, frame = cap.read()
|
@@ -768,17 +646,10 @@ def process_video_file(
|
|
768 |
if not deepface_results or "region" not in deepface_results:
|
769 |
face_data = detect_face_opencv(frame)
|
770 |
|
771 |
-
# Use LLaVA for additional analysis (periodically to save time)
|
772 |
-
llava_analysis = "LLaVA analysis not available"
|
773 |
-
if (face_data is not None or (deepface_results and "region" in deepface_results)) and llava_counter % llava_interval == 0:
|
774 |
-
# Only use LLaVA if a face was detected and on the right interval
|
775 |
-
llava_analysis = analyze_image_with_llava(frame, ad_context)
|
776 |
-
llava_counter += 1
|
777 |
-
|
778 |
# Calculate metrics if face detected
|
779 |
if deepface_results or face_data:
|
780 |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
|
781 |
-
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results,
|
782 |
|
783 |
# Create a row for the dataframe
|
784 |
row = {
|
@@ -787,8 +658,7 @@ def process_video_file(
|
|
787 |
**calculated_metrics,
|
788 |
**ad_context,
|
789 |
'user_state': user_state,
|
790 |
-
'enhanced_user_state': enhanced_state
|
791 |
-
'llava_analysis': llava_analysis
|
792 |
}
|
793 |
metrics_data.append(row)
|
794 |
|
@@ -835,9 +705,8 @@ def process_webcam_frame(
|
|
835 |
ad_context: Dict[str, Any],
|
836 |
metrics_data: pd.DataFrame,
|
837 |
frame_count: int,
|
838 |
-
start_time: float
|
839 |
-
|
840 |
-
) -> Tuple[np.ndarray, Dict[str, float], str, str, pd.DataFrame, int]:
|
841 |
"""
|
842 |
Process a single webcam frame
|
843 |
|
@@ -847,13 +716,12 @@ def process_webcam_frame(
|
|
847 |
metrics_data: DataFrame to accumulate metrics
|
848 |
frame_count: Current frame count
|
849 |
start_time: Start time of the session
|
850 |
-
llava_counter: Counter to limit LLaVA calls
|
851 |
|
852 |
Returns:
|
853 |
-
Tuple of (annotated_frame, metrics_dict, enhanced_state,
|
854 |
"""
|
855 |
if frame is None:
|
856 |
-
return None, None, None,
|
857 |
|
858 |
# Analyze with DeepFace
|
859 |
deepface_results = analyze_face_with_deepface(frame)
|
@@ -863,19 +731,10 @@ def process_webcam_frame(
|
|
863 |
if not deepface_results or "region" not in deepface_results:
|
864 |
face_data = detect_face_opencv(frame)
|
865 |
|
866 |
-
# Use LLaVA for periodic analysis (it's slow)
|
867 |
-
llava_analysis = "LLaVA analysis not available"
|
868 |
-
llava_interval = 30 # Run LLaVA every X frames
|
869 |
-
|
870 |
-
if (face_data is not None or (deepface_results and "region" in deepface_results)) and llava_counter % llava_interval == 0:
|
871 |
-
# Only use LLaVA if a face was detected and on the right interval
|
872 |
-
llava_analysis = analyze_image_with_llava(frame, ad_context)
|
873 |
-
llava_counter += 1
|
874 |
-
|
875 |
# Calculate metrics if face detected
|
876 |
if deepface_results or face_data:
|
877 |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
|
878 |
-
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results,
|
879 |
|
880 |
# Create a row for the dataframe
|
881 |
current_time = time.time()
|
@@ -885,8 +744,7 @@ def process_webcam_frame(
|
|
885 |
**calculated_metrics,
|
886 |
**ad_context,
|
887 |
'user_state': user_state,
|
888 |
-
'enhanced_user_state': enhanced_state
|
889 |
-
'llava_analysis': llava_analysis
|
890 |
}
|
891 |
|
892 |
# Add row to DataFrame
|
@@ -896,13 +754,13 @@ def process_webcam_frame(
|
|
896 |
# Annotate the frame
|
897 |
annotated_frame = annotate_frame(frame, face_data, deepface_results, calculated_metrics, enhanced_state)
|
898 |
|
899 |
-
return annotated_frame, calculated_metrics, enhanced_state,
|
900 |
else:
|
901 |
# No face detected
|
902 |
no_face_frame = frame.copy()
|
903 |
cv2.putText(no_face_frame, "No face detected", (30, 30),
|
904 |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
|
905 |
-
return no_face_frame, None, "No face detected",
|
906 |
|
907 |
def start_webcam_session(
|
908 |
ad_description: str = "",
|
@@ -950,8 +808,7 @@ def start_webcam_session(
|
|
950 |
"last_saved": 0,
|
951 |
"record_video": record_video,
|
952 |
"recorded_frames": [] if record_video else None,
|
953 |
-
"timestamps": [] if record_video else None
|
954 |
-
"llava_counter": 0 # Counter to limit LLaVA calls
|
955 |
}
|
956 |
|
957 |
return session
|
@@ -959,7 +816,7 @@ def start_webcam_session(
|
|
959 |
def update_webcam_session(
|
960 |
session: Dict[str, Any],
|
961 |
frame: np.ndarray
|
962 |
-
) -> Tuple[np.ndarray, Dict[str, float], str,
|
963 |
"""
|
964 |
Update webcam session with a new frame
|
965 |
|
@@ -968,22 +825,20 @@ def update_webcam_session(
|
|
968 |
frame: New frame from webcam
|
969 |
|
970 |
Returns:
|
971 |
-
Tuple of (annotated_frame, metrics_dict, enhanced_state,
|
972 |
"""
|
973 |
# Process the frame
|
974 |
-
annotated_frame, metrics, enhanced_state,
|
975 |
frame,
|
976 |
session["ad_context"],
|
977 |
session["metrics_data"],
|
978 |
session["frame_count"],
|
979 |
-
session["start_time"]
|
980 |
-
session["llava_counter"]
|
981 |
)
|
982 |
|
983 |
# Update session
|
984 |
session["frame_count"] += 1
|
985 |
session["metrics_data"] = updated_df
|
986 |
-
session["llava_counter"] = updated_llava_counter
|
987 |
|
988 |
# Record frame if enabled
|
989 |
if session["record_video"] and annotated_frame is not None:
|
@@ -996,7 +851,7 @@ def update_webcam_session(
|
|
996 |
updated_df.to_csv(session["csv_path"], index=False)
|
997 |
session["last_saved"] = session["frame_count"]
|
998 |
|
999 |
-
return annotated_frame, metrics, enhanced_state,
|
1000 |
|
1001 |
def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]:
|
1002 |
"""
|
@@ -1053,7 +908,7 @@ def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]:
|
|
1053 |
def create_api_interface():
|
1054 |
with gr.Blocks(title="Facial Analysis APIs") as iface:
|
1055 |
gr.Markdown(f"""
|
1056 |
-
# Enhanced Facial Analysis APIs (
|
1057 |
|
1058 |
This interface provides two API endpoints:
|
1059 |
|
@@ -1061,8 +916,6 @@ def create_api_interface():
|
|
1061 |
2. **Webcam API**: Analyze live webcam feed in real-time
|
1062 |
|
1063 |
Both APIs use DeepFace for emotion analysis and Google's Gemini API for enhanced interpretations.
|
1064 |
-
|
1065 |
-
**LLaVA Vision Model: {'✅ Enabled' if LLAVA_ENABLED else '❌ Disabled'}**
|
1066 |
""")
|
1067 |
|
1068 |
with gr.Tab("Video File API"):
|
@@ -1181,9 +1034,6 @@ def create_api_interface():
|
|
1181 |
with gr.Column():
|
1182 |
enhanced_state_txt = gr.Textbox(label="Enhanced State Analysis", lines=3)
|
1183 |
|
1184 |
-
with gr.Row():
|
1185 |
-
llava_analysis_txt = gr.Textbox(label="LLaVA Vision Analysis", lines=6)
|
1186 |
-
|
1187 |
with gr.Row():
|
1188 |
download_csv = gr.File(label="Download Session Data")
|
1189 |
download_video = gr.Video(label="Recorded Session")
|
@@ -1208,18 +1058,18 @@ def create_api_interface():
|
|
1208 |
|
1209 |
def process_frame(frame, session):
|
1210 |
if session is None:
|
1211 |
-
return frame, None, "No active session. Click 'Start Session' to begin.",
|
1212 |
|
1213 |
# Process the frame
|
1214 |
-
annotated_frame, metrics, enhanced_state,
|
1215 |
|
1216 |
# Update the metrics plot if metrics available
|
1217 |
if metrics:
|
1218 |
metrics_plot = update_metrics_visualization(metrics)
|
1219 |
-
return annotated_frame, metrics_plot, enhanced_state,
|
1220 |
else:
|
1221 |
# Return the annotated frame (likely with "No face detected")
|
1222 |
-
return annotated_frame, None, enhanced_state or "No metrics available",
|
1223 |
|
1224 |
def end_session(session):
|
1225 |
if session is None:
|
@@ -1245,7 +1095,7 @@ def create_api_interface():
|
|
1245 |
webcam_input.stream(
|
1246 |
process_frame,
|
1247 |
inputs=[webcam_input, session_data],
|
1248 |
-
outputs=[processed_output, metrics_plot, enhanced_state_txt,
|
1249 |
)
|
1250 |
|
1251 |
end_session_btn.click(
|
@@ -1258,8 +1108,7 @@ def create_api_interface():
|
|
1258 |
|
1259 |
# Entry point
|
1260 |
if __name__ == "__main__":
|
1261 |
-
print("Starting Enhanced Facial Analysis API (
|
1262 |
print(f"Gemini API {'enabled' if GEMINI_ENABLED else 'disabled (using simulation)'}")
|
1263 |
-
print(f"LLaVA Vision Model {'enabled' if LLAVA_ENABLED else 'disabled (using DeepFace only)'}")
|
1264 |
iface = create_api_interface()
|
1265 |
iface.launch(debug=True)
|
|
|
18 |
import base64
|
19 |
import io
|
20 |
from pathlib import Path
|
21 |
+
import traceback
|
|
|
|
|
22 |
|
23 |
# Suppress warnings for cleaner output
|
24 |
warnings.filterwarnings('ignore')
|
|
|
38 |
raise ValueError("GOOGLE_API_KEY environment variable not set.")
|
39 |
|
40 |
genai.configure(api_key=GOOGLE_API_KEY)
|
41 |
+
# Use gemini-1.5-flash for quick responses
|
42 |
+
model = genai.GenerativeModel('gemini-1.5-flash')
|
43 |
GEMINI_ENABLED = True
|
44 |
print("Google Gemini API configured successfully.")
|
45 |
except Exception as e:
|
|
|
47 |
print("Running with simulated Gemini API responses.")
|
48 |
GEMINI_ENABLED = False
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# --- Initialize OpenCV face detector for backup ---
|
51 |
print("Initializing OpenCV face detector...")
|
52 |
try:
|
|
|
82 |
}
|
83 |
|
84 |
ad_context_columns = ["ad_description", "ad_detail", "ad_type", "gemini_ad_analysis"]
|
85 |
+
user_state_columns = ["user_state", "enhanced_user_state"]
|
86 |
all_columns = ['timestamp', 'frame_number'] + metrics + ad_context_columns + user_state_columns
|
87 |
initial_metrics_df = pd.DataFrame(columns=all_columns)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
# --- Gemini API Functions ---
|
90 |
def call_gemini_api_for_ad(description, detail, ad_type):
|
91 |
"""
|
|
|
118 |
print(f"Error calling Gemini for ad context: {e}")
|
119 |
return f"Error analyzing ad context: {str(e)}"
|
120 |
|
121 |
+
def interpret_metrics_with_gemini(metrics_dict, deepface_results=None, ad_context=None):
|
122 |
"""
|
123 |
+
Uses Google Gemini to interpret facial metrics and DeepFace results
|
124 |
to determine user state.
|
125 |
"""
|
126 |
+
if not metrics_dict and not deepface_results:
|
127 |
return "No metrics", "No facial data detected"
|
128 |
|
129 |
if not GEMINI_ENABLED:
|
|
|
145 |
state = "Stressed, Negative"
|
146 |
|
147 |
enhanced_state = f"The viewer appears {state.lower()} while watching this content."
|
|
|
|
|
|
|
|
|
148 |
|
149 |
return state, enhanced_state
|
150 |
else:
|
|
|
161 |
emotion_dict = deepface_results["emotion"]
|
162 |
deepface_formatted = "\nDeepFace emotions:\n" + "\n".join([f"- {k.title()}: {v:.2f}" for k, v in emotion_dict.items()])
|
163 |
|
|
|
|
|
|
|
|
|
|
|
164 |
# Include ad context if available
|
165 |
ad_info = ""
|
166 |
if ad_context:
|
|
|
171 |
prompt = f"""
|
172 |
Analyze the facial expression and emotion of a person watching an advertisement{ad_info}.
|
173 |
|
174 |
+
Use these combined inputs:{metrics_formatted}{deepface_formatted}
|
175 |
|
176 |
Provide two outputs:
|
177 |
1. User State: A short 1-3 word description of their emotional/cognitive state
|
|
|
200 |
|
201 |
except Exception as e:
|
202 |
print(f"Error calling Gemini for metric interpretation: {e}")
|
203 |
+
traceback.print_exc()
|
204 |
return "Error", f"Error analyzing facial metrics: {str(e)}"
|
205 |
|
206 |
# --- DeepFace Analysis Function ---
|
|
|
228 |
# Analyze with DeepFace
|
229 |
analysis = DeepFace.analyze(
|
230 |
img_path=temp_img,
|
231 |
+
actions=['emotion'],
|
232 |
enforce_detection=False, # Don't throw error if face not detected
|
233 |
detector_backend='opencv' # Faster detection
|
234 |
)
|
|
|
320 |
arsl += 0.1
|
321 |
dom -= 0.1
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
# Illustrative Context Adjustments from ad
|
324 |
ad_type = ad_context.get('ad_type', 'Unknown')
|
325 |
gem_txt = str(ad_context.get('gemini_ad_analysis', '')).lower()
|
|
|
570 |
if not deepface_results or "region" not in deepface_results:
|
571 |
face_data = detect_face_opencv(video_file)
|
572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
# Calculate metrics if face detected
|
574 |
if deepface_results or face_data:
|
575 |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
|
576 |
+
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context)
|
577 |
|
578 |
# Create a row for the dataframe
|
579 |
row = {
|
|
|
582 |
**calculated_metrics,
|
583 |
**ad_context,
|
584 |
'user_state': user_state,
|
585 |
+
'enhanced_user_state': enhanced_state
|
|
|
586 |
}
|
587 |
metrics_data.append(row)
|
588 |
|
|
|
623 |
metrics_data = []
|
624 |
processed_frames = []
|
625 |
frame_count = 0
|
|
|
|
|
626 |
|
627 |
if show_progress:
|
628 |
print(f"Processing video with {total_frames} frames at {fps} FPS")
|
629 |
print(f"Ad Context: {ad_description} ({ad_type})")
|
|
|
630 |
|
631 |
while True:
|
632 |
ret, frame = cap.read()
|
|
|
646 |
if not deepface_results or "region" not in deepface_results:
|
647 |
face_data = detect_face_opencv(frame)
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
# Calculate metrics if face detected
|
650 |
if deepface_results or face_data:
|
651 |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
|
652 |
+
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context)
|
653 |
|
654 |
# Create a row for the dataframe
|
655 |
row = {
|
|
|
658 |
**calculated_metrics,
|
659 |
**ad_context,
|
660 |
'user_state': user_state,
|
661 |
+
'enhanced_user_state': enhanced_state
|
|
|
662 |
}
|
663 |
metrics_data.append(row)
|
664 |
|
|
|
705 |
ad_context: Dict[str, Any],
|
706 |
metrics_data: pd.DataFrame,
|
707 |
frame_count: int,
|
708 |
+
start_time: float
|
709 |
+
) -> Tuple[np.ndarray, Dict[str, float], str, pd.DataFrame]:
|
|
|
710 |
"""
|
711 |
Process a single webcam frame
|
712 |
|
|
|
716 |
metrics_data: DataFrame to accumulate metrics
|
717 |
frame_count: Current frame count
|
718 |
start_time: Start time of the session
|
|
|
719 |
|
720 |
Returns:
|
721 |
+
Tuple of (annotated_frame, metrics_dict, enhanced_state, updated_metrics_df)
|
722 |
"""
|
723 |
if frame is None:
|
724 |
+
return None, None, None, metrics_data
|
725 |
|
726 |
# Analyze with DeepFace
|
727 |
deepface_results = analyze_face_with_deepface(frame)
|
|
|
731 |
if not deepface_results or "region" not in deepface_results:
|
732 |
face_data = detect_face_opencv(frame)
|
733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
# Calculate metrics if face detected
|
735 |
if deepface_results or face_data:
|
736 |
calculated_metrics = calculate_metrics_from_deepface(deepface_results, ad_context)
|
737 |
+
user_state, enhanced_state = interpret_metrics_with_gemini(calculated_metrics, deepface_results, ad_context)
|
738 |
|
739 |
# Create a row for the dataframe
|
740 |
current_time = time.time()
|
|
|
744 |
**calculated_metrics,
|
745 |
**ad_context,
|
746 |
'user_state': user_state,
|
747 |
+
'enhanced_user_state': enhanced_state
|
|
|
748 |
}
|
749 |
|
750 |
# Add row to DataFrame
|
|
|
754 |
# Annotate the frame
|
755 |
annotated_frame = annotate_frame(frame, face_data, deepface_results, calculated_metrics, enhanced_state)
|
756 |
|
757 |
+
return annotated_frame, calculated_metrics, enhanced_state, metrics_data
|
758 |
else:
|
759 |
# No face detected
|
760 |
no_face_frame = frame.copy()
|
761 |
cv2.putText(no_face_frame, "No face detected", (30, 30),
|
762 |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
|
763 |
+
return no_face_frame, None, "No face detected", metrics_data
|
764 |
|
765 |
def start_webcam_session(
|
766 |
ad_description: str = "",
|
|
|
808 |
"last_saved": 0,
|
809 |
"record_video": record_video,
|
810 |
"recorded_frames": [] if record_video else None,
|
811 |
+
"timestamps": [] if record_video else None
|
|
|
812 |
}
|
813 |
|
814 |
return session
|
|
|
816 |
def update_webcam_session(
|
817 |
session: Dict[str, Any],
|
818 |
frame: np.ndarray
|
819 |
+
) -> Tuple[np.ndarray, Dict[str, float], str, Dict[str, Any]]:
|
820 |
"""
|
821 |
Update webcam session with a new frame
|
822 |
|
|
|
825 |
frame: New frame from webcam
|
826 |
|
827 |
Returns:
|
828 |
+
Tuple of (annotated_frame, metrics_dict, enhanced_state, updated_session)
|
829 |
"""
|
830 |
# Process the frame
|
831 |
+
annotated_frame, metrics, enhanced_state, updated_df = process_webcam_frame(
|
832 |
frame,
|
833 |
session["ad_context"],
|
834 |
session["metrics_data"],
|
835 |
session["frame_count"],
|
836 |
+
session["start_time"]
|
|
|
837 |
)
|
838 |
|
839 |
# Update session
|
840 |
session["frame_count"] += 1
|
841 |
session["metrics_data"] = updated_df
|
|
|
842 |
|
843 |
# Record frame if enabled
|
844 |
if session["record_video"] and annotated_frame is not None:
|
|
|
851 |
updated_df.to_csv(session["csv_path"], index=False)
|
852 |
session["last_saved"] = session["frame_count"]
|
853 |
|
854 |
+
return annotated_frame, metrics, enhanced_state, session
|
855 |
|
856 |
def end_webcam_session(session: Dict[str, Any]) -> Tuple[str, str]:
|
857 |
"""
|
|
|
908 |
def create_api_interface():
|
909 |
with gr.Blocks(title="Facial Analysis APIs") as iface:
|
910 |
gr.Markdown(f"""
|
911 |
+
# Enhanced Facial Analysis APIs (DeepFace)
|
912 |
|
913 |
This interface provides two API endpoints:
|
914 |
|
|
|
916 |
2. **Webcam API**: Analyze live webcam feed in real-time
|
917 |
|
918 |
Both APIs use DeepFace for emotion analysis and Google's Gemini API for enhanced interpretations.
|
|
|
|
|
919 |
""")
|
920 |
|
921 |
with gr.Tab("Video File API"):
|
|
|
1034 |
with gr.Column():
|
1035 |
enhanced_state_txt = gr.Textbox(label="Enhanced State Analysis", lines=3)
|
1036 |
|
|
|
|
|
|
|
1037 |
with gr.Row():
|
1038 |
download_csv = gr.File(label="Download Session Data")
|
1039 |
download_video = gr.Video(label="Recorded Session")
|
|
|
1058 |
|
1059 |
def process_frame(frame, session):
|
1060 |
if session is None:
|
1061 |
+
return frame, None, "No active session. Click 'Start Session' to begin.", session
|
1062 |
|
1063 |
# Process the frame
|
1064 |
+
annotated_frame, metrics, enhanced_state, updated_session = update_webcam_session(session, frame)
|
1065 |
|
1066 |
# Update the metrics plot if metrics available
|
1067 |
if metrics:
|
1068 |
metrics_plot = update_metrics_visualization(metrics)
|
1069 |
+
return annotated_frame, metrics_plot, enhanced_state, updated_session
|
1070 |
else:
|
1071 |
# Return the annotated frame (likely with "No face detected")
|
1072 |
+
return annotated_frame, None, enhanced_state or "No metrics available", updated_session
|
1073 |
|
1074 |
def end_session(session):
|
1075 |
if session is None:
|
|
|
1095 |
webcam_input.stream(
|
1096 |
process_frame,
|
1097 |
inputs=[webcam_input, session_data],
|
1098 |
+
outputs=[processed_output, metrics_plot, enhanced_state_txt, session_data]
|
1099 |
)
|
1100 |
|
1101 |
end_session_btn.click(
|
|
|
1108 |
|
1109 |
# Entry point
|
1110 |
if __name__ == "__main__":
|
1111 |
+
print("Starting Enhanced Facial Analysis API (DeepFace)...")
|
1112 |
print(f"Gemini API {'enabled' if GEMINI_ENABLED else 'disabled (using simulation)'}")
|
|
|
1113 |
iface = create_api_interface()
|
1114 |
iface.launch(debug=True)
|