aminaB9 commited on
Commit
9c86f76
·
1 Parent(s): 7df08f2

Adjusted image processing

Browse files
Files changed (2) hide show
  1. app.py +19 -5
  2. requirements.txt +2 -1
app.py CHANGED
@@ -16,11 +16,25 @@ import subprocess
16
  import os
17
 
18
 
 
19
 
20
- def resizeImage(image):
21
- resized = image.resize((112, 112))
22
- return resized
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
 
@@ -154,7 +168,7 @@ with gr.Blocks() as demo:
154
  with gr.Row():
155
  with gr.Column():
156
  image_input_enroll = gr.Image(label="Upload a reference facial image.", type="pil", sources="upload")
157
- image_input_enroll.change(fn=resizeImage, inputs=image_input_enroll, outputs=image_input_enroll)
158
  with gr.Column():
159
  example_gallery = gr.Gallery(value=example_images, columns=3)
160
  with gr.Column():
@@ -231,7 +245,7 @@ with gr.Blocks() as demo:
231
  with gr.Row():
232
  with gr.Column():
233
  image_input_auth = gr.Image(label="Upload a facial image.", type="pil", sources="upload")
234
- image_input_auth.change(fn=resizeImage, inputs=image_input_auth, outputs=image_input_auth)
235
  with gr.Column():
236
  example_gallery = gr.Gallery(value=example_images_auth, columns=3)
237
  with gr.Column():
 
16
  import os
17
 
18
 
19
+ from facenet_pytorch import MTCNN
20
 
21
+ mtcnn = MTCNN(keep_all=False)
22
+
23
+
24
+
25
+ def crop_face_to_112x112(image: Image.Image):
26
+ if image.size == (112, 112):
27
+ return image
28
 
29
+ boxes, _ = mtcnn.detect(image)
30
+
31
+ if boxes is None:
32
+ raise ValueError("No face detected.")
33
+
34
+ x1, y1, x2, y2 = map(int, boxes[0])
35
+ cropped = image.crop((x1, y1, x2, y2))
36
+ resized = cropped.resize((112, 112), Image.BILINEAR)
37
+ return resized
38
 
39
 
40
 
 
168
  with gr.Row():
169
  with gr.Column():
170
  image_input_enroll = gr.Image(label="Upload a reference facial image.", type="pil", sources="upload")
171
+ image_input_enroll.change(fn=crop_face_to_112x112, inputs=image_input_enroll, outputs=image_input_enroll)
172
  with gr.Column():
173
  example_gallery = gr.Gallery(value=example_images, columns=3)
174
  with gr.Column():
 
245
  with gr.Row():
246
  with gr.Column():
247
  image_input_auth = gr.Image(label="Upload a facial image.", type="pil", sources="upload")
248
+ image_input_auth.change(fn=crop_face_to_112x112, inputs=image_input_auth, outputs=image_input_auth)
249
  with gr.Column():
250
  example_gallery = gr.Gallery(value=example_images_auth, columns=3)
251
  with gr.Column():
requirements.txt CHANGED
@@ -4,4 +4,5 @@ torch
4
  timm
5
  opencv-python
6
  pillow
7
- torchvision
 
 
4
  timm
5
  opencv-python
6
  pillow
7
+ torchvision
8
+ facenet-pytorch