AshanGimhana commited on
Commit
13a83ac
1 Parent(s): 07fffb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -8
app.py CHANGED
@@ -36,18 +36,33 @@ from models.psp import pSp
36
  # Huggingface login
37
  login(token=os.getenv("TOKENKEY"))
38
 
 
 
 
39
  # If 'mse' is a custom function needed,
40
  #custom_objects = {'mse': MeanSquaredError()}
41
  #new_age_model = load_model("age_prediction_model.h5")
 
 
 
 
 
 
 
 
42
 
43
  # Download models from Huggingface
44
  age_prototxt = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="age.prototxt")
45
  caffe_model = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="dex_imdb_wiki.caffemodel")
46
  sam_ffhq_aging = hf_hub_download(repo_id="AshanGimhana/Face_Agin_model", filename="sam_ffhq_aging.pt")
47
 
 
 
48
 
49
  # Age prediction model setup
50
- age_net = cv2.dnn.readNetFromCaffe(age_prototxt, caffe_model)
 
 
51
 
52
  # Face detection and landmarks predictor setup
53
  detector = dlib.get_frontal_face_detector()
@@ -116,16 +131,37 @@ def get_mouth_region(image):
116
 
117
  # Function to predict age
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def predict_age(image):
120
- image = np.array(image.resize((64, 64)))
121
- image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
122
- image = image / 255.0
123
- image = np.expand_dims(image, axis=0)
 
 
 
 
 
 
 
124
 
125
  # Predict age
126
- val = new_age_model.predict(np.array(image))
127
- age = val[0][0]
128
- return int(age)
 
129
 
130
  # Function for color correction
131
  def color_correct(source, target):
 
36
  # Huggingface login
37
  login(token=os.getenv("TOKENKEY"))
38
 
39
+ ########################################################################
40
+ ############## tensorflow model for age calculation #######################
41
+
42
  # If 'mse' is a custom function needed,
43
  #custom_objects = {'mse': MeanSquaredError()}
44
  #new_age_model = load_model("age_prediction_model.h5")
45
+ ########################################################################
46
+
47
+
48
+ ########################################################################
49
+ ############## pytorch model for age calculation #######################
50
+ age_calc_model = torch.load('Custom_Age_prediction_model.pth')
51
+
52
+ ########################################################################
53
 
54
  # Download models from Huggingface
55
  age_prototxt = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="age.prototxt")
56
  caffe_model = hf_hub_download(repo_id="AshanGimhana/Age_Detection_caffe", filename="dex_imdb_wiki.caffemodel")
57
  sam_ffhq_aging = hf_hub_download(repo_id="AshanGimhana/Face_Agin_model", filename="sam_ffhq_aging.pt")
58
 
59
+ ########################################################################
60
+ ############## caffe model for age calculation #######################
61
 
62
  # Age prediction model setup
63
+ #age_net = cv2.dnn.readNetFromCaffe(age_prototxt, caffe_model)
64
+ ########################################################################
65
+
66
 
67
  # Face detection and landmarks predictor setup
68
  detector = dlib.get_frontal_face_detector()
 
131
 
132
  # Function to predict age
133
 
134
+ # old tensorflow function for age predict
135
+
136
+ #def predict_age(image):
137
+ #image = np.array(image.resize((64, 64)))
138
+ #image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
139
+ #image = image / 255.0
140
+ #image = np.expand_dims(image, axis=0)
141
+
142
+ ##### Predict age
143
+ #val = new_age_model.predict(np.array(image))
144
+ #age = val[0][0]
145
+ #return int(age)
146
+
147
  def predict_age(image):
148
+ age_calc_model.eval()
149
+ # Load and preprocess the image
150
+ image = cv2.imread(image, cv2.IMREAD_GRAYSCALE) # Load as grayscale
151
+ image = cv2.resize(image, (64, 64)) # Resize to 64x64
152
+ image = image / 255.0 # Normalize pixel values to [0, 1]
153
+ image = np.expand_dims(image, axis=0) # Add batch dimension
154
+ image = np.expand_dims(image, axis=0) # Add channel dimension
155
+ image = torch.tensor(image, dtype=torch.float32).to(device)
156
+
157
+ # Convert to tensor
158
+ image_tensor = torch.tensor(image, dtype=torch.float32)
159
 
160
  # Predict age
161
+ with torch.no_grad():
162
+ predicted_age = age_calc_model(image_tensor)
163
+
164
+ return int(predicted_age.item())
165
 
166
  # Function for color correction
167
  def color_correct(source, target):