Spaces:
Runtime error
Runtime error
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
import torch | |
from PIL import Image | |
from app.model import pth_model_static, cam, pth_processing | |
from app.face_utils import get_box | |
import mediapipe as mp | |
mp_face_mesh = mp.solutions.face_mesh | |
def preprocess_frame_and_predict_aus(frame): | |
if len(frame.shape) == 2: | |
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) | |
elif frame.shape[2] == 4: | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) | |
with mp_face_mesh.FaceMesh( | |
max_num_faces=1, | |
refine_landmarks=False, | |
min_detection_confidence=0.5, | |
min_tracking_confidence=0.5 | |
) as face_mesh: | |
results = face_mesh.process(frame) | |
if results.multi_face_landmarks: | |
h, w = frame.shape[:2] | |
for fl in results.multi_face_landmarks: | |
startX, startY, endX, endY = get_box(fl, w, h) | |
cur_face = frame[startY:endY, startX:endX] | |
cur_face_n = pth_processing(Image.fromarray(cur_face)) | |
with torch.no_grad(): | |
features = pth_model_static(cur_face_n) | |
au_intensities = features_to_au_intensities(features) | |
grayscale_cam = cam(input_tensor=cur_face_n) | |
grayscale_cam = grayscale_cam[0, :] | |
cur_face_hm = cv2.resize(cur_face, (224, 224)) | |
cur_face_hm = np.float32(cur_face_hm) / 255 | |
heatmap = show_cam_on_image(cur_face_hm, grayscale_cam, use_rgb=True) | |
return cur_face, au_intensities, heatmap | |
return None, None, None | |
def features_to_au_intensities(features): | |
features_np = features.detach().cpu().numpy()[0] | |
au_intensities = (features_np - features_np.min()) / (features_np.max() - features_np.min()) | |
return au_intensities[:24] # Assuming we want 24 AUs | |
def au_statistics_plot(frames, au_intensities_list): | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
au_intensities_array = np.array(au_intensities_list) | |
for i in range(au_intensities_array.shape[1]): | |
ax.plot(frames, au_intensities_array[:, i], label=f'AU{i+1}') | |
ax.set_xlabel('Frame') | |
ax.set_ylabel('AU Intensity') | |
ax.set_title('Action Unit Intensities Over Time') | |
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.tight_layout() | |
return fig |