niks-salodkar's picture
added code and files
7faf1c4
import os
from PIL import Image
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
class AgePredictResnet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet101()
self.model.fc = nn.Linear(2048, 512)
self.age_linear1 = nn.Linear(512, 256)
self.age_linear2 = nn.Linear(256, 128)
self.age_out = nn.Linear(128, 9)
self.gender_linear1 = nn.Linear(512, 256)
self.gender_linear2 = nn.Linear(256, 128)
self.gender_out = nn.Linear(128, 2)
self.race_linear1 = nn.Linear(512, 256)
self.race_linear2 = nn.Linear(256, 128)
self.race_out = nn.Linear(128, 5)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(0.4)
def forward(self, x):
out = self.activation(self.model(x))
age_out = self.activation(self.dropout((self.age_linear1(out))))
age_out = self.activation(self.dropout(self.age_linear2(age_out)))
age_out = self.age_out(age_out)
gender_out = self.activation(self.dropout((self.gender_linear1(out))))
gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
gender_out = self.gender_out(gender_out)
race_out = self.activation(self.dropout((self.race_linear1(out))))
race_out = self.activation(self.dropout(self.race_linear2(race_out)))
race_out = self.race_out(race_out)
return age_out, gender_out, race_out
if __name__ == '__main__':
trained_model_path = os.path.join('./final-models/resnet_101_weigthed.pt')
model = AgePredictResnet()
model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')), strict=False)
model.eval()
sample_image = Image.open('../../age_prediction/data/wild_images/part1/50_1_1_20170110120147003.jpg')
transforms = Compose([Resize((256, 256)), ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transformed_image = transforms(sample_image)
transformed_image = torch.unsqueeze(transformed_image, 0)
print(transformed_image.shape)
with torch.inference_mode():
logits = model(transformed_image)
age_prob = F.softmax(logits[0], dim=1)
sex_prob = F.softmax(logits[1], dim=1)
race_prob = F.softmax(logits[2], dim=1)
top2_age = torch.topk(age_prob, 2, dim=1)
sex = torch.argmax(sex_prob, dim=1)
top2_race = torch.topk(race_prob, 2, dim=1)
all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
sex.item(), sex_prob[0][sex.item()].item()), \
(list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
print(all_predictions)
age_dict = {
0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
}
sex_dict = {0: 'Male', 1: 'Female'}
race_dict = {
0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
}
#
pred_dict = {
'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
'Age Probability': all_predictions[0][0],
'Predicted Sex': sex_dict[all_predictions[1][0]],
'Sex Probability': all_predictions[1][1],
'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
'Race Probability': all_predictions[2][0],
}
print(pred_dict)