fer-demo-1 / fer.py
jlynxdev's picture
Upload 13 files
b066d77 verified
from face_detector import FaceDetector
from model_small import ResNet18
import numpy as np
import torch
from torch import nn
from PIL import Image
from util import draw_bboxes, draw_label_on_bbox
import torchvision.transforms as T
class FaceExpressionRecognizer:
_DATASET_MEAN = 0.5077385902404785
_DATASET_STD = 0.255077600479126
def __init__(self):
self.face_detector = FaceDetector()
self.fer_classifier = _make_fer_classifier()
self.post_process = T.Compose([
T.Resize((48, 48)),
T.Grayscale(),
T.ConvertImageDtype(torch.float32),
T.Normalize(FaceExpressionRecognizer._DATASET_MEAN, FaceExpressionRecognizer._DATASET_STD)
])
self.idx_to_label = {
0: 'angry',
1: 'disgust',
2: 'fear',
3: 'happy',
4: 'neutral',
5: 'sad',
6: 'surprise',
}
def handle_frame(self, image: Image.Image) -> Image.Image:
bboxes = self.face_detector.detect_bboxes(image)
if bboxes is None:
return image
extracted_faces = self.face_detector.extract_faces(image, bboxes)
extracted_faces = self.post_process(extracted_faces)
preds = self.fer_classifier(extracted_faces).argmax(dim=1)
print(f'Preds: {preds}')
preds = preds.tolist()
img_w_boxes = draw_bboxes(image.copy(), bboxes, (255, 0, 0))
image_w_boxes_arr = np.array(img_w_boxes)
for bbox, pred in zip(bboxes, preds):
image_w_boxes_arr = draw_label_on_bbox(image_w_boxes_arr, bbox, self.idx_to_label[pred])
return Image.fromarray(image_w_boxes_arr)
def _make_fer_classifier() -> nn.Module:
model = ResNet18(1, 7)
# fer_fc = nn.Linear(256, 7)
# model = nn.Sequential(*list(model.children())[:-1])
# model = nn.Sequential(*model, fer_fc)
model.load_state_dict(torch.load('./saved_models/weighted_sampler200_fer_model.pth', map_location=torch.device('cpu')))
return model