File size: 4,750 Bytes
d9f5274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05793f6
d9f5274
05793f6
d9f5274
05793f6
d9f5274
05793f6
d9f5274
 
 
 
 
 
05793f6
 
 
d9f5274
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from PIL import Image
import faiss
import streamlit as st
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

from model import FashionPrediction
transforms = Compose([Resize((232, 232)), ToTensor(),
                          Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

path = './data/final-models/resnet_152_classification.pt'
index_path = os.path.join('./data/index_files/resnet152_unweighted_flat.index')

gender_dict = {0: 'Boys', 1: 'Girls', 2: 'Men', 3: 'Unisex', 4: 'Women'}
master_dict = {0: 'Accessories', 1: 'Apparel', 2: 'Footwear', 3: 'Personal Care'}
subcat_dict = {0: 'Accessories', 1: 'Apparel Set', 2: 'Bags', 3: 'Belts', 4: 'Bottomwear', 5: 'Cufflinks',
               6: 'Dress', 7: 'Eyes', 8: 'Eyewear', 9: 'Flip Flops', 10: 'Fragrance', 11: 'Headwear',
               12: 'Innerwear', 13: 'Jewellery', 14: 'Lips', 15: 'Loungewear and Nightwear', 16: 'Makeup',
               17: 'Mufflers', 18: 'Nails', 19: 'Sandal', 20: 'Saree', 21: 'Scarves', 22: 'Shoe Accessories',
               23: 'Shoes', 24: 'Skin', 25: 'Skin Care', 26: 'Socks', 27: 'Stoles', 28: 'Ties', 29: 'Topwear',
               30: 'Wallets', 31: 'Watches'}
color_dict = {0: 'Beige', 1: 'Black', 2: 'Blue', 3: 'Bronze', 4: 'Brown', 5: 'Burgundy', 6: 'Charcoal',
              7: 'Coffee Brown', 8: 'Copper', 9: 'Cream', 10: 'Gold', 11: 'Green', 12: 'Grey', 13: 'Grey Melange',
              14: 'Khaki', 15: 'Lavender', 16: 'Magenta', 17: 'Maroon', 18: 'Mauve', 19: 'Metallic', 20: 'Multi',
              21: 'Mushroom Brown', 22: 'Mustard', 23: 'Navy Blue', 24: 'Nude', 25: 'Off White', 26: 'Olive',
              27: 'Orange', 28: 'Peach', 29: 'Pink', 30: 'Purple', 31: 'Red', 32: 'Rose', 33: 'Rust',
              34: 'Sea Green', 35: 'Silver', 36: 'Skin', 37: 'Steel', 38: 'Tan', 39: 'Taupe', 40: 'Teal',
              41: 'Turquoise Blue', 42: 'White', 43: 'Yellow'}


@st.experimental_memo
def load_trained_model(model_path):
    model = FashionPrediction()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model


@st.experimental_memo
def load_index():
    index = faiss.read_index(index_path)
    return index


def get_nearest_k(input_image, top_k=5):
    model = load_trained_model(path)
    transformed_image = transforms(input_image)
    transformed_image = torch.unsqueeze(transformed_image, 0)
    with torch.inference_mode():
        query_embeddings = model(transformed_image, True)
    the_index = load_index()
    dist, indexes = the_index.search(query_embeddings, top_k)
    return dist, indexes


def get_predictions(input_image):
    model = load_trained_model(path)
    transformed_image = transforms(input_image)
    transformed_image = torch.unsqueeze(transformed_image, 0)
    with torch.inference_mode():
        logits = model(transformed_image, False)
    gender_prob = F.softmax(logits[0], dim=1)
    master_prob = F.softmax(logits[1], dim=1)
    subcat_prob = F.softmax(logits[2], dim=1)
    color_prob = F.softmax(logits[3], dim=1)
    top2_gender = torch.topk(gender_prob, 2, dim=1)
    top2_master = torch.topk(master_prob, 2, dim=1)
    top2_subcat = torch.topk(subcat_prob, 2, dim=1)
    top2_color = torch.topk(color_prob, 2, dim=1)

    all_predictions = (list(top2_gender.values.numpy().reshape(-1)), list(top2_gender.indices.numpy().reshape(-1))), \
        (list(top2_master.values.numpy().reshape(-1)), list(top2_master.indices.numpy().reshape(-1))), \
        (list(top2_color.values.numpy().reshape(-1)), list(top2_color.indices.numpy().reshape(-1))), \
        (list(top2_subcat.values.numpy().reshape(-1)), list(top2_subcat.indices.numpy().reshape(-1)))

    pred_dict = {
        'Predicted Master Category': (master_dict[all_predictions[1][1][0]], master_dict[all_predictions[1][1][1]]),
        'Master Category Probability': [round(prob, 3) for prob in all_predictions[1][0]],
        'Predicted Sub Category': (subcat_dict[all_predictions[3][1][0]], subcat_dict[all_predictions[3][1][1]]),
        'Sub Category Probability': [round(prob, 3) for prob in all_predictions[3][0]],
        'Predicted person type': (gender_dict[all_predictions[0][1][0]], gender_dict[all_predictions[0][1][1]]),
        'Person Type Probability': [round(prob, 3) for prob in all_predictions[0][0]],
        'Predicted Color': (color_dict[all_predictions[2][1][0]], color_dict[all_predictions[2][1][1]]),
        'Color Probability': [round(prob, 3) for prob in all_predictions[2][0]]
    }

    return pred_dict


if __name__ == '__main__':
    sample_image = Image.open('./data/small_images_0_9999/0.jpg')
    output = get_predictions(sample_image)
    print(output)