Spaces:
Sleeping
Sleeping
from face_detector import FaceDetector | |
from model_small import ResNet18 | |
import numpy as np | |
import torch | |
from torch import nn | |
from PIL import Image | |
from util import draw_bboxes, draw_label_on_bbox | |
import torchvision.transforms as T | |
class FaceExpressionRecognizer: | |
_DATASET_MEAN = 0.5077385902404785 | |
_DATASET_STD = 0.255077600479126 | |
def __init__(self): | |
self.face_detector = FaceDetector() | |
self.fer_classifier = _make_fer_classifier() | |
self.post_process = T.Compose([ | |
T.Resize((48, 48)), | |
T.Grayscale(), | |
T.ConvertImageDtype(torch.float32), | |
T.Normalize(FaceExpressionRecognizer._DATASET_MEAN, FaceExpressionRecognizer._DATASET_STD) | |
]) | |
self.idx_to_label = { | |
0: 'angry', | |
1: 'disgust', | |
2: 'fear', | |
3: 'happy', | |
4: 'neutral', | |
5: 'sad', | |
6: 'surprise', | |
} | |
def handle_frame(self, image: Image.Image) -> Image.Image: | |
bboxes = self.face_detector.detect_bboxes(image) | |
if bboxes is None: | |
return image | |
extracted_faces = self.face_detector.extract_faces(image, bboxes) | |
extracted_faces = self.post_process(extracted_faces) | |
preds = self.fer_classifier(extracted_faces).argmax(dim=1) | |
print(f'Preds: {preds}') | |
preds = preds.tolist() | |
img_w_boxes = draw_bboxes(image.copy(), bboxes, (255, 0, 0)) | |
image_w_boxes_arr = np.array(img_w_boxes) | |
for bbox, pred in zip(bboxes, preds): | |
image_w_boxes_arr = draw_label_on_bbox(image_w_boxes_arr, bbox, self.idx_to_label[pred]) | |
return Image.fromarray(image_w_boxes_arr) | |
def _make_fer_classifier() -> nn.Module: | |
model = ResNet18(1, 7) | |
# fer_fc = nn.Linear(256, 7) | |
# model = nn.Sequential(*list(model.children())[:-1]) | |
# model = nn.Sequential(*model, fer_fc) | |
model.load_state_dict(torch.load('./saved_models/weighted_sampler200_fer_model.pth', map_location=torch.device('cpu'))) | |
return model | |