Depreesion / tabs /FACS_analysis_sad.py
vitorcalvi's picture
pre-launch
fc286f6
import gradio as gr
import cv2
import numpy as np
import matplotlib.pyplot as plt
from app.app_utils import preprocess_frame_and_predict_aus
# Define the AUs associated with stress, anxiety, and depression
STRESS_AUS = [4, 7, 17, 23, 24]
ANXIETY_AUS = [1, 2, 4, 5, 20]
DEPRESSION_AUS = [1, 4, 15, 17]
AU_DESCRIPTIONS = {
1: "Inner Brow Raiser",
2: "Outer Brow Raiser",
4: "Brow Lowerer",
5: "Upper Lid Raiser",
7: "Lid Tightener",
15: "Lip Corner Depressor",
17: "Chin Raiser",
20: "Lip Stretcher",
23: "Lip Tightener",
24: "Lip Pressor"
}
def normalize_score(score):
return max(0, min(1, (score + 1.5) / 3)) # Adjust the range as needed
def process_video_for_facs(video_path):
cap = cv2.VideoCapture(video_path)
frames = []
au_intensities_list = []
while True:
ret, frame = cap.read()
if not ret:
break
processed_frame, au_intensities, _ = preprocess_frame_and_predict_aus(frame)
if processed_frame is not None and au_intensities is not None:
frames.append(processed_frame)
au_intensities_list.append(au_intensities)
cap.release()
if not frames:
return None, None
# Calculate average AU intensities
avg_au_intensities = np.mean(au_intensities_list, axis=0)
# Calculate and normalize emotional state scores
stress_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in STRESS_AUS if au <= len(avg_au_intensities)]))
anxiety_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in ANXIETY_AUS if au <= len(avg_au_intensities)]))
depression_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in DEPRESSION_AUS if au <= len(avg_au_intensities)]))
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
# Emotional state scores
states = ['Stress', 'Anxiety', 'Depression']
scores = [stress_score, anxiety_score, depression_score]
bars = ax1.bar(states, scores)
ax1.set_ylim(0, 1)
ax1.set_title('Emotional State Scores')
for bar in bars:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2f}', ha='center', va='bottom')
# AU intensities
all_aus = sorted(set(STRESS_AUS + ANXIETY_AUS + DEPRESSION_AUS))
all_aus = [au for au in all_aus if au <= len(avg_au_intensities)]
au_labels = [f"AU{au}\n{AU_DESCRIPTIONS.get(au, '')}" for au in all_aus]
au_values = [avg_au_intensities[au-1] for au in all_aus]
ax2.bar(range(len(au_labels)), au_values)
ax2.set_xticks(range(len(au_labels)))
ax2.set_xticklabels(au_labels, rotation=45, ha='right')
ax2.set_ylim(0, 1)
ax2.set_title('Average Action Unit Intensities')
plt.tight_layout()
return frames[-1], fig # Return the last processed frame and the plot
def create_facs_analysis_sad_tab():
with gr.Row():
with gr.Column(scale=1):
input_video = gr.Video()
analyze_btn = gr.Button("Analyze")
gr.Examples(["./assets/videos/fitness.mp4"], inputs=[input_video])
with gr.Column(scale=2):
output_image = gr.Image(label="Processed Frame")
facs_chart = gr.Plot(label="FACS Analysis for SAD")
analyze_btn.click(
fn=process_video_for_facs,
inputs=input_video,
outputs=[output_image, facs_chart],
)
return input_video, output_image, facs_chart, analyze_btn