ElenaRyumina's picture
Update
b404794
raw
history blame
4.67 kB
"""
File: app_utils.py
Author: Elena Ryumina and Dmitry Ryumin
Description: This module contains utility functions for facial expression recognition application.
License: MIT License
"""
import torch
import numpy as np
import mediapipe as mp
from PIL import Image
import cv2
# Importing necessary components for the Gradio app
from app.model import pth_model_static, pth_model_dynamic, pth_processing
from app.face_utils import get_box, display_info
from app.config import DICT_EMO, config_data
from app.plot import statistics_plot
mp_face_mesh = mp.solutions.face_mesh
def preprocess_image_and_predict(inp):
inp = np.array(inp)
if inp is None:
return None, None
try:
h, w = inp.shape[:2]
except Exception:
return None, None
with mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
) as face_mesh:
results = face_mesh.process(inp)
if results.multi_face_landmarks:
for fl in results.multi_face_landmarks:
startX, startY, endX, endY = get_box(fl, w, h)
cur_face = inp[startY:endY, startX:endX]
cur_face_n = pth_processing(Image.fromarray(cur_face)).to(config_data.DEVICE)
prediction = (
torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1).cpu()
.detach()
.numpy()[0]
)
confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
return cur_face, confidences
def preprocess_video_and_predict(video):
cap = cv2.VideoCapture(video)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = np.round(cap.get(cv2.CAP_PROP_FPS))
path_save_video = 'result.mp4'
vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))
lstm_features = []
count_frame = 1
probs = []
frames = []
last_output = None
with mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5) as face_mesh:
while cap.isOpened():
_, frame = cap.read()
if frame is None: break
frame_copy = frame.copy()
frame_copy.flags.writeable = False
frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
results = face_mesh.process(frame_copy)
frame_copy.flags.writeable = True
if results.multi_face_landmarks:
for fl in results.multi_face_landmarks:
startX, startY, endX, endY = get_box(fl, w, h)
cur_face = frame_copy[startY:endY, startX: endX]
if (count_frame-1)%5 == 0:
cur_face_copy = pth_processing(Image.fromarray(cur_face)).to(config_data.DEVICE)
features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).cpu().detach().numpy()
if len(lstm_features) == 0:
lstm_features = [features]*10
else:
lstm_features = lstm_features[1:] + [features]
lstm_f = torch.from_numpy(np.vstack(lstm_features))
lstm_f = torch.unsqueeze(lstm_f, 0).to(config_data.DEVICE)
output = pth_model_dynamic(lstm_f).cpu().detach().numpy()
last_output = output
else:
if last_output is not None:
output = last_output
elif last_output is None:
output = np.zeros((7))
probs.append(output[0])
frames.append(count_frame)
else:
empty = np.empty((7))
empty[:] = np.nan
probs.append(empty)
frames.append(count_frame)
cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)
cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
count_frame += 1
vid_writer.write(cur_face)
vid_writer.release()
stat = statistics_plot(frames, probs)
if not stat:
return None, None, None
return video, path_save_video, stat