Spaces:
Runtime error
Runtime error
File size: 6,464 Bytes
d9f5274 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
from PIL import Image
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
class FashionResnet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet152()
self.model.fc = torch.nn.Identity()
def forward(self, x):
return self.model(x)
class FashionClassifictions(nn.Module):
def __init__(self):
super().__init__()
self.gender_linear1 = nn.Linear(2048, 1024)
self.gender_linear2 = nn.Linear(1024, 256)
self.gender_out = nn.Linear(256, 5)
self.mastercat_linear1 = nn.Linear(2048, 1024)
self.mastercat_linear2 = nn.Linear(1024, 256)
self.mastercat_out = nn.Linear(256, 4)
self.subcat_linear1 = nn.Linear(2048, 1024)
self.subcat_linear2 = nn.Linear(1024, 256)
self.subcat_out = nn.Linear(256, 32)
self.color_linear1 = nn.Linear(2048, 1024)
self.color_linear2 = nn.Linear(1024, 256)
self.color_out = nn.Linear(256, 44)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(0.3)
def forward(self, out):
gender_out = self.activation(self.dropout((self.gender_linear1(out))))
gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
gender_out = self.gender_out(gender_out)
master_out = self.activation(self.dropout((self.mastercat_linear1(out))))
master_out = self.activation(self.dropout(self.mastercat_linear2(master_out)))
master_out = self.mastercat_out(master_out)
subcat_out = self.activation(self.dropout((self.subcat_linear1(out))))
subcat_out = self.activation(self.dropout(self.subcat_linear2(subcat_out)))
subcat_out = self.subcat_out(subcat_out)
color_out = self.activation(self.dropout((self.color_linear1(out))))
color_out = self.activation(self.dropout(self.color_linear2(color_out)))
color_out = self.color_out(color_out)
return gender_out, master_out, subcat_out, color_out
class FashionPrediction(nn.Module):
def __init__(self):
super().__init__()
self.feature_model = FashionResnet()
self.classification_model = FashionClassifictions()
def forward(self, x, only_embedding=False):
out_embed = self.feature_model(x)
if only_embedding:
return out_embed
gender_out, master_out, subcat_out, color_out = self.classification_model(out_embed)
return gender_out, master_out, subcat_out, color_out, out_embed
if __name__ == '__main__':
trained_model_path = os.path.join('./data/final-models/resnet_152_classification.pt')
model = FashionPrediction()
# print(model)
model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')))
model.eval()
sample_image = Image.open('./data/small_images/0.jpg')
transforms = Compose([Resize((232, 232)), ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transformed_image = transforms(sample_image)
transformed_image = torch.unsqueeze(transformed_image, 0)
print(transformed_image.shape)
with torch.inference_mode():
logits = model(transformed_image, False)
gender_prob = F.softmax(logits[0], dim=1)
master_prob = F.softmax(logits[1], dim=1)
subcat_prob = F.softmax(logits[2], dim=1)
color_prob = F.softmax(logits[3], dim=1)
top2_gender = torch.topk(gender_prob, 2, dim=1)
top2_master = torch.topk(master_prob, 2, dim=1)
top2_subcat = torch.topk(subcat_prob, 2, dim=1)
top2_color = torch.topk(color_prob, 2, dim=1)
all_predictions = (list(top2_gender.values.numpy().reshape(-1)), list(top2_gender.indices.numpy().reshape(-1))), \
(list(top2_master.values.numpy().reshape(-1)), list(top2_master.indices.numpy().reshape(-1))), \
(list(top2_color.values.numpy().reshape(-1)), list(top2_color.indices.numpy().reshape(-1))), \
(list(top2_subcat.values.numpy().reshape(-1)), list(top2_subcat.indices.numpy().reshape(-1)))
gender_dict = {0: 'Boys', 1: 'Girls', 2: 'Men', 3: 'Unisex', 4: 'Women'}
master_dict = {0: 'Accessories', 1: 'Apparel', 2: 'Footwear', 3: 'Personal Care'}
subcat_dict = {0: 'Accessories', 1: 'Apparel Set', 2: 'Bags', 3: 'Belts', 4: 'Bottomwear', 5: 'Cufflinks',
6: 'Dress', 7: 'Eyes', 8: 'Eyewear', 9: 'Flip Flops', 10: 'Fragrance', 11: 'Headwear',
12: 'Innerwear', 13: 'Jewellery', 14: 'Lips', 15: 'Loungewear and Nightwear', 16: 'Makeup',
17: 'Mufflers', 18: 'Nails', 19: 'Sandal', 20: 'Saree', 21: 'Scarves', 22: 'Shoe Accessories',
23: 'Shoes', 24: 'Skin', 25: 'Skin Care', 26: 'Socks', 27: 'Stoles', 28: 'Ties', 29: 'Topwear',
30: 'Wallets', 31: 'Watches'}
color_dict = {0: 'Beige', 1: 'Black', 2: 'Blue', 3: 'Bronze', 4: 'Brown', 5: 'Burgundy', 6: 'Charcoal',
7: 'Coffee Brown', 8: 'Copper', 9: 'Cream', 10: 'Gold', 11: 'Green', 12: 'Grey', 13: 'Grey Melange',
14: 'Khaki', 15: 'Lavender', 16: 'Magenta', 17: 'Maroon', 18: 'Mauve', 19: 'Metallic', 20: 'Multi',
21: 'Mushroom Brown', 22: 'Mustard', 23: 'Navy Blue', 24: 'Nude', 25: 'Off White', 26: 'Olive',
27: 'Orange', 28: 'Peach', 29: 'Pink', 30: 'Purple', 31: 'Red', 32: 'Rose', 33: 'Rust',
34: 'Sea Green', 35: 'Silver', 36: 'Skin', 37: 'Steel', 38: 'Tan', 39: 'Taupe', 40: 'Teal',
41: 'Turquoise Blue', 42: 'White', 43: 'Yellow'}
print("All predictions:", all_predictions)
pred_dict = {
'Predicted Master Category': (master_dict[all_predictions[1][1][0]], master_dict[all_predictions[1][1][1]]),
'Master Category Probability': all_predictions[1][0],
'Predicted Sub Category': (subcat_dict[all_predictions[3][1][0]], subcat_dict[all_predictions[3][1][1]]),
'Sub Category Probability': all_predictions[3][0],
'Predicted person type': (gender_dict[all_predictions[0][1][0]], gender_dict[all_predictions[0][1][1]]),
'Person Type Probability': all_predictions[0][0],
'Predicted Color': (color_dict[all_predictions[2][1][0]], color_dict[all_predictions[2][1][1]]),
'Color Probability': all_predictions[2][0]
}
print(pred_dict)
|