Spaces:
Runtime error
Runtime error
import gradio as gr | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from app.app_utils import preprocess_frame_and_predict_aus | |
# Define the AUs associated with stress, anxiety, and depression | |
STRESS_AUS = [4, 7, 17, 23, 24] | |
ANXIETY_AUS = [1, 2, 4, 5, 20] | |
DEPRESSION_AUS = [1, 4, 15, 17] | |
AU_DESCRIPTIONS = { | |
1: "Inner Brow Raiser", | |
2: "Outer Brow Raiser", | |
4: "Brow Lowerer", | |
5: "Upper Lid Raiser", | |
7: "Lid Tightener", | |
15: "Lip Corner Depressor", | |
17: "Chin Raiser", | |
20: "Lip Stretcher", | |
23: "Lip Tightener", | |
24: "Lip Pressor" | |
} | |
def normalize_score(score): | |
return max(0, min(1, (score + 1.5) / 3)) # Adjust the range as needed | |
def process_video_for_facs(video_path): | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
au_intensities_list = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
processed_frame, au_intensities, _ = preprocess_frame_and_predict_aus(frame) | |
if processed_frame is not None and au_intensities is not None: | |
frames.append(processed_frame) | |
au_intensities_list.append(au_intensities) | |
cap.release() | |
if not frames: | |
return None, None | |
# Calculate average AU intensities | |
avg_au_intensities = np.mean(au_intensities_list, axis=0) | |
# Calculate and normalize emotional state scores | |
stress_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in STRESS_AUS if au <= len(avg_au_intensities)])) | |
anxiety_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in ANXIETY_AUS if au <= len(avg_au_intensities)])) | |
depression_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in DEPRESSION_AUS if au <= len(avg_au_intensities)])) | |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10)) | |
# Emotional state scores | |
states = ['Stress', 'Anxiety', 'Depression'] | |
scores = [stress_score, anxiety_score, depression_score] | |
bars = ax1.bar(states, scores) | |
ax1.set_ylim(0, 1) | |
ax1.set_title('Emotional State Scores') | |
for bar in bars: | |
height = bar.get_height() | |
ax1.text(bar.get_x() + bar.get_width()/2., height, | |
f'{height:.2f}', ha='center', va='bottom') | |
# AU intensities | |
all_aus = sorted(set(STRESS_AUS + ANXIETY_AUS + DEPRESSION_AUS)) | |
all_aus = [au for au in all_aus if au <= len(avg_au_intensities)] | |
au_labels = [f"AU{au}\n{AU_DESCRIPTIONS.get(au, '')}" for au in all_aus] | |
au_values = [avg_au_intensities[au-1] for au in all_aus] | |
ax2.bar(range(len(au_labels)), au_values) | |
ax2.set_xticks(range(len(au_labels))) | |
ax2.set_xticklabels(au_labels, rotation=45, ha='right') | |
ax2.set_ylim(0, 1) | |
ax2.set_title('Average Action Unit Intensities') | |
plt.tight_layout() | |
return frames[-1], fig # Return the last processed frame and the plot | |
def create_facs_analysis_sad_tab(): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_video = gr.Video() | |
analyze_btn = gr.Button("Analyze") | |
gr.Examples(["./assets/videos/fitness.mp4"], inputs=[input_video]) | |
with gr.Column(scale=2): | |
output_image = gr.Image(label="Processed Frame") | |
facs_chart = gr.Plot(label="FACS Analysis for SAD") | |
analyze_btn.click( | |
fn=process_video_for_facs, | |
inputs=input_video, | |
outputs=[output_image, facs_chart], | |
) | |
return input_video, output_image, facs_chart, analyze_btn |