import gradio as gr import cv2 import numpy as np import matplotlib.pyplot as plt from app.app_utils import preprocess_frame_and_predict_aus # Define the AUs associated with stress, anxiety, and depression STRESS_AUS = [4, 7, 17, 23, 24] ANXIETY_AUS = [1, 2, 4, 5, 20] DEPRESSION_AUS = [1, 4, 15, 17] AU_DESCRIPTIONS = { 1: "Inner Brow Raiser", 2: "Outer Brow Raiser", 4: "Brow Lowerer", 5: "Upper Lid Raiser", 7: "Lid Tightener", 15: "Lip Corner Depressor", 17: "Chin Raiser", 20: "Lip Stretcher", 23: "Lip Tightener", 24: "Lip Pressor" } def normalize_score(score): return max(0, min(1, (score + 1.5) / 3)) # Adjust the range as needed def process_video_for_facs(video_path): cap = cv2.VideoCapture(video_path) frames = [] au_intensities_list = [] while True: ret, frame = cap.read() if not ret: break processed_frame, au_intensities, _ = preprocess_frame_and_predict_aus(frame) if processed_frame is not None and au_intensities is not None: frames.append(processed_frame) au_intensities_list.append(au_intensities) cap.release() if not frames: return None, None # Calculate average AU intensities avg_au_intensities = np.mean(au_intensities_list, axis=0) # Calculate and normalize emotional state scores stress_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in STRESS_AUS if au <= len(avg_au_intensities)])) anxiety_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in ANXIETY_AUS if au <= len(avg_au_intensities)])) depression_score = normalize_score(np.mean([avg_au_intensities[au-1] for au in DEPRESSION_AUS if au <= len(avg_au_intensities)])) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10)) # Emotional state scores states = ['Stress', 'Anxiety', 'Depression'] scores = [stress_score, anxiety_score, depression_score] bars = ax1.bar(states, scores) ax1.set_ylim(0, 1) ax1.set_title('Emotional State Scores') for bar in bars: height = bar.get_height() ax1.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}', ha='center', va='bottom') # AU intensities all_aus = sorted(set(STRESS_AUS + ANXIETY_AUS + DEPRESSION_AUS)) all_aus = [au for au in all_aus if au <= len(avg_au_intensities)] au_labels = [f"AU{au}\n{AU_DESCRIPTIONS.get(au, '')}" for au in all_aus] au_values = [avg_au_intensities[au-1] for au in all_aus] ax2.bar(range(len(au_labels)), au_values) ax2.set_xticks(range(len(au_labels))) ax2.set_xticklabels(au_labels, rotation=45, ha='right') ax2.set_ylim(0, 1) ax2.set_title('Average Action Unit Intensities') plt.tight_layout() return frames[-1], fig # Return the last processed frame and the plot def create_facs_analysis_sad_tab(): with gr.Row(): with gr.Column(scale=1): input_video = gr.Video() analyze_btn = gr.Button("Analyze") gr.Examples(["./assets/videos/fitness.mp4"], inputs=[input_video]) with gr.Column(scale=2): output_image = gr.Image(label="Processed Frame") facs_chart = gr.Plot(label="FACS Analysis for SAD") analyze_btn.click( fn=process_video_for_facs, inputs=input_video, outputs=[output_image, facs_chart], ) return input_video, output_image, facs_chart, analyze_btn