File size: 6,464 Bytes
d9f5274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os

from PIL import Image
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize


class FashionResnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet152()
        self.model.fc = torch.nn.Identity()

    def forward(self, x):
        return self.model(x)


class FashionClassifictions(nn.Module):
    def __init__(self):
        super().__init__()
        self.gender_linear1 = nn.Linear(2048, 1024)
        self.gender_linear2 = nn.Linear(1024, 256)
        self.gender_out = nn.Linear(256, 5)

        self.mastercat_linear1 = nn.Linear(2048, 1024)
        self.mastercat_linear2 = nn.Linear(1024, 256)
        self.mastercat_out = nn.Linear(256, 4)

        self.subcat_linear1 = nn.Linear(2048, 1024)
        self.subcat_linear2 = nn.Linear(1024, 256)
        self.subcat_out = nn.Linear(256, 32)

        self.color_linear1 = nn.Linear(2048, 1024)
        self.color_linear2 = nn.Linear(1024, 256)
        self.color_out = nn.Linear(256, 44)

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

    def forward(self, out):
        gender_out = self.activation(self.dropout((self.gender_linear1(out))))
        gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
        gender_out = self.gender_out(gender_out)

        master_out = self.activation(self.dropout((self.mastercat_linear1(out))))
        master_out = self.activation(self.dropout(self.mastercat_linear2(master_out)))
        master_out = self.mastercat_out(master_out)

        subcat_out = self.activation(self.dropout((self.subcat_linear1(out))))
        subcat_out = self.activation(self.dropout(self.subcat_linear2(subcat_out)))
        subcat_out = self.subcat_out(subcat_out)

        color_out = self.activation(self.dropout((self.color_linear1(out))))
        color_out = self.activation(self.dropout(self.color_linear2(color_out)))
        color_out = self.color_out(color_out)

        return gender_out, master_out, subcat_out, color_out


class FashionPrediction(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_model = FashionResnet()
        self.classification_model = FashionClassifictions()

    def forward(self, x, only_embedding=False):
        out_embed = self.feature_model(x)
        if only_embedding:
            return out_embed
        gender_out, master_out, subcat_out, color_out = self.classification_model(out_embed)
        return gender_out, master_out, subcat_out, color_out, out_embed


if __name__ == '__main__':
    trained_model_path = os.path.join('./data/final-models/resnet_152_classification.pt')
    model = FashionPrediction()
    # print(model)
    model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')))
    model.eval()
    sample_image = Image.open('./data/small_images/0.jpg')
    transforms = Compose([Resize((232, 232)), ToTensor(),
                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    transformed_image = transforms(sample_image)
    transformed_image = torch.unsqueeze(transformed_image, 0)
    print(transformed_image.shape)
    with torch.inference_mode():
        logits = model(transformed_image, False)
    gender_prob = F.softmax(logits[0], dim=1)
    master_prob = F.softmax(logits[1], dim=1)
    subcat_prob = F.softmax(logits[2], dim=1)
    color_prob = F.softmax(logits[3], dim=1)
    top2_gender = torch.topk(gender_prob, 2, dim=1)
    top2_master = torch.topk(master_prob, 2, dim=1)
    top2_subcat = torch.topk(subcat_prob, 2, dim=1)
    top2_color = torch.topk(color_prob, 2, dim=1)

    all_predictions = (list(top2_gender.values.numpy().reshape(-1)), list(top2_gender.indices.numpy().reshape(-1))), \
        (list(top2_master.values.numpy().reshape(-1)), list(top2_master.indices.numpy().reshape(-1))), \
        (list(top2_color.values.numpy().reshape(-1)), list(top2_color.indices.numpy().reshape(-1))), \
        (list(top2_subcat.values.numpy().reshape(-1)), list(top2_subcat.indices.numpy().reshape(-1)))

    gender_dict = {0: 'Boys', 1: 'Girls', 2: 'Men', 3: 'Unisex', 4: 'Women'}
    master_dict = {0: 'Accessories', 1: 'Apparel', 2: 'Footwear', 3: 'Personal Care'}
    subcat_dict = {0: 'Accessories', 1: 'Apparel Set', 2: 'Bags', 3: 'Belts', 4: 'Bottomwear', 5: 'Cufflinks',
                   6: 'Dress', 7: 'Eyes', 8: 'Eyewear', 9: 'Flip Flops', 10: 'Fragrance', 11: 'Headwear',
                   12: 'Innerwear', 13: 'Jewellery', 14: 'Lips', 15: 'Loungewear and Nightwear', 16: 'Makeup',
                   17: 'Mufflers', 18: 'Nails', 19: 'Sandal', 20: 'Saree', 21: 'Scarves', 22: 'Shoe Accessories',
                   23: 'Shoes', 24: 'Skin', 25: 'Skin Care', 26: 'Socks', 27: 'Stoles', 28: 'Ties', 29: 'Topwear',
                   30: 'Wallets', 31: 'Watches'}
    color_dict = {0: 'Beige', 1: 'Black', 2: 'Blue', 3: 'Bronze', 4: 'Brown', 5: 'Burgundy', 6: 'Charcoal',
                  7: 'Coffee Brown', 8: 'Copper', 9: 'Cream', 10: 'Gold', 11: 'Green', 12: 'Grey', 13: 'Grey Melange',
                  14: 'Khaki', 15: 'Lavender', 16: 'Magenta', 17: 'Maroon', 18: 'Mauve', 19: 'Metallic', 20: 'Multi',
                  21: 'Mushroom Brown', 22: 'Mustard', 23: 'Navy Blue', 24: 'Nude', 25: 'Off White', 26: 'Olive',
                  27: 'Orange', 28: 'Peach', 29: 'Pink', 30: 'Purple', 31: 'Red', 32: 'Rose', 33: 'Rust',
                  34: 'Sea Green', 35: 'Silver', 36: 'Skin', 37: 'Steel', 38: 'Tan', 39: 'Taupe', 40: 'Teal',
                  41: 'Turquoise Blue', 42: 'White', 43: 'Yellow'}
    print("All predictions:", all_predictions)

    pred_dict = {
        'Predicted Master Category': (master_dict[all_predictions[1][1][0]], master_dict[all_predictions[1][1][1]]),
        'Master Category Probability': all_predictions[1][0],
        'Predicted Sub Category': (subcat_dict[all_predictions[3][1][0]], subcat_dict[all_predictions[3][1][1]]),
        'Sub Category Probability': all_predictions[3][0],
        'Predicted person type': (gender_dict[all_predictions[0][1][0]], gender_dict[all_predictions[0][1][1]]),
        'Person Type Probability': all_predictions[0][0],
        'Predicted Color': (color_dict[all_predictions[2][1][0]], color_dict[all_predictions[2][1][1]]),
        'Color Probability': all_predictions[2][0]
    }
    print(pred_dict)