File size: 3,691 Bytes
7faf1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os

from PIL import Image
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize


class AgePredictResnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet101()
        self.model.fc = nn.Linear(2048, 512)
        self.age_linear1 = nn.Linear(512, 256)
        self.age_linear2 = nn.Linear(256, 128)
        self.age_out = nn.Linear(128, 9)
        self.gender_linear1 = nn.Linear(512, 256)
        self.gender_linear2 = nn.Linear(256, 128)
        self.gender_out = nn.Linear(128, 2)
        self.race_linear1 = nn.Linear(512, 256)
        self.race_linear2 = nn.Linear(256, 128)
        self.race_out = nn.Linear(128, 5)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.4)

    def forward(self, x):
        out = self.activation(self.model(x))
        age_out = self.activation(self.dropout((self.age_linear1(out))))
        age_out = self.activation(self.dropout(self.age_linear2(age_out)))
        age_out = self.age_out(age_out)

        gender_out = self.activation(self.dropout((self.gender_linear1(out))))
        gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
        gender_out = self.gender_out(gender_out)

        race_out = self.activation(self.dropout((self.race_linear1(out))))
        race_out = self.activation(self.dropout(self.race_linear2(race_out)))
        race_out = self.race_out(race_out)
        return age_out, gender_out, race_out


if __name__ == '__main__':
    trained_model_path = os.path.join('./final-models/resnet_101_weigthed.pt')
    model = AgePredictResnet()
    model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')), strict=False)
    model.eval()
    sample_image = Image.open('../../age_prediction/data/wild_images/part1/50_1_1_20170110120147003.jpg')
    transforms = Compose([Resize((256, 256)), ToTensor(),
                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    transformed_image = transforms(sample_image)
    transformed_image = torch.unsqueeze(transformed_image, 0)
    print(transformed_image.shape)
    with torch.inference_mode():
        logits = model(transformed_image)
    age_prob = F.softmax(logits[0], dim=1)
    sex_prob = F.softmax(logits[1], dim=1)
    race_prob = F.softmax(logits[2], dim=1)
    top2_age = torch.topk(age_prob, 2, dim=1)
    sex = torch.argmax(sex_prob, dim=1)
    top2_race = torch.topk(race_prob, 2, dim=1)
    all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
    sex.item(), sex_prob[0][sex.item()].item()), \
        (list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
    print(all_predictions)
    age_dict = {
        0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
        6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
    }
    sex_dict = {0: 'Male', 1: 'Female'}
    race_dict = {
        0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
    }
    #
    pred_dict = {
        'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
        'Age Probability': all_predictions[0][0],
        'Predicted Sex': sex_dict[all_predictions[1][0]],
        'Sex Probability': all_predictions[1][1],
        'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
        'Race Probability': all_predictions[2][0],
    }
    print(pred_dict)