File size: 2,206 Bytes
7d1312d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
from PIL import Image
from torch import nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from configs import path_ckpt_fairface

# code adapted from https://github.com/dchen236/FairFace

def init_fair_model(device, path_ckpt=None):
    if path_ckpt is None:
        path_ckpt = path_ckpt_fairface
    model_fair_7 = torchvision.models.resnet34(pretrained=False)
    model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
    model_fair_7.load_state_dict(
        torch.load(path_ckpt))
    model_fair_7 = model_fair_7.to(device)
    model_fair_7.eval()
    return model_fair_7


def predict_race(model_fair_7, path_img, device):
    if type(path_img) == str:
        trans = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = Image.open(path_img)
        image = trans(image)
        image = image.view(1, 3, 224, 224)  # reshape image to match model dimensions (1 batch size)
    elif type(path_img) == torch.Tensor:
        trans = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = F.interpolate(path_img, (224, 224))
        image = image * 0.5 + 0.5
        image = trans(image)
        image = image.view(1, 3, 224, 224)

    image = image.to(device)

    outputs = model_fair_7(image)
    outputs = outputs.cpu().detach().numpy()
    outputs = np.squeeze(outputs)

    race_outputs = outputs[:7]
    gender_outputs = outputs[7:9]
    age_outputs = outputs[9:18]

    race_score = np.exp(race_outputs) / np.sum(np.exp(race_outputs))
    gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs))
    age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs))

    race_pred = np.argmax(race_score)
    gender_pred = np.argmax(gender_score)
    age_pred = np.argmax(age_score)
    race_label = ['White', 'Black', 'Latino_Hispanic', 'East Asian', 'Southeast Asian', 'Indian', 'Middle Eastern']
    return race_label[race_pred], race_pred, gender_pred, age_pred