Spaces:
Runtime error
Runtime error
Commit
·
7faf1c4
1
Parent(s):
7505f8b
added code and files
Browse files- .gitignore +1 -0
- app.py +71 -0
- data/sample_images/100_1_0_20170110183726390.jpg +0 -0
- data/sample_images/1_0_0_20170109193841675.jpg +0 -0
- data/sample_images/21_1_3_20170104222105039.jpg +0 -0
- data/sample_images/22_1_3_20170104231706746.jpg +0 -0
- data/sample_images/27_0_3_20170104214555317.jpg +0 -0
- data/sample_images/49_0_0_20170104184239893.jpg +0 -0
- data/sample_images/4_1_3_20161220220636202.jpg +0 -0
- data/sample_images/55_0_0_20170111195801050.jpg +0 -0
- data/sample_images/58_0_3_20170104220928390.jpg +0 -0
- data/sample_images/74_1_0_20170110153238490.jpg +0 -0
- data/sample_images/75_0_0_20170111200151404.jpg +0 -0
- final-models/resnet_101_weigthed.pt +3 -0
- inference.py +52 -0
- model.py +84 -0
- requirements.txt +53 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
/__pycache__/
|
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from inference import get_predictions
|
6 |
+
|
7 |
+
|
8 |
+
st.title('Person characteristic prediction Demo')
|
9 |
+
|
10 |
+
sample_files = os.listdir('./data/sample_images')
|
11 |
+
tot_index = len(sample_files)
|
12 |
+
sample_path = './data/sample_images'
|
13 |
+
|
14 |
+
if 'image_index' not in st.session_state:
|
15 |
+
st.session_state['image_index'] = 4
|
16 |
+
|
17 |
+
if 'which_button' not in st.session_state:
|
18 |
+
st.session_state['which_button'] = 'sample_button'
|
19 |
+
|
20 |
+
stream_col, upload_col, sample_col = st.tabs(['Take picture', 'Upload file', 'Select from sample images'])
|
21 |
+
with stream_col:
|
22 |
+
picture = st.camera_input("Take a picture")
|
23 |
+
if picture is not None:
|
24 |
+
captured_img = Image.open(picture)
|
25 |
+
st.image(captured_img, caption='Captured Image')
|
26 |
+
use_captured_image = st.button('Use this captured image')
|
27 |
+
if use_captured_image is True:
|
28 |
+
st.session_state['which_button'] = 'captured_button'
|
29 |
+
with upload_col:
|
30 |
+
uploaded_file = st.file_uploader("Select a picture from your computer(png/jpg) :", type=['png', 'jpg', 'jpeg'])
|
31 |
+
if uploaded_file is not None:
|
32 |
+
img = Image.open(uploaded_file)
|
33 |
+
st.image(img, caption='Uploaded Image')
|
34 |
+
use_uploaded_image = st.button("Use uploaded image")
|
35 |
+
if use_uploaded_image is True:
|
36 |
+
st.session_state['which_button'] = 'upload_button'
|
37 |
+
|
38 |
+
with sample_col:
|
39 |
+
st.write("Select one from these available samples: ")
|
40 |
+
current_index = st.session_state['image_index']
|
41 |
+
current_image = Image.open(os.path.join(sample_path, sample_files[current_index]))
|
42 |
+
|
43 |
+
# next = st.button('next_image')
|
44 |
+
prev_button, next_button = st.columns(2)
|
45 |
+
with prev_button:
|
46 |
+
prev = st.button('prev_image')
|
47 |
+
with next_button:
|
48 |
+
next = st.button('next_image')
|
49 |
+
if prev:
|
50 |
+
current_index = (current_index - 1) % tot_index
|
51 |
+
if next:
|
52 |
+
current_index = (current_index + 1) % tot_index
|
53 |
+
st.session_state['image_index'] = current_index
|
54 |
+
sample_image = Image.open(os.path.join(sample_path, sample_files[current_index]))
|
55 |
+
st.image(sample_image, caption='Chosen image')
|
56 |
+
|
57 |
+
use_sample_image = st.button("Use this Sample")
|
58 |
+
if use_sample_image is True:
|
59 |
+
st.session_state['which_button'] = 'sample_button'
|
60 |
+
|
61 |
+
predict_clicked = st.button("Get prediction")
|
62 |
+
if predict_clicked:
|
63 |
+
which_button = st.session_state['which_button']
|
64 |
+
if which_button == 'sample_button':
|
65 |
+
predictions = get_predictions(sample_image)
|
66 |
+
elif which_button == 'upload_button':
|
67 |
+
predictions = get_predictions(img)
|
68 |
+
elif which_button == 'captured_button':
|
69 |
+
predictions = get_predictions(captured_img)
|
70 |
+
st.markdown('**The model predictions along with their probabilities are :**')
|
71 |
+
st.table(predictions)
|
data/sample_images/100_1_0_20170110183726390.jpg
ADDED
![]() |
data/sample_images/1_0_0_20170109193841675.jpg
ADDED
![]() |
data/sample_images/21_1_3_20170104222105039.jpg
ADDED
![]() |
data/sample_images/22_1_3_20170104231706746.jpg
ADDED
![]() |
data/sample_images/27_0_3_20170104214555317.jpg
ADDED
![]() |
data/sample_images/49_0_0_20170104184239893.jpg
ADDED
![]() |
data/sample_images/4_1_3_20161220220636202.jpg
ADDED
![]() |
data/sample_images/55_0_0_20170111195801050.jpg
ADDED
![]() |
data/sample_images/58_0_3_20170104220928390.jpg
ADDED
![]() |
data/sample_images/74_1_0_20170110153238490.jpg
ADDED
![]() |
data/sample_images/75_0_0_20170111200151404.jpg
ADDED
![]() |
final-models/resnet_101_weigthed.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50caa9aa9f00f30c0ff6ae5e16c487bc3ba3db59ffd57e7010358cd165848252
|
3 |
+
size 176824799
|
inference.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
5 |
+
|
6 |
+
from model import AgePredictResnet
|
7 |
+
|
8 |
+
path = './final-models/resnet_101_weigthed.pt'
|
9 |
+
age_dict = {
|
10 |
+
0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
|
11 |
+
6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
|
12 |
+
}
|
13 |
+
sex_dict = {0: 'Male', 1: 'Female'}
|
14 |
+
race_dict = {
|
15 |
+
0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
|
16 |
+
}
|
17 |
+
|
18 |
+
@st.experimental_memo
|
19 |
+
def load_trained_model(model_path):
|
20 |
+
model = AgePredictResnet()
|
21 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
|
22 |
+
model.eval()
|
23 |
+
return model
|
24 |
+
|
25 |
+
|
26 |
+
def get_predictions(input_image):
|
27 |
+
model = load_trained_model(path)
|
28 |
+
transforms = Compose([Resize((256, 256)), ToTensor(),
|
29 |
+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
30 |
+
transformed_image = transforms(input_image)
|
31 |
+
transformed_image = torch.unsqueeze(transformed_image, 0)
|
32 |
+
with torch.inference_mode():
|
33 |
+
logits = model(transformed_image)
|
34 |
+
age_prob = F.softmax(logits[0], dim=1)
|
35 |
+
sex_prob = F.softmax(logits[1], dim=1)
|
36 |
+
race_prob = F.softmax(logits[2], dim=1)
|
37 |
+
top2_age = torch.topk(age_prob, 2, dim=1)
|
38 |
+
sex = torch.argmax(sex_prob, dim=1)
|
39 |
+
top2_race = torch.topk(race_prob, 2, dim=1)
|
40 |
+
all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
|
41 |
+
sex.item(), sex_prob[0][sex.item()].item()), \
|
42 |
+
(list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
|
43 |
+
|
44 |
+
pred_dict = {
|
45 |
+
'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
|
46 |
+
'Age Probability': all_predictions[0][0],
|
47 |
+
'Predicted Sex': sex_dict[all_predictions[1][0]],
|
48 |
+
'Sex Probability': all_predictions[1][1],
|
49 |
+
'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
|
50 |
+
'Race Probability': all_predictions[2][0],
|
51 |
+
}
|
52 |
+
return pred_dict
|
model.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
9 |
+
|
10 |
+
|
11 |
+
class AgePredictResnet(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.model = torchvision.models.resnet101()
|
15 |
+
self.model.fc = nn.Linear(2048, 512)
|
16 |
+
self.age_linear1 = nn.Linear(512, 256)
|
17 |
+
self.age_linear2 = nn.Linear(256, 128)
|
18 |
+
self.age_out = nn.Linear(128, 9)
|
19 |
+
self.gender_linear1 = nn.Linear(512, 256)
|
20 |
+
self.gender_linear2 = nn.Linear(256, 128)
|
21 |
+
self.gender_out = nn.Linear(128, 2)
|
22 |
+
self.race_linear1 = nn.Linear(512, 256)
|
23 |
+
self.race_linear2 = nn.Linear(256, 128)
|
24 |
+
self.race_out = nn.Linear(128, 5)
|
25 |
+
self.activation = nn.ReLU()
|
26 |
+
self.dropout = nn.Dropout(0.4)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
out = self.activation(self.model(x))
|
30 |
+
age_out = self.activation(self.dropout((self.age_linear1(out))))
|
31 |
+
age_out = self.activation(self.dropout(self.age_linear2(age_out)))
|
32 |
+
age_out = self.age_out(age_out)
|
33 |
+
|
34 |
+
gender_out = self.activation(self.dropout((self.gender_linear1(out))))
|
35 |
+
gender_out = self.activation(self.dropout(self.gender_linear2(gender_out)))
|
36 |
+
gender_out = self.gender_out(gender_out)
|
37 |
+
|
38 |
+
race_out = self.activation(self.dropout((self.race_linear1(out))))
|
39 |
+
race_out = self.activation(self.dropout(self.race_linear2(race_out)))
|
40 |
+
race_out = self.race_out(race_out)
|
41 |
+
return age_out, gender_out, race_out
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
trained_model_path = os.path.join('./final-models/resnet_101_weigthed.pt')
|
46 |
+
model = AgePredictResnet()
|
47 |
+
model.load_state_dict(torch.load(trained_model_path, map_location=torch.device('cpu')), strict=False)
|
48 |
+
model.eval()
|
49 |
+
sample_image = Image.open('../../age_prediction/data/wild_images/part1/50_1_1_20170110120147003.jpg')
|
50 |
+
transforms = Compose([Resize((256, 256)), ToTensor(),
|
51 |
+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
52 |
+
transformed_image = transforms(sample_image)
|
53 |
+
transformed_image = torch.unsqueeze(transformed_image, 0)
|
54 |
+
print(transformed_image.shape)
|
55 |
+
with torch.inference_mode():
|
56 |
+
logits = model(transformed_image)
|
57 |
+
age_prob = F.softmax(logits[0], dim=1)
|
58 |
+
sex_prob = F.softmax(logits[1], dim=1)
|
59 |
+
race_prob = F.softmax(logits[2], dim=1)
|
60 |
+
top2_age = torch.topk(age_prob, 2, dim=1)
|
61 |
+
sex = torch.argmax(sex_prob, dim=1)
|
62 |
+
top2_race = torch.topk(race_prob, 2, dim=1)
|
63 |
+
all_predictions = (list(top2_age.values.numpy().reshape(-1)), list(top2_age.indices.numpy().reshape(-1))), (
|
64 |
+
sex.item(), sex_prob[0][sex.item()].item()), \
|
65 |
+
(list(top2_race.values.numpy().reshape(-1)), list(top2_race.indices.numpy().reshape(-1)))
|
66 |
+
print(all_predictions)
|
67 |
+
age_dict = {
|
68 |
+
0: '0 to 10', 1: '10 to 20', 2: '20 to 30', 3: '30 to 40', 4: '40 to 50', 5: '50 to 60',
|
69 |
+
6: '60 to 70', 7: '70 to 80', 8: 'Above 80'
|
70 |
+
}
|
71 |
+
sex_dict = {0: 'Male', 1: 'Female'}
|
72 |
+
race_dict = {
|
73 |
+
0: 'White', 1: 'Black', 2: 'Asian', 3: 'Indian', 4: 'Others (like Hispanic, Latino, Middle Eastern etc)'
|
74 |
+
}
|
75 |
+
#
|
76 |
+
pred_dict = {
|
77 |
+
'Predicted Age range': (age_dict[all_predictions[0][1][0]], age_dict[all_predictions[0][1][1]]),
|
78 |
+
'Age Probability': all_predictions[0][0],
|
79 |
+
'Predicted Sex': sex_dict[all_predictions[1][0]],
|
80 |
+
'Sex Probability': all_predictions[1][1],
|
81 |
+
'Predicted Race': (race_dict[all_predictions[2][1][0]], race_dict[all_predictions[2][1][1]]),
|
82 |
+
'Race Probability': all_predictions[2][0],
|
83 |
+
}
|
84 |
+
print(pred_dict)
|
requirements.txt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
2 |
+
altair==4.2.0
|
3 |
+
attrs==22.2.0
|
4 |
+
backports.zoneinfo==0.2.1
|
5 |
+
blinker==1.5
|
6 |
+
cachetools==5.2.0
|
7 |
+
certifi==2022.12.7
|
8 |
+
charset-normalizer==2.1.1
|
9 |
+
click==8.1.3
|
10 |
+
commonmark==0.9.1
|
11 |
+
decorator==5.1.1
|
12 |
+
entrypoints==0.4
|
13 |
+
gitdb==4.0.10
|
14 |
+
GitPython==3.1.29
|
15 |
+
idna==3.4
|
16 |
+
importlib-metadata==5.2.0
|
17 |
+
importlib-resources==5.10.1
|
18 |
+
Jinja2==3.1.2
|
19 |
+
jsonschema==4.17.3
|
20 |
+
MarkupSafe==2.1.1
|
21 |
+
numpy==1.24.0
|
22 |
+
packaging==22.0
|
23 |
+
pandas==1.5.2
|
24 |
+
Pillow==9.3.0
|
25 |
+
pkgutil_resolve_name==1.3.10
|
26 |
+
protobuf==3.20.3
|
27 |
+
pyarrow==10.0.1
|
28 |
+
pydeck==0.8.0
|
29 |
+
Pygments==2.13.0
|
30 |
+
Pympler==1.0.1
|
31 |
+
pyrsistent==0.19.2
|
32 |
+
python-dateutil==2.8.2
|
33 |
+
pytz==2022.7
|
34 |
+
pytz-deprecation-shim==0.1.0.post0
|
35 |
+
requests==2.28.1
|
36 |
+
rich==12.6.0
|
37 |
+
semver==2.13.0
|
38 |
+
six==1.16.0
|
39 |
+
smmap==5.0.0
|
40 |
+
streamlit==1.16.0
|
41 |
+
toml==0.10.2
|
42 |
+
toolz==0.12.0
|
43 |
+
torch==1.13.1
|
44 |
+
torchaudio==0.13.1
|
45 |
+
torchvision==0.14.1
|
46 |
+
tornado==6.2
|
47 |
+
typing_extensions==4.4.0
|
48 |
+
tzdata==2022.7
|
49 |
+
tzlocal==4.2
|
50 |
+
urllib3==1.26.13
|
51 |
+
validators==0.20.0
|
52 |
+
watchdog==2.2.0
|
53 |
+
zipp==3.11.0
|