|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
import mediapipe as mp |
|
import numpy as np |
|
import math |
|
import requests |
|
|
|
import gradio as gr |
|
|
|
model_url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth" |
|
model_path = "FER_static_ResNet50_AffectNet.pth" |
|
|
|
response = requests.get(model_url, stream=True) |
|
with open(model_path, 'wb') as file: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
file.write(chunk) |
|
|
|
pth_model = torch.jit.load(model_path) |
|
pth_model.eval() |
|
|
|
DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'} |
|
|
|
mp_face_mesh = mp.solutions.face_mesh |
|
|
|
def pth_processing(fp): |
|
class PreprocessInput(torch.nn.Module): |
|
def init(self): |
|
super(PreprocessInput, self).init() |
|
|
|
def forward(self, x): |
|
x = x.to(torch.float32) |
|
x = torch.flip(x, dims=(0,)) |
|
x[0, :, :] -= 91.4953 |
|
x[1, :, :] -= 103.8827 |
|
x[2, :, :] -= 131.0912 |
|
return x |
|
|
|
def get_img_torch(img): |
|
|
|
ttransform = transforms.Compose([ |
|
transforms.PILToTensor(), |
|
PreprocessInput() |
|
]) |
|
img = img.resize((224, 224), Image.Resampling.NEAREST) |
|
img = ttransform(img) |
|
img = torch.unsqueeze(img, 0) |
|
return img |
|
return get_img_torch(fp) |
|
|
|
def norm_coordinates(normalized_x, normalized_y, image_width, image_height): |
|
|
|
x_px = min(math.floor(normalized_x * image_width), image_width - 1) |
|
y_px = min(math.floor(normalized_y * image_height), image_height - 1) |
|
|
|
return x_px, y_px |
|
|
|
def get_box(fl, w, h): |
|
idx_to_coors = {} |
|
for idx, landmark in enumerate(fl.landmark): |
|
landmark_px = norm_coordinates(landmark.x, landmark.y, w, h) |
|
|
|
if landmark_px: |
|
idx_to_coors[idx] = landmark_px |
|
|
|
x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0]) |
|
y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1]) |
|
endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0]) |
|
endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1]) |
|
|
|
(startX, startY) = (max(0, x_min), max(0, y_min)) |
|
(endX, endY) = (min(w - 1, endX), min(h - 1, endY)) |
|
|
|
return startX, startY, endX, endY |
|
|
|
def predict(inp): |
|
|
|
inp = np.array(inp) |
|
h, w = inp.shape[:2] |
|
|
|
with mp_face_mesh.FaceMesh( |
|
max_num_faces=1, |
|
refine_landmarks=False, |
|
min_detection_confidence=0.5, |
|
min_tracking_confidence=0.5) as face_mesh: |
|
results = face_mesh.process(inp) |
|
if results.multi_face_landmarks: |
|
for fl in results.multi_face_landmarks: |
|
startX, startY, endX, endY = get_box(fl, w, h) |
|
cur_face = inp[startY:endY, startX: endX] |
|
cur_face_n = pth_processing(Image.fromarray(cur_face)) |
|
prediction = torch.nn.functional.softmax(pth_model(cur_face_n), dim=1).detach().numpy()[0] |
|
confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)} |
|
|
|
return cur_face, confidences |
|
|
|
def clear(): |
|
return ( |
|
gr.Image(value=None, type="pil"), |
|
gr.Image(value=None,scale=1, elem_classes="dl2"), |
|
gr.Label(value=None,num_top_classes=3, scale=1, elem_classes="dl3") |
|
) |
|
|
|
style = """ |
|
div.dl1 div.upload-container { |
|
height: 350px; |
|
max-height: 350px; |
|
} |
|
|
|
div.dl2 { |
|
max-height: 200px; |
|
} |
|
|
|
div.dl2 img { |
|
max-height: 200px; |
|
} |
|
|
|
.submit { |
|
display: inline-block; |
|
padding: 10px 20px; |
|
font-size: 16px; |
|
font-weight: bold; |
|
text-align: center; |
|
text-decoration: none; |
|
cursor: pointer; |
|
border: var(--button-border-width) solid var(--button-primary-border-color); |
|
background: var(--button-primary-background-fill); |
|
color: var(--button-primary-text-color); |
|
border-radius: 8px; |
|
transition: all 0.3s ease; |
|
} |
|
|
|
.submit[disabled] { |
|
cursor: not-allowed; |
|
opacity: 0.6; |
|
} |
|
|
|
.submit:hover:not([disabled]) { |
|
border-color: var(--button-primary-border-color-hover); |
|
background: var(--button-primary-background-fill-hover); |
|
color: var(--button-primary-text-color-hover); |
|
} |
|
|
|
.submit:active:not([disabled]) { |
|
transform: scale(0.98); |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=style) as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=2, elem_classes="dl1"): |
|
input_image = gr.Image(type="pil") |
|
with gr.Row(): |
|
submit = gr.Button( |
|
value="Submit", interactive=True, scale=1, elem_classes="submit" |
|
) |
|
clear_btn = gr.Button( |
|
value="Clear", interactive=True, scale=1 |
|
) |
|
with gr.Column(scale=1, elem_classes="dl4"): |
|
output_image = gr.Image(scale=1, elem_classes="dl2") |
|
output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3") |
|
gr.Examples( |
|
["images/fig7.jpg", "images/fig1.jpg", "images/fig2.jpg","images/fig3.jpg", |
|
"images/fig4.jpg", "images/fig5.jpg", "images/fig6.jpg"], |
|
[input_image], |
|
) |
|
|
|
|
|
submit.click( |
|
fn=predict, |
|
inputs=[input_image], |
|
outputs=[ |
|
output_image, |
|
output_label |
|
], |
|
queue=True, |
|
) |
|
clear_btn.click( |
|
fn=clear, |
|
inputs=[], |
|
outputs=[ |
|
input_image, |
|
output_image, |
|
output_label, |
|
], |
|
queue=True, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(api_open=False).launch(share=False) |