alexnasa commited on
Commit
6086b37
·
verified ·
1 Parent(s): 03f1807

Upload demo.py

Browse files
src/pixel3dmm/preprocessing/MICA/demo.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2023 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: [email protected]
16
+
17
+
18
+ import argparse
19
+ import os
20
+ import random
21
+ import traceback
22
+ from glob import glob
23
+ from pathlib import Path
24
+ from PIL import Image
25
+
26
+ import cv2
27
+ import numpy as np
28
+ import torch
29
+ import torch.backends.cudnn as cudnn
30
+ import trimesh
31
+ from insightface.app.common import Face
32
+ from insightface.utils import face_align
33
+ from loguru import logger
34
+ from skimage.io import imread
35
+ from tqdm import tqdm
36
+ #from retinaface.pre_trained_models import get_model
37
+ #from retinaface.utils import vis_annotations
38
+ #from matplotlib import pyplot as plt
39
+
40
+
41
+ from pixel3dmm.preprocessing.MICA.configs.config import get_cfg_defaults
42
+ from pixel3dmm.preprocessing.MICA.datasets.creation.util import get_arcface_input, get_center, draw_on
43
+ from pixel3dmm.preprocessing.MICA.utils import util
44
+ from pixel3dmm.preprocessing.MICA.utils.landmark_detector import LandmarksDetector, detectors
45
+ from pixel3dmm import env_paths
46
+
47
+
48
+ #model = get_model("resnet50_2020-07-20", max_size=512)
49
+ #model.eval()
50
+
51
+
52
+ def deterministic(rank):
53
+ torch.manual_seed(rank)
54
+ torch.cuda.manual_seed(rank)
55
+ np.random.seed(rank)
56
+ random.seed(rank)
57
+
58
+ cudnn.deterministic = True
59
+ cudnn.benchmark = False
60
+
61
+
62
+ def process(args, app, image_size=224, draw_bbox=False):
63
+ dst = Path(args.a)
64
+ dst.mkdir(parents=True, exist_ok=True)
65
+ processes = []
66
+ image_paths = sorted(glob(args.i + '/*.*'))#[:1]
67
+ image_paths = image_paths[::max(1, len(image_paths)//10)]
68
+ for image_path in tqdm(image_paths):
69
+ name = Path(image_path).stem
70
+ img = cv2.imread(image_path)
71
+
72
+
73
+ # FOR pytorch retinaface use this: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
74
+ # I had issues with onnxruntime!
75
+ bboxes, kpss = app.detect(img)
76
+
77
+ #annotation = model.predict_jsons(img)
78
+ #Image.fromarray(vis_annotations(img, annotation)).show()
79
+
80
+ #bboxes = np.stack([np.array( annotation[0]['bbox'] + [annotation[0]['score']] ) for i in range(len(annotation))], axis=0)
81
+ #kpss = np.stack([np.array( annotation[0]['landmarks'] ) for i in range(len(annotation))], axis=0)
82
+ if bboxes.shape[0] == 0:
83
+ logger.error(f'[ERROR] Face not detected for {image_path}')
84
+ continue
85
+ i = get_center(bboxes, img)
86
+ bbox = bboxes[i, 0:4]
87
+ det_score = bboxes[i, 4]
88
+ kps = None
89
+ if kpss is not None:
90
+ kps = kpss[i]
91
+
92
+ ##for ikp in range(kps.shape[0]):
93
+ # img[int(kps[ikp][1]), int(kps[ikp][0]), 0] = 255
94
+ # img[int(kpss_[0][ikp][1]), int(kpss_[0][ikp][0]), 1] = 255
95
+ #Image.fromarray(img).show()
96
+ face = Face(bbox=bbox, kps=kps, det_score=det_score)
97
+ blob, aimg = get_arcface_input(face, img)
98
+ file = str(Path(dst, name))
99
+ np.save(file, blob)
100
+ processes.append(file + '.npy')
101
+ cv2.imwrite(file + '.jpg', face_align.norm_crop(img, landmark=face.kps, image_size=image_size))
102
+ if draw_bbox:
103
+ dimg = draw_on(img, [face])
104
+ cv2.imwrite(file + '_bbox.jpg', dimg)
105
+
106
+ return processes
107
+
108
+
109
+ def to_batch(path):
110
+ src = path.replace('npy', 'jpg')
111
+ if not os.path.exists(src):
112
+ src = path.replace('npy', 'png')
113
+
114
+ image = imread(src)[:, :, :3]
115
+ image = image / 255.
116
+ image = cv2.resize(image, (224, 224)).transpose(2, 0, 1)
117
+ image = torch.tensor(image).cuda()[None]
118
+
119
+ arcface = np.load(path)
120
+ arcface = torch.tensor(arcface).cuda()[None]
121
+
122
+ return image, arcface
123
+
124
+
125
+ def load_checkpoint(args, mica):
126
+ checkpoint = torch.load(args.m, weights_only=False)
127
+ if 'arcface' in checkpoint:
128
+ mica.arcface.load_state_dict(checkpoint['arcface'])
129
+ if 'flameModel' in checkpoint:
130
+ mica.flameModel.load_state_dict(checkpoint['flameModel'])
131
+
132
+
133
+ def main(cfg, args):
134
+ device = 'cuda:0'
135
+ cfg.model.testing = True
136
+ mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, device)
137
+ load_checkpoint(args, mica)
138
+ mica.eval()
139
+
140
+ faces = mica.flameModel.generator.faces_tensor.cpu()
141
+ Path(args.o).mkdir(exist_ok=True, parents=True)
142
+
143
+ app = LandmarksDetector(model=detectors.RETINAFACE)
144
+
145
+ with torch.no_grad():
146
+ logger.info(f'Processing has started...')
147
+ paths = process(args, app, draw_bbox=False)
148
+ for path in tqdm(paths):
149
+ name = Path(path).stem
150
+ images, arcface = to_batch(path)
151
+ codedict = mica.encode(images, arcface)
152
+ opdict = mica.decode(codedict)
153
+ meshes = opdict['pred_canonical_shape_vertices']
154
+ code = opdict['pred_shape_code']
155
+ lmk = mica.flameModel.generator.compute_landmarks(meshes)
156
+
157
+ mesh = meshes[0]
158
+ landmark_51 = lmk[0, 17:]
159
+ landmark_7 = landmark_51[[19, 22, 25, 28, 16, 31, 37]]
160
+
161
+ dst = Path(args.o, name)
162
+ dst.mkdir(parents=True, exist_ok=True)
163
+ trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.ply') # save in millimeters
164
+ trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.obj')
165
+ np.save(f'{dst}/identity', code[0].cpu().numpy())
166
+ np.save(f'{dst}/kpt7', landmark_7.cpu().numpy() * 1000.0)
167
+ np.save(f'{dst}/kpt68', lmk.cpu().numpy() * 1000.0)
168
+
169
+ logger.info(f'Processing finished. Results has been saved in {args.o}')
170
+
171
+
172
+ if __name__ == '__main__':
173
+ parser = argparse.ArgumentParser(description='MICA - Towards Metrical Reconstruction of Human Faces')
174
+ parser.add_argument('-video_name', required=True, type=str)
175
+ parser.add_argument('-a', default='demo/arcface', type=str, help='Processed images for MICA input')
176
+ parser.add_argument('-m', default='data/pretrained/mica.tar', type=str, help='Pretrained model path')
177
+
178
+ args = parser.parse_args()
179
+ cfg = get_cfg_defaults()
180
+ args.i = f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/cropped/'
181
+ args.o = f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/'
182
+ if os.path.exists(f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/'):
183
+ if len(os.listdir(f'{env_paths.PREPROCESSED_DATA}/{args.video_name}/mica/')) >= 10:
184
+ print(f'''
185
+ <<<<<<<< ALREADY COMPLETE MICA PREDICTION FOR {args.video_name}, SKIPPING >>>>>>>>
186
+ ''')
187
+ exit()
188
+ main(cfg, args)