niks-salodkar commited on
Commit
7faf1c4
·
1 Parent(s): 7505f8b

added code and files

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ /__pycache__/
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from PIL import Image
4
+
5
+ from inference import get_predictions
6
+
7
+
8
+ st.title('Person characteristic prediction Demo')
9
+
10
+ sample_files = os.listdir('./data/sample_images')
11
+ tot_index = len(sample_files)
12
+ sample_path = './data/sample_images'
13
+
14
+ if 'image_index' not in st.session_state:
15
+ st.session_state['image_index'] = 4
16
+
17
+ if 'which_button' not in st.session_state:
18
+ st.session_state['which_button'] = 'sample_button'
19
+
20
+ stream_col, upload_col, sample_col = st.tabs(['Take picture', 'Upload file', 'Select from sample images'])
21
+ with stream_col:
22
+ picture = st.camera_input("Take a picture")
23
+ if picture is not None:
24
+ captured_img = Image.open(picture)
25
+ st.image(captured_img, caption='Captured Image')
26
+ use_captured_image = st.button('Use this captured image')
27
+ if use_captured_image is True:
28
+ st.session_state['which_button'] = 'captured_button'
29
+ with upload_col:
30
+ uploaded_file = st.file_uploader("Select a picture from your computer(png/jpg) :", type=['png', 'jpg', 'jpeg'])
31
+ if uploaded_file is not None:
32
+ img = Image.open(uploaded_file)
33
+ st.image(img, caption='Uploaded Image')
34
+ use_uploaded_image = st.button("Use uploaded image")
35
+ if use_uploaded_image is True:
36
+ st.session_state['which_button'] = 'upload_button'
37
+
38
+ with sample_col:
39
+ st.write("Select one from these available samples: ")
40
+ current_index = st.session_state['image_index']
41
+ current_image = Image.open(os.path.join(sample_path, sample_files[current_index]))
42
+
43
+ # next = st.button('next_image')
44
+ prev_button, next_button = st.columns(2)
45
+ with prev_button:
46
+ prev = st.button('prev_image')
47
+ with next_button:
48
+ next = st.button('next_image')
49
+ if prev:
50
+ current_index = (current_index - 1) % tot_index
51
+ if next:
52
+ current_index = (current_index + 1) % tot_index
53
+ st.session_state['image_index'] = current_index
54
+ sample_image = Image.open(os.path.join(sample_path, sample_files[current_index]))
55
+ st.image(sample_image, caption='Chosen image')
56
+
57
+ use_sample_image = st.button("Use this Sample")
58
+ if use_sample_image is True:
59
+ st.session_state['which_button'] = 'sample_button'
60
+
61
+ predict_clicked = st.button("Get prediction")
62
+ if predict_clicked:
63
+ which_button = st.session_state['which_button']
64
+ if which_button == 'sample_button':
65
+ predictions = get_predictions(sample_image)
66
+ elif which_button == 'upload_button':
67
+ predictions = get_predictions(img)
68
+ elif which_button == 'captured_button':
69
+ predictions = get_predictions(captured_img)
70
+ st.markdown('**The model predictions along with their probabilities are :**')
71
+ st.table(predictions)
data/sample_images/100_1_0_20170110183726390.jpg ADDED
data/sample_images/1_0_0_20170109193841675.jpg ADDED
data/sample_images/21_1_3_20170104222105039.jpg ADDED
data/sample_images/22_1_3_20170104231706746.jpg ADDED
data/sample_images/27_0_3_20170104214555317.jpg ADDED
data/sample_images/49_0_0_20170104184239893.jpg ADDED
data/sample_images/4_1_3_20161220220636202.jpg ADDED
data/sample_images/55_0_0_20170111195801050.jpg ADDED
data/sample_images/58_0_3_20170104220928390.jpg ADDED
data/sample_images/74_1_0_20170110153238490.jpg ADDED
data/sample_images/75_0_0_20170111200151404.jpg ADDED
final-models/resnet_101_weigthed.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50caa9aa9f00f30c0ff6ae5e16c487bc3ba3db59ffd57e7010358cd165848252
3
+ size 176824799
inference.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
5
+
6
+ from model import AgePredictResnet
7
+
8
+ path = './final-models/resnet_101_weigthed.pt'
9
+ age_dict = {
10
+ 0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
11
+ 6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
12
+ }
13
+ sex_dict = {0: 'Male', 1: 'Female'}
14
+ race_dict = {
15
+ 0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
16
+ }
17
+
18
+ @st.experimental_memo
19
+ def load_trained_model(model_path):
20
+ model = AgePredictResnet()
21
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
22
+ model.eval()
23
+ return model
24
+
25
+
26
+ def get_predictions(input_image):
27
+ model = load_trained_model(path)
28
+ transforms = Compose([Resize((256, 256)), ToTensor(),
29
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
30
+ transformed_image = transforms(input_image)
31
+ transformed_image = torch.unsqueeze(transformed_image, 0)
32
+ with torch.inference_mode():
33
+ logits = model(transformed_image)
34
+ age_prob = F.softmax(logits[0], dim=1)
35
+ sex_prob = F.softmax(logits[1], dim=1)
36
+ race_prob = F.softmax(logits[2], dim=1)
37
+ top2_age = torch.topk(age_prob, 2, dim=1)
38
+ sex = torch.argmax(sex_prob, dim=1)
39
+ top2_race = torch.topk(race_prob, 2, dim=1)
40
+ all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
41
+ sex.item(), sex_prob[0][sex.item()].item()), \
42
+ (list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
43
+
44
+ pred_dict = {
45
+ 'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
46
+ 'Age Probability': all_predictions[0][0],
47
+ 'Predicted Sex': sex_dict[all_predictions[1][0]],
48
+ 'Sex Probability': all_predictions[1][1],
49
+ 'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
50
+ 'Race Probability': all_predictions[2][0],
51
+ }
52
+ return pred_dict
model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
9
+
10
+
11
+ class AgePredictResnet(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.model = torchvision.models.resnet101()
15
+ self.model.fc = nn.Linear(2048, 512)
16
+ self.age_linear1 = nn.Linear(512, 256)
17
+ self.age_linear2 = nn.Linear(256, 128)
18
+ self.age_out = nn.Linear(128, 9)
19
+ self.gender_linear1 = nn.Linear(512, 256)
20
+ self.gender_linear2 = nn.Linear(256, 128)
21
+ self.gender_out = nn.Linear(128, 2)
22
+ self.race_linear1 = nn.Linear(512, 256)
23
+ self.race_linear2 = nn.Linear(256, 128)
24
+ self.race_out = nn.Linear(128, 5)
25
+ self.activation = nn.ReLU()
26
+ self.dropout = nn.Dropout(0.4)
27
+
28
+ def forward(self, x):
29
+ out = self.activation(self.model(x))
30
+ age_out = self.activation(self.dropout((self.age_linear1(out))))
31
+ age_out = self.activation(self.dropout(self.age_linear2(age_out)))
32
+ age_out = self.age_out(age_out)
33
+
34
+ gender_out = self.activation(self.dropout((self.gender_linear1(out))))
35
+ gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
36
+ gender_out = self.gender_out(gender_out)
37
+
38
+ race_out = self.activation(self.dropout((self.race_linear1(out))))
39
+ race_out = self.activation(self.dropout(self.race_linear2(race_out)))
40
+ race_out = self.race_out(race_out)
41
+ return age_out, gender_out, race_out
42
+
43
+
44
+ if __name__ == '__main__':
45
+ trained_model_path = os.path.join('./final-models/resnet_101_weigthed.pt')
46
+ model = AgePredictResnet()
47
+ model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')), strict=False)
48
+ model.eval()
49
+ sample_image = Image.open('../../age_prediction/data/wild_images/part1/50_1_1_20170110120147003.jpg')
50
+ transforms = Compose([Resize((256, 256)), ToTensor(),
51
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
52
+ transformed_image = transforms(sample_image)
53
+ transformed_image = torch.unsqueeze(transformed_image, 0)
54
+ print(transformed_image.shape)
55
+ with torch.inference_mode():
56
+ logits = model(transformed_image)
57
+ age_prob = F.softmax(logits[0], dim=1)
58
+ sex_prob = F.softmax(logits[1], dim=1)
59
+ race_prob = F.softmax(logits[2], dim=1)
60
+ top2_age = torch.topk(age_prob, 2, dim=1)
61
+ sex = torch.argmax(sex_prob, dim=1)
62
+ top2_race = torch.topk(race_prob, 2, dim=1)
63
+ all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
64
+ sex.item(), sex_prob[0][sex.item()].item()), \
65
+ (list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
66
+ print(all_predictions)
67
+ age_dict = {
68
+ 0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
69
+ 6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
70
+ }
71
+ sex_dict = {0: 'Male', 1: 'Female'}
72
+ race_dict = {
73
+ 0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
74
+ }
75
+ #
76
+ pred_dict = {
77
+ 'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
78
+ 'Age Probability': all_predictions[0][0],
79
+ 'Predicted Sex': sex_dict[all_predictions[1][0]],
80
+ 'Sex Probability': all_predictions[1][1],
81
+ 'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
82
+ 'Race Probability': all_predictions[2][0],
83
+ }
84
+ print(pred_dict)
requirements.txt ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ altair==4.2.0
3
+ attrs==22.2.0
4
+ backports.zoneinfo==0.2.1
5
+ blinker==1.5
6
+ cachetools==5.2.0
7
+ certifi==2022.12.7
8
+ charset-normalizer==2.1.1
9
+ click==8.1.3
10
+ commonmark==0.9.1
11
+ decorator==5.1.1
12
+ entrypoints==0.4
13
+ gitdb==4.0.10
14
+ GitPython==3.1.29
15
+ idna==3.4
16
+ importlib-metadata==5.2.0
17
+ importlib-resources==5.10.1
18
+ Jinja2==3.1.2
19
+ jsonschema==4.17.3
20
+ MarkupSafe==2.1.1
21
+ numpy==1.24.0
22
+ packaging==22.0
23
+ pandas==1.5.2
24
+ Pillow==9.3.0
25
+ pkgutil_resolve_name==1.3.10
26
+ protobuf==3.20.3
27
+ pyarrow==10.0.1
28
+ pydeck==0.8.0
29
+ Pygments==2.13.0
30
+ Pympler==1.0.1
31
+ pyrsistent==0.19.2
32
+ python-dateutil==2.8.2
33
+ pytz==2022.7
34
+ pytz-deprecation-shim==0.1.0.post0
35
+ requests==2.28.1
36
+ rich==12.6.0
37
+ semver==2.13.0
38
+ six==1.16.0
39
+ smmap==5.0.0
40
+ streamlit==1.16.0
41
+ toml==0.10.2
42
+ toolz==0.12.0
43
+ torch==1.13.1
44
+ torchaudio==0.13.1
45
+ torchvision==0.14.1
46
+ tornado==6.2
47
+ typing_extensions==4.4.0
48
+ tzdata==2022.7
49
+ tzlocal==4.2
50
+ urllib3==1.26.13
51
+ validators==0.20.0
52
+ watchdog==2.2.0
53
+ zipp==3.11.0