osanseviero commited on
Commit
91fd28c
·
1 Parent(s): fd743d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -61
app.py CHANGED
@@ -1,92 +1,73 @@
1
  import os
2
  os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master")
3
  os.system("git clone https://github.com/thohemp/6DRepNet")
4
- import sys
5
- sys.path.append("6DRepNet")
6
 
7
- from model import SixDRepNet
8
- import math
9
- import re
10
- from matplotlib import pyplot as plt
11
  import sys
12
- import os
13
 
14
  import numpy as np
15
- import cv2
16
- import matplotlib.pyplot as plt
17
- from numpy.lib.function_base import _quantile_unchecked
18
-
19
  import torch
20
- import torch.nn as nn
21
- from torch.utils.data import DataLoader
22
- from torchvision import transforms
23
- import torchvision
24
- import torch.nn.functional as F
25
- import utils
26
- import matplotlib
27
- from PIL import Image
28
- import time
29
- from face_detection import RetinaFace
30
  from huggingface_hub import hf_hub_download
31
 
 
 
 
 
32
  snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth")
33
 
34
  model = SixDRepNet(backbone_name='RepVGG-B1g2',
35
- backbone_file='',
36
- deploy=True,
37
- pretrained=False)
38
-
39
- detector = RetinaFace()
40
  saved_state_dict = torch.load(os.path.join(
41
  snapshot_path), map_location='cpu')
42
-
43
  if 'model_state_dict' in saved_state_dict:
44
  model.load_state_dict(saved_state_dict['model_state_dict'])
45
  else:
46
- model.load_state_dict(saved_state_dict)
 
47
  model.eval()
48
 
49
-
50
  def predict(img):
51
  faces = detector(frame)
52
  for box, landmarks, score in faces:
53
- # Print the location of each face in this image
54
- if score < .95:
55
- continue
56
- x_min = int(box[0])
57
- y_min = int(box[1])
58
- x_max = int(box[2])
59
- y_max = int(box[3])
60
- bbox_width = abs(x_max - x_min)
61
- bbox_height = abs(y_max - y_min)
62
-
63
- x_min = max(0,x_min-int(0.2*bbox_height))
64
- y_min = max(0,y_min-int(0.2*bbox_width))
65
- x_max = x_max+int(0.2*bbox_height)
66
- y_max = y_max+int(0.2*bbox_width)
67
-
68
- img = frame[y_min:y_max,x_min:x_max]
69
- img = cv2.resize(img, (244, 244))/255.0
70
- img = img.transpose(2, 0, 1)
71
- img = torch.from_numpy(img).type(torch.FloatTensor)
72
- img = torch.Tensor(img)
73
- img=img.unsqueeze(0)
74
 
75
- R_pred = model(img)
76
- euler = utils.compute_euler_angles_from_rotation_matrices(
77
- R_pred)*180/np.pi
78
- p_pred_deg = euler[:, 0].cpu()
79
- y_pred_deg = euler[:, 1].cpu()
80
- r_pred_deg = euler[:, 2].cpu()
81
- utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width)
82
-
83
- return img
 
 
 
 
84
 
85
 
86
  iface = gr.Interface(
87
  fn=predict,
88
- inputs='img',
89
- outputs='img',
90
  )
91
 
92
  iface.launch()
 
1
  import os
2
  os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master")
3
  os.system("git clone https://github.com/thohemp/6DRepNet")
 
 
4
 
 
 
 
 
5
  import sys
6
+ sys.path.append("6DRepNet")
7
 
8
  import numpy as np
 
 
 
 
9
  import torch
 
 
 
 
 
 
 
 
 
 
10
  from huggingface_hub import hf_hub_download
11
 
12
+ from face_detection import RetinaFace
13
+ from model import SixDRepNet
14
+ import utils
15
+
16
  snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth")
17
 
18
  model = SixDRepNet(backbone_name='RepVGG-B1g2',
19
+ backbone_file='',
20
+ deploy=True,
21
+ pretrained=False)
22
+
23
+ detector = RetinaFace(0)
24
  saved_state_dict = torch.load(os.path.join(
25
  snapshot_path), map_location='cpu')
26
+
27
  if 'model_state_dict' in saved_state_dict:
28
  model.load_state_dict(saved_state_dict['model_state_dict'])
29
  else:
30
+ model.load_state_dict(saved_state_dict)
31
+ model.cuda(0)
32
  model.eval()
33
 
 
34
  def predict(img):
35
  faces = detector(frame)
36
  for box, landmarks, score in faces:
37
+ # Print the location of each face in this image
38
+ if score < .95:
39
+ continue
40
+ x_min = int(box[0])
41
+ y_min = int(box[1])
42
+ x_max = int(box[2])
43
+ y_max = int(box[3])
44
+ bbox_width = abs(x_max - x_min)
45
+ bbox_height = abs(y_max - y_min)
46
+
47
+ x_min = max(0,x_min-int(0.2*bbox_height))
48
+ y_min = max(0,y_min-int(0.2*bbox_width))
49
+ x_max = x_max+int(0.2*bbox_height)
50
+ y_max = y_max+int(0.2*bbox_width)
 
 
 
 
 
 
 
51
 
52
+ img = frame[y_min:y_max,x_min:x_max]
53
+ img = cv2.resize(img, (244, 244))/255.0
54
+ img = img.transpose(2, 0, 1)
55
+ img = torch.from_numpy(img).type(torch.FloatTensor)
56
+ img = torch.Tensor(img).cuda(0)
57
+ img=img.unsqueeze(0)
58
+ R_pred = model(img)
59
+ euler = utils.compute_euler_angles_from_rotation_matrices(
60
+ R_pred)*180/np.pi
61
+ p_pred_deg = euler[:, 0].cpu()
62
+ y_pred_deg = euler[:, 1].cpu()
63
+ r_pred_deg = euler[:, 2].cpu()
64
+ return utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width)
65
 
66
 
67
  iface = gr.Interface(
68
  fn=predict,
69
+ inputs=gr.inputs.Image(label="Input Image", source="webcam"),
70
+ outputs='image',
71
  )
72
 
73
  iface.launch()