File size: 4,672 Bytes
d1b31ce
 
 
 
 
 
 
 
 
 
 
c0f6432
d1b31ce
 
c0f6432
 
b404794
c0f6432
d1b31ce
 
 
 
 
c0f6432
d1b31ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b404794
d1b31ce
b404794
d1b31ce
 
 
 
 
 
c0f6432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b404794
 
c0f6432
 
 
 
 
 
 
b404794
 
c0f6432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
File: app_utils.py
Author: Elena Ryumina and Dmitry Ryumin
Description: This module contains utility functions for facial expression recognition application.
License: MIT License
"""

import torch
import numpy as np
import mediapipe as mp
from PIL import Image
import cv2

# Importing necessary components for the Gradio app
from app.model import pth_model_static, pth_model_dynamic, pth_processing
from app.face_utils import get_box, display_info
from app.config import DICT_EMO, config_data
from app.plot import statistics_plot


mp_face_mesh = mp.solutions.face_mesh


def preprocess_image_and_predict(inp):
    inp = np.array(inp)

    if inp is None:
        return None, None

    try:
        h, w = inp.shape[:2]
    except Exception:
        return None, None

    with mp_face_mesh.FaceMesh(
        max_num_faces=1,
        refine_landmarks=False,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5,
    ) as face_mesh:
        results = face_mesh.process(inp)
        if results.multi_face_landmarks:
            for fl in results.multi_face_landmarks:
                startX, startY, endX, endY = get_box(fl, w, h)
                cur_face = inp[startY:endY, startX:endX]
                cur_face_n = pth_processing(Image.fromarray(cur_face)).to(config_data.DEVICE)
                prediction = (
                    torch.nn.functional.softmax(pth_model_static(cur_face_n), dim=1).cpu()
                    .detach()
                    .numpy()[0]
                )
                confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}

    return cur_face, confidences


def preprocess_video_and_predict(video):

    cap = cv2.VideoCapture(video)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = np.round(cap.get(cv2.CAP_PROP_FPS))

    path_save_video = 'result.mp4'
    vid_writer = cv2.VideoWriter(path_save_video, cv2.VideoWriter_fourcc(*'mp4v'), fps, (224, 224))

    lstm_features = []
    count_frame = 1
    probs = []
    frames = []
    last_output = None 

    with mp_face_mesh.FaceMesh(
    max_num_faces=1,
    refine_landmarks=False,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5) as face_mesh:

        while cap.isOpened():
            _, frame = cap.read()
            if frame is None: break

            frame_copy = frame.copy()
            frame_copy.flags.writeable = False
            frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
            results = face_mesh.process(frame_copy)
            frame_copy.flags.writeable = True

            if results.multi_face_landmarks:
                for fl in results.multi_face_landmarks:
                    startX, startY, endX, endY  = get_box(fl, w, h)
                    cur_face = frame_copy[startY:endY, startX: endX]

                    if (count_frame-1)%5 == 0:
                        cur_face_copy = pth_processing(Image.fromarray(cur_face)).to(config_data.DEVICE)
                        features = torch.nn.functional.relu(pth_model_static.extract_features(cur_face_copy)).cpu().detach().numpy()
        
                        if len(lstm_features) == 0:
                            lstm_features = [features]*10
                        else:
                            lstm_features = lstm_features[1:] + [features]

                        lstm_f = torch.from_numpy(np.vstack(lstm_features))
                        lstm_f = torch.unsqueeze(lstm_f, 0).to(config_data.DEVICE)
                        output = pth_model_dynamic(lstm_f).cpu().detach().numpy()
                        last_output = output
                    else:
                        if last_output is not None:
                            output = last_output
                        elif last_output is None:
                            output = np.zeros((7))
                            
                    probs.append(output[0])
                    frames.append(count_frame)
            else:
                empty = np.empty((7))
                empty[:] = np.nan
                probs.append(empty)
                frames.append(count_frame)                        

            cur_face = cv2.cvtColor(cur_face, cv2.COLOR_RGB2BGR)
            cur_face = cv2.resize(cur_face, (224,224), interpolation = cv2.INTER_AREA)

            cur_face = display_info(cur_face, 'Frame: {}'.format(count_frame), box_scale=.3)
            count_frame += 1
            vid_writer.write(cur_face)

        vid_writer.release()

        stat = statistics_plot(frames, probs)

        if not stat:
            return None, None, None
        
    return video, path_save_video, stat