alexnasa commited on
Commit
03f1807
·
verified ·
1 Parent(s): bd096d2

Delete src/pixel3dmm/preprocessing/MICA/demo.py

Browse files
src/pixel3dmm/preprocessing/MICA/demo.py DELETED
@@ -1,156 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
- # holder of all proprietary rights on this computer program.
5
- # You can only use this computer program if you have closed
6
- # a license agreement with MPG or you get the right to use the computer
7
- # program from someone who is authorized to grant you that right.
8
- # Any use of the computer program without a valid license is prohibited and
9
- # liable to prosecution.
10
- #
11
- # Copyright©2023 Max-Planck-Gesellschaft zur Förderung
12
- # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
- # for Intelligent Systems. All rights reserved.
14
- #
15
- # Contact: [email protected]
16
-
17
-
18
- import argparse
19
- import os
20
- import random
21
- from glob import glob
22
- from pathlib import Path
23
-
24
- import cv2
25
- import numpy as np
26
- import torch
27
- import torch.backends.cudnn as cudnn
28
- import trimesh
29
- from insightface.app.common import Face
30
- from insightface.utils import face_align
31
- from loguru import logger
32
- from skimage.io import imread
33
- from tqdm import tqdm
34
-
35
- from configs.config import get_cfg_defaults
36
- from datasets.creation.util import get_arcface_input, get_center, draw_on
37
- from utils import util
38
- from utils.landmark_detector import LandmarksDetector, detectors
39
-
40
-
41
- def deterministic(rank):
42
- torch.manual_seed(rank)
43
- torch.cuda.manual_seed(rank)
44
- np.random.seed(rank)
45
- random.seed(rank)
46
-
47
- cudnn.deterministic = True
48
- cudnn.benchmark = False
49
-
50
-
51
- def process(args, app, image_size=224, draw_bbox=False):
52
- dst = Path(args.a)
53
- dst.mkdir(parents=True, exist_ok=True)
54
- processes = []
55
- image_paths = sorted(glob(args.i + '/*.*'))
56
- for image_path in tqdm(image_paths):
57
- name = Path(image_path).stem
58
- img = cv2.imread(image_path)
59
- bboxes, kpss = app.detect(img)
60
- if bboxes.shape[0] == 0:
61
- logger.error(f'[ERROR] Face not detected for {image_path}')
62
- continue
63
- i = get_center(bboxes, img)
64
- bbox = bboxes[i, 0:4]
65
- det_score = bboxes[i, 4]
66
- kps = None
67
- if kpss is not None:
68
- kps = kpss[i]
69
- face = Face(bbox=bbox, kps=kps, det_score=det_score)
70
- blob, aimg = get_arcface_input(face, img)
71
- file = str(Path(dst, name))
72
- np.save(file, blob)
73
- processes.append(file + '.npy')
74
- cv2.imwrite(file + '.jpg', face_align.norm_crop(img, landmark=face.kps, image_size=image_size))
75
- if draw_bbox:
76
- dimg = draw_on(img, [face])
77
- cv2.imwrite(file + '_bbox.jpg', dimg)
78
-
79
- return processes
80
-
81
-
82
- def to_batch(path):
83
- src = path.replace('npy', 'jpg')
84
- if not os.path.exists(src):
85
- src = path.replace('npy', 'png')
86
-
87
- image = imread(src)[:, :, :3]
88
- image = image / 255.
89
- image = cv2.resize(image, (224, 224)).transpose(2, 0, 1)
90
- image = torch.tensor(image).cuda()[None]
91
-
92
- arcface = np.load(path)
93
- arcface = torch.tensor(arcface).cuda()[None]
94
-
95
- return image, arcface
96
-
97
-
98
- def load_checkpoint(args, mica):
99
- checkpoint = torch.load(args.m)
100
- if 'arcface' in checkpoint:
101
- mica.arcface.load_state_dict(checkpoint['arcface'])
102
- if 'flameModel' in checkpoint:
103
- mica.flameModel.load_state_dict(checkpoint['flameModel'])
104
-
105
-
106
- def main(cfg, args):
107
- device = 'cuda:0'
108
- cfg.model.testing = True
109
- mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, device)
110
- load_checkpoint(args, mica)
111
- mica.eval()
112
-
113
- faces = mica.flameModel.generator.faces_tensor.cpu()
114
- Path(args.o).mkdir(exist_ok=True, parents=True)
115
-
116
- app = LandmarksDetector(model=detectors.RETINAFACE)
117
-
118
- with torch.no_grad():
119
- logger.info(f'Processing has started...')
120
- paths = process(args, app, draw_bbox=False)
121
- for path in tqdm(paths):
122
- name = Path(path).stem
123
- images, arcface = to_batch(path)
124
- codedict = mica.encode(images, arcface)
125
- opdict = mica.decode(codedict)
126
- meshes = opdict['pred_canonical_shape_vertices']
127
- code = opdict['pred_shape_code']
128
- lmk = mica.flame.compute_landmarks(meshes)
129
-
130
- mesh = meshes[0]
131
- landmark_51 = lmk[0, 17:]
132
- landmark_7 = landmark_51[[19, 22, 25, 28, 16, 31, 37]]
133
-
134
- dst = Path(args.o, name)
135
- dst.mkdir(parents=True, exist_ok=True)
136
- trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.ply') # save in millimeters
137
- trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.obj')
138
- np.save(f'{dst}/identity', code[0].cpu().numpy())
139
- np.save(f'{dst}/kpt7', landmark_7.cpu().numpy() * 1000.0)
140
- np.save(f'{dst}/kpt68', lmk.cpu().numpy() * 1000.0)
141
-
142
- logger.info(f'Processing finished. Results has been saved in {args.o}')
143
-
144
-
145
- if __name__ == '__main__':
146
- parser = argparse.ArgumentParser(description='MICA - Towards Metrical Reconstruction of Human Faces')
147
- parser.add_argument('-i', default='demo/input', type=str, help='Input folder with images')
148
- parser.add_argument('-o', default='demo/output', type=str, help='Output folder')
149
- parser.add_argument('-a', default='demo/arcface', type=str, help='Processed images for MICA input')
150
- parser.add_argument('-m', default='data/pretrained/mica.tar', type=str, help='Pretrained model path')
151
-
152
- args = parser.parse_args()
153
- cfg = get_cfg_defaults()
154
-
155
- deterministic(42)
156
- main(cfg, args)