File size: 2,627 Bytes
1d53eef
fd743d2
1d53eef
 
fbdd41c
1d53eef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbdd41c
1d53eef
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master")
os.system("git clone https://github.com/thohemp/6DRepNet")
import sys
sys.path.append("6DRepNet")

from model import SixDRepNet
import math
import re
from matplotlib import pyplot as plt
import sys
import os

import numpy as np
import cv2
import matplotlib.pyplot as plt
from numpy.lib.function_base import _quantile_unchecked

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision
import torch.nn.functional as F
import utils
import matplotlib
from PIL import Image
import time
from face_detection import RetinaFace
from huggingface_hub import hf_hub_download

snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth")

model = SixDRepNet(backbone_name='RepVGG-B1g2',
                        backbone_file='',
                        deploy=True,
                        pretrained=False)
                        
detector = RetinaFace()
saved_state_dict = torch.load(os.path.join(
        snapshot_path), map_location='cpu')
        
if 'model_state_dict' in saved_state_dict:
    model.load_state_dict(saved_state_dict['model_state_dict'])
else:
    model.load_state_dict(saved_state_dict)    
model.eval()


def predict(img):
  faces = detector(frame)
  for box, landmarks, score in faces:
      # Print the location of each face in this image
      if score < .95:
          continue
      x_min = int(box[0])
      y_min = int(box[1])
      x_max = int(box[2])
      y_max = int(box[3])         
      bbox_width = abs(x_max - x_min)
      bbox_height = abs(y_max - y_min)
  
      x_min = max(0,x_min-int(0.2*bbox_height))
      y_min = max(0,y_min-int(0.2*bbox_width))
      x_max = x_max+int(0.2*bbox_height)
      y_max = y_max+int(0.2*bbox_width)
  
      img = frame[y_min:y_max,x_min:x_max]
      img = cv2.resize(img, (244, 244))/255.0
      img = img.transpose(2, 0, 1)
      img = torch.from_numpy(img).type(torch.FloatTensor)
      img = torch.Tensor(img)
      img=img.unsqueeze(0)         

      R_pred = model(img)
      euler = utils.compute_euler_angles_from_rotation_matrices(
          R_pred)*180/np.pi
      p_pred_deg = euler[:, 0].cpu()
      y_pred_deg = euler[:, 1].cpu()
      r_pred_deg = euler[:, 2].cpu()
      utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width)
      
  return img
  
  
iface = gr.Interface(
    fn=predict, 
    inputs='img',
    outputs='img',
)

iface.launch()