Spaces:
Runtime error
Runtime error
Last commit not found
import os | |
from PIL import Image | |
import torch | |
import torchvision | |
from torch import nn | |
import torch.nn.functional as F | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
class AgePredictResnet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = torchvision.models.resnet101() | |
self.model.fc = nn.Linear(2048, 512) | |
self.age_linear1 = nn.Linear(512, 256) | |
self.age_linear2 = nn.Linear(256, 128) | |
self.age_out = nn.Linear(128, 9) | |
self.gender_linear1 = nn.Linear(512, 256) | |
self.gender_linear2 = nn.Linear(256, 128) | |
self.gender_out = nn.Linear(128, 2) | |
self.race_linear1 = nn.Linear(512, 256) | |
self.race_linear2 = nn.Linear(256, 128) | |
self.race_out = nn.Linear(128, 5) | |
self.activation = nn.ReLU() | |
self.dropout = nn.Dropout(0.4) | |
def forward(self, x): | |
out = self.activation(self.model(x)) | |
age_out = self.activation(self.dropout((self.age_linear1(out)))) | |
age_out = self.activation(self.dropout(self.age_linear2(age_out))) | |
age_out = self.age_out(age_out) | |
gender_out = self.activation(self.dropout((self.gender_linear1(out)))) | |
gender_out = self.activation(self.dropout(self.gender_linear2(gender_out))) | |
gender_out = self.gender_out(gender_out) | |
race_out = self.activation(self.dropout((self.race_linear1(out)))) | |
race_out = self.activation(self.dropout(self.race_linear2(race_out))) | |
race_out = self.race_out(race_out) | |
return age_out, gender_out, race_out | |
if __name__ == '__main__': | |
trained_model_path = os.path.join('./final-models/resnet_101_weigthed.pt') | |
model = AgePredictResnet() | |
model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')), strict=False) | |
model.eval() | |
sample_image = Image.open('../../age_prediction/data/wild_images/part1/50_1_1_20170110120147003.jpg') | |
transforms = Compose([Resize((256, 256)), ToTensor(), | |
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
transformed_image = transforms(sample_image) | |
transformed_image = torch.unsqueeze(transformed_image, 0) | |
print(transformed_image.shape) | |
with torch.inference_mode(): | |
logits = model(transformed_image) | |
age_prob = F.softmax(logits[0], dim=1) | |
sex_prob = F.softmax(logits[1], dim=1) | |
race_prob = F.softmax(logits[2], dim=1) | |
top2_age = torch.topk(age_prob, 2, dim=1) | |
sex = torch.argmax(sex_prob, dim=1) | |
top2_race = torch.topk(race_prob, 2, dim=1) | |
all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), ( | |
sex.item(), sex_prob[0][sex.item()].item()), \ | |
(list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1))) | |
print(all_predictions) | |
age_dict = { | |
0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60', | |
6: '60 to 70', 7: '70 to 80', 8: 'Above 80' | |
} | |
sex_dict = {0: 'Male', 1: 'Female'} | |
race_dict = { | |
0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)' | |
} | |
# | |
pred_dict = { | |
'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]), | |
'Age Probability': all_predictions[0][0], | |
'Predicted Sex': sex_dict[all_predictions[1][0]], | |
'Sex Probability': all_predictions[1][1], | |
'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]), | |
'Race Probability': all_predictions[2][0], | |
} | |
print(pred_dict) | |