Spaces:
Running
on
Zero
Running
on
Zero
Upload demo.py
Browse files
src/pixel3dmm/preprocessing/MICA/demo.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems. All rights reserved.
|
14 |
+
#
|
15 |
+
# Contact: [email protected]
|
16 |
+
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
import traceback
|
22 |
+
from glob import glob
|
23 |
+
from pathlib import Path
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
import cv2
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.backends.cudnn as cudnn
|
30 |
+
import trimesh
|
31 |
+
from insightface.app.common import Face
|
32 |
+
from insightface.utils import face_align
|
33 |
+
from loguru import logger
|
34 |
+
from skimage.io import imread
|
35 |
+
from tqdm import tqdm
|
36 |
+
#from retinaface.pre_trained_models import get_model
|
37 |
+
#from retinaface.utils import vis_annotations
|
38 |
+
#from matplotlib import pyplot as plt
|
39 |
+
|
40 |
+
|
41 |
+
from pixel3dmm.preprocessing.MICA.configs.config import get_cfg_defaults
|
42 |
+
from pixel3dmm.preprocessing.MICA.datasets.creation.util import get_arcface_input, get_center, draw_on
|
43 |
+
from pixel3dmm.preprocessing.MICA.utils import util
|
44 |
+
from pixel3dmm.preprocessing.MICA.utils.landmark_detector import LandmarksDetector, detectors
|
45 |
+
from pixel3dmm import env_paths
|
46 |
+
|
47 |
+
|
48 |
+
#model = get_model("resnet50_2020-07-20", max_size=512)
|
49 |
+
#model.eval()
|
50 |
+
|
51 |
+
|
52 |
+
def deterministic(rank):
|
53 |
+
torch.manual_seed(rank)
|
54 |
+
torch.cuda.manual_seed(rank)
|
55 |
+
np.random.seed(rank)
|
56 |
+
random.seed(rank)
|
57 |
+
|
58 |
+
cudnn.deterministic = True
|
59 |
+
cudnn.benchmark = False
|
60 |
+
|
61 |
+
|
62 |
+
def process(args, app, image_size=224, draw_bbox=False):
|
63 |
+
dst = Path(args.a)
|
64 |
+
dst.mkdir(parents=True, exist_ok=True)
|
65 |
+
processes = []
|
66 |
+
image_paths = sorted(glob(args.i + '/*.*'))#[:1]
|
67 |
+
image_paths = image_paths[::max(1, len(image_paths)//10)]
|
68 |
+
for image_path in tqdm(image_paths):
|
69 |
+
name = Path(image_path).stem
|
70 |
+
img = cv2.imread(image_path)
|
71 |
+
|
72 |
+
|
73 |
+
# FOR pytorch retinaface use this: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
74 |
+
# I had issues with onnxruntime!
|
75 |
+
bboxes, kpss = app.detect(img)
|
76 |
+
|
77 |
+
#annotation = model.predict_jsons(img)
|
78 |
+
#Image.fromarray(vis_annotations(img, annotation)).show()
|
79 |
+
|
80 |
+
#bboxes = np.stack([np.array( annotation[0]['bbox'] + [annotation[0]['score']] ) for i in range(len(annotation))], axis=0)
|
81 |
+
#kpss = np.stack([np.array( annotation[0]['landmarks'] ) for i in range(len(annotation))], axis=0)
|
82 |
+
if bboxes.shape[0] == 0:
|
83 |
+
logger.error(f'[ERROR] Face not detected for {image_path}')
|
84 |
+
continue
|
85 |
+
i = get_center(bboxes, img)
|
86 |
+
bbox = bboxes[i, 0:4]
|
87 |
+
det_score = bboxes[i, 4]
|
88 |
+
kps = None
|
89 |
+
if kpss is not None:
|
90 |
+
kps = kpss[i]
|
91 |
+
|
92 |
+
##for ikp in range(kps.shape[0]):
|
93 |
+
# img[int(kps[ikp][1]), int(kps[ikp][0]), 0] = 255
|
94 |
+
# img[int(kpss_[0][ikp][1]), int(kpss_[0][ikp][0]), 1] = 255
|
95 |
+
#Image.fromarray(img).show()
|
96 |
+
face = Face(bbox=bbox, kps=kps, det_score=det_score)
|
97 |
+
blob, aimg = get_arcface_input(face, img)
|
98 |
+
file = str(Path(dst, name))
|
99 |
+
np.save(file, blob)
|
100 |
+
processes.append(file + '.npy')
|
101 |
+
cv2.imwrite(file + '.jpg', face_align.norm_crop(img, landmark=face.kps, image_size=image_size))
|
102 |
+
if draw_bbox:
|
103 |
+
dimg = draw_on(img, [face])
|
104 |
+
cv2.imwrite(file + '_bbox.jpg', dimg)
|
105 |
+
|
106 |
+
return processes
|
107 |
+
|
108 |
+
|
109 |
+
def to_batch(path):
|
110 |
+
src = path.replace('npy', 'jpg')
|
111 |
+
if not os.path.exists(src):
|
112 |
+
src = path.replace('npy', 'png')
|
113 |
+
|
114 |
+
image = imread(src)[:, :, :3]
|
115 |
+
image = image / 255.
|
116 |
+
image = cv2.resize(image, (224, 224)).transpose(2, 0, 1)
|
117 |
+
image = torch.tensor(image).cuda()[None]
|
118 |
+
|
119 |
+
arcface = np.load(path)
|
120 |
+
arcface = torch.tensor(arcface).cuda()[None]
|
121 |
+
|
122 |
+
return image, arcface
|
123 |
+
|
124 |
+
|
125 |
+
def load_checkpoint(args, mica):
|
126 |
+
checkpoint = torch.load(args.m, weights_only=False)
|
127 |
+
if 'arcface' in checkpoint:
|
128 |
+
mica.arcface.load_state_dict(checkpoint['arcface'])
|
129 |
+
if 'flameModel' in checkpoint:
|
130 |
+
mica.flameModel.load_state_dict(checkpoint['flameModel'])
|
131 |
+
|
132 |
+
|
133 |
+
def main(cfg, args):
|
134 |
+
device = 'cuda:0'
|
135 |
+
cfg.model.testing = True
|
136 |
+
mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, device)
|
137 |
+
load_checkpoint(args, mica)
|
138 |
+
mica.eval()
|
139 |
+
|
140 |
+
faces = mica.flameModel.generator.faces_tensor.cpu()
|
141 |
+
Path(args.o).mkdir(exist_ok=True, parents=True)
|
142 |
+
|
143 |
+
app = LandmarksDetector(model=detectors.RETINAFACE)
|
144 |
+
|
145 |
+
with torch.no_grad():
|
146 |
+
logger.info(f'Processing has started...')
|
147 |
+
paths = process(args, app, draw_bbox=False)
|
148 |
+
for path in tqdm(paths):
|
149 |
+
name = Path(path).stem
|
150 |
+
images, arcface = to_batch(path)
|
151 |
+
codedict = mica.encode(images, arcface)
|
152 |
+
opdict = mica.decode(codedict)
|
153 |
+
meshes = opdict['pred_canonical_shape_vertices']
|
154 |
+
code = opdict['pred_shape_code']
|
155 |
+
lmk = mica.flameModel.generator.compute_landmarks(meshes)
|
156 |
+
|
157 |
+
mesh = meshes[0]
|
158 |
+
landmark_51 = lmk[0, 17:]
|
159 |
+
landmark_7 = landmark_51[[19, 22, 25, 28, 16, 31, 37]]
|
160 |
+
|
161 |
+
dst = Path(args.o, name)
|
162 |
+
dst.mkdir(parents=True, exist_ok=True)
|
163 |
+
trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.ply') # save in millimeters
|
164 |
+
trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.obj')
|
165 |
+
np.save(f'{dst}/identity', code[0].cpu().numpy())
|
166 |
+
np.save(f'{dst}/kpt7', landmark_7.cpu().numpy() * 1000.0)
|
167 |
+
np.save(f'{dst}/kpt68', lmk.cpu().numpy() * 1000.0)
|
168 |
+
|
169 |
+
logger.info(f'Processing finished. Results has been saved in {args.o}')
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == '__main__':
|
173 |
+
parser = argparse.ArgumentParser(description='MICA - Towards Metrical Reconstruction of Human Faces')
|
174 |
+
parser.add_argument('-video_name', required=True, type=str)
|
175 |
+
parser.add_argument('-a', default='demo/arcface', type=str, help='Processed images for MICA input')
|
176 |
+
parser.add_argument('-m', default='data/pretrained/mica.tar', type=str, help='Pretrained model path')
|
177 |
+
|
178 |
+
args = parser.parse_args()
|
179 |
+
cfg = get_cfg_defaults()
|
180 |
+
args.i = f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/cropped/'
|
181 |
+
args.o = f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/'
|
182 |
+
if os.path.exists(f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/'):
|
183 |
+
if len(os.listdir(f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/')) >= 10:
|
184 |
+
print(f'''
|
185 |
+
<<<<<<<< ALREADY COMPLETE MICA PREDICTION FOR {args.video_name}, SKIPPING >>>>>>>>
|
186 |
+
''')
|
187 |
+
exit()
|
188 |
+
main(cfg, args)
|