niks-salodkar's picture
added code and files
d9f5274
import os
from PIL import Image
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
class FashionResnet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet152()
self.model.fc = torch.nn.Identity()
def forward(self, x):
return self.model(x)
class FashionClassifictions(nn.Module):
def __init__(self):
super().__init__()
self.gender_linear1 = nn.Linear(2048, 1024)
self.gender_linear2 = nn.Linear(1024, 256)
self.gender_out = nn.Linear(256, 5)
self.mastercat_linear1 = nn.Linear(2048, 1024)
self.mastercat_linear2 = nn.Linear(1024, 256)
self.mastercat_out = nn.Linear(256, 4)
self.subcat_linear1 = nn.Linear(2048, 1024)
self.subcat_linear2 = nn.Linear(1024, 256)
self.subcat_out = nn.Linear(256, 32)
self.color_linear1 = nn.Linear(2048, 1024)
self.color_linear2 = nn.Linear(1024, 256)
self.color_out = nn.Linear(256, 44)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(0.3)
def forward(self, out):
gender_out = self.activation(self.dropout((self.gender_linear1(out))))
gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
gender_out = self.gender_out(gender_out)
master_out = self.activation(self.dropout((self.mastercat_linear1(out))))
master_out = self.activation(self.dropout(self.mastercat_linear2(master_out)))
master_out = self.mastercat_out(master_out)
subcat_out = self.activation(self.dropout((self.subcat_linear1(out))))
subcat_out = self.activation(self.dropout(self.subcat_linear2(subcat_out)))
subcat_out = self.subcat_out(subcat_out)
color_out = self.activation(self.dropout((self.color_linear1(out))))
color_out = self.activation(self.dropout(self.color_linear2(color_out)))
color_out = self.color_out(color_out)
return gender_out, master_out, subcat_out, color_out
class FashionPrediction(nn.Module):
def __init__(self):
super().__init__()
self.feature_model = FashionResnet()
self.classification_model = FashionClassifictions()
def forward(self, x, only_embedding=False):
out_embed = self.feature_model(x)
if only_embedding:
return out_embed
gender_out, master_out, subcat_out, color_out = self.classification_model(out_embed)
return gender_out, master_out, subcat_out, color_out, out_embed
if __name__ == '__main__':
trained_model_path = os.path.join('./data/final-models/resnet_152_classification.pt')
model = FashionPrediction()
# print(model)
model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')))
model.eval()
sample_image = Image.open('./data/small_images/0.jpg')
transforms = Compose([Resize((232, 232)), ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transformed_image = transforms(sample_image)
transformed_image = torch.unsqueeze(transformed_image, 0)
print(transformed_image.shape)
with torch.inference_mode():
logits = model(transformed_image, False)
gender_prob = F.softmax(logits[0], dim=1)
master_prob = F.softmax(logits[1], dim=1)
subcat_prob = F.softmax(logits[2], dim=1)
color_prob = F.softmax(logits[3], dim=1)
top2_gender = torch.topk(gender_prob, 2, dim=1)
top2_master = torch.topk(master_prob, 2, dim=1)
top2_subcat = torch.topk(subcat_prob, 2, dim=1)
top2_color = torch.topk(color_prob, 2, dim=1)
all_predictions = (list(top2_gender.values.numpy().reshape(-1)), list(top2_gender.indices.numpy().reshape(-1))), \
(list(top2_master.values.numpy().reshape(-1)), list(top2_master.indices.numpy().reshape(-1))), \
(list(top2_color.values.numpy().reshape(-1)), list(top2_color.indices.numpy().reshape(-1))), \
(list(top2_subcat.values.numpy().reshape(-1)), list(top2_subcat.indices.numpy().reshape(-1)))
gender_dict = {0: 'Boys', 1: 'Girls', 2: 'Men', 3: 'Unisex', 4: 'Women'}
master_dict = {0: 'Accessories', 1: 'Apparel', 2: 'Footwear', 3: 'Personal Care'}
subcat_dict = {0: 'Accessories', 1: 'Apparel Set', 2: 'Bags', 3: 'Belts', 4: 'Bottomwear', 5: 'Cufflinks',
6: 'Dress', 7: 'Eyes', 8: 'Eyewear', 9: 'Flip Flops', 10: 'Fragrance', 11: 'Headwear',
12: 'Innerwear', 13: 'Jewellery', 14: 'Lips', 15: 'Loungewear and Nightwear', 16: 'Makeup',
17: 'Mufflers', 18: 'Nails', 19: 'Sandal', 20: 'Saree', 21: 'Scarves', 22: 'Shoe Accessories',
23: 'Shoes', 24: 'Skin', 25: 'Skin Care', 26: 'Socks', 27: 'Stoles', 28: 'Ties', 29: 'Topwear',
30: 'Wallets', 31: 'Watches'}
color_dict = {0: 'Beige', 1: 'Black', 2: 'Blue', 3: 'Bronze', 4: 'Brown', 5: 'Burgundy', 6: 'Charcoal',
7: 'Coffee Brown', 8: 'Copper', 9: 'Cream', 10: 'Gold', 11: 'Green', 12: 'Grey', 13: 'Grey Melange',
14: 'Khaki', 15: 'Lavender', 16: 'Magenta', 17: 'Maroon', 18: 'Mauve', 19: 'Metallic', 20: 'Multi',
21: 'Mushroom Brown', 22: 'Mustard', 23: 'Navy Blue', 24: 'Nude', 25: 'Off White', 26: 'Olive',
27: 'Orange', 28: 'Peach', 29: 'Pink', 30: 'Purple', 31: 'Red', 32: 'Rose', 33: 'Rust',
34: 'Sea Green', 35: 'Silver', 36: 'Skin', 37: 'Steel', 38: 'Tan', 39: 'Taupe', 40: 'Teal',
41: 'Turquoise Blue', 42: 'White', 43: 'Yellow'}
print("All predictions:", all_predictions)
pred_dict = {
'Predicted Master Category': (master_dict[all_predictions[1][1][0]], master_dict[all_predictions[1][1][1]]),
'Master Category Probability': all_predictions[1][0],
'Predicted Sub Category': (subcat_dict[all_predictions[3][1][0]], subcat_dict[all_predictions[3][1][1]]),
'Sub Category Probability': all_predictions[3][0],
'Predicted person type': (gender_dict[all_predictions[0][1][0]], gender_dict[all_predictions[0][1][1]]),
'Person Type Probability': all_predictions[0][0],
'Predicted Color': (color_dict[all_predictions[2][1][0]], color_dict[all_predictions[2][1][1]]),
'Color Probability': all_predictions[2][0]
}
print(pred_dict)