""" File: app_utils.py Author: Elena Ryumina and Dmitry Ryumin Description: This module contains utility functions for facial expression recognition application. License: MIT License """ import torch import numpy as np import mediapipe as mp from PIL import Image import cv2 # Importing necessary components for the Gradio app from app.model import pth_model_static, pth_model_dynamic, pth_processing from app.face_utils import get_box, display_info from app.config import DICT_EMO from app.plot import statistics_plot mp_face_mesh = mp.solutions.face_mesh def preprocess_image_and_predict(inp): inp = np.array(inp) if inp is None: return None, None try: h, w = inp.shape[:2] except Exception: return None, None with mp_face_mesh.FaceMesh( max_num_faces=1, refine_landmarks=False, min_detection_confidence=0.5, min_tracking_confidence=0.5, ) as face_mesh: results = face_mesh.process(inp) if results.multi_face_landmarks: for fl in results.multi_face_landmarks: startX, startY, endX, endY = get_box(fl, w, h) cur_face = inp[startY:endY, startX:endX] cur_face_n = pth_processing(Image.fromarray(cur_face)) prediction = ( torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1) .detach() .numpy()[0] ) confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)} return cur_face, confidences def preprocess_video_and_predict(video): cap = cv2.VideoCapture(video) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = np.round(cap.get(cv2.CAP_PROP_FPS)) path_save_video = 'result.mp4' vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224)) lstm_features = [] count_frame = 1 probs = [] frames = [] last_output = None with mp_face_mesh.FaceMesh( max_num_faces=1, refine_landmarks=False, min_detection_confidence=0.5, min_tracking_confidence=0.5) as face_mesh: while cap.isOpened(): _, frame = cap.read() if frame is None: break frame_copy = frame.copy() frame_copy.flags.writeable = False frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB) results = face_mesh.process(frame_copy) frame_copy.flags.writeable = True if results.multi_face_landmarks: for fl in results.multi_face_landmarks: startX, startY, endX, endY = get_box(fl, w, h) cur_face = frame_copy[startY:endY, startX: endX] if (count_frame-1)%5 == 0: cur_face_copy = pth_processing(Image.fromarray(cur_face)) features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).detach().numpy() if len(lstm_features) == 0: lstm_features = [features]*10 else: lstm_features = lstm_features[1:] + [features] lstm_f = torch.from_numpy(np.vstack(lstm_features)) lstm_f = torch.unsqueeze(lstm_f, 0) output = pth_model_dynamic(lstm_f).detach().numpy() last_output = output else: if last_output is not None: output = last_output elif last_output is None: output = np.zeros((7)) probs.append(output[0]) frames.append(count_frame) else: empty = np.empty((7)) empty[:] = np.nan probs.append(empty) frames.append(count_frame) cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR) cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA) cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3) count_frame += 1 vid_writer.write(cur_face) vid_writer.release() stat = statistics_plot(frames, probs) if not stat: return None, None, None return video, path_save_video, stat