Spaces:
Running
on
Zero
Running
on
Zero
Delete src/pixel3dmm/preprocessing/MICA/demo.py
Browse files
src/pixel3dmm/preprocessing/MICA/demo.py
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
|
3 |
-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
-
# holder of all proprietary rights on this computer program.
|
5 |
-
# You can only use this computer program if you have closed
|
6 |
-
# a license agreement with MPG or you get the right to use the computer
|
7 |
-
# program from someone who is authorized to grant you that right.
|
8 |
-
# Any use of the computer program without a valid license is prohibited and
|
9 |
-
# liable to prosecution.
|
10 |
-
#
|
11 |
-
# Copyright©2023 Max-Planck-Gesellschaft zur Förderung
|
12 |
-
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
-
# for Intelligent Systems. All rights reserved.
|
14 |
-
#
|
15 |
-
# Contact: [email protected]
|
16 |
-
|
17 |
-
|
18 |
-
import argparse
|
19 |
-
import os
|
20 |
-
import random
|
21 |
-
from glob import glob
|
22 |
-
from pathlib import Path
|
23 |
-
|
24 |
-
import cv2
|
25 |
-
import numpy as np
|
26 |
-
import torch
|
27 |
-
import torch.backends.cudnn as cudnn
|
28 |
-
import trimesh
|
29 |
-
from insightface.app.common import Face
|
30 |
-
from insightface.utils import face_align
|
31 |
-
from loguru import logger
|
32 |
-
from skimage.io import imread
|
33 |
-
from tqdm import tqdm
|
34 |
-
|
35 |
-
from configs.config import get_cfg_defaults
|
36 |
-
from datasets.creation.util import get_arcface_input, get_center, draw_on
|
37 |
-
from utils import util
|
38 |
-
from utils.landmark_detector import LandmarksDetector, detectors
|
39 |
-
|
40 |
-
|
41 |
-
def deterministic(rank):
|
42 |
-
torch.manual_seed(rank)
|
43 |
-
torch.cuda.manual_seed(rank)
|
44 |
-
np.random.seed(rank)
|
45 |
-
random.seed(rank)
|
46 |
-
|
47 |
-
cudnn.deterministic = True
|
48 |
-
cudnn.benchmark = False
|
49 |
-
|
50 |
-
|
51 |
-
def process(args, app, image_size=224, draw_bbox=False):
|
52 |
-
dst = Path(args.a)
|
53 |
-
dst.mkdir(parents=True, exist_ok=True)
|
54 |
-
processes = []
|
55 |
-
image_paths = sorted(glob(args.i + '/*.*'))
|
56 |
-
for image_path in tqdm(image_paths):
|
57 |
-
name = Path(image_path).stem
|
58 |
-
img = cv2.imread(image_path)
|
59 |
-
bboxes, kpss = app.detect(img)
|
60 |
-
if bboxes.shape[0] == 0:
|
61 |
-
logger.error(f'[ERROR] Face not detected for {image_path}')
|
62 |
-
continue
|
63 |
-
i = get_center(bboxes, img)
|
64 |
-
bbox = bboxes[i, 0:4]
|
65 |
-
det_score = bboxes[i, 4]
|
66 |
-
kps = None
|
67 |
-
if kpss is not None:
|
68 |
-
kps = kpss[i]
|
69 |
-
face = Face(bbox=bbox, kps=kps, det_score=det_score)
|
70 |
-
blob, aimg = get_arcface_input(face, img)
|
71 |
-
file = str(Path(dst, name))
|
72 |
-
np.save(file, blob)
|
73 |
-
processes.append(file + '.npy')
|
74 |
-
cv2.imwrite(file + '.jpg', face_align.norm_crop(img, landmark=face.kps, image_size=image_size))
|
75 |
-
if draw_bbox:
|
76 |
-
dimg = draw_on(img, [face])
|
77 |
-
cv2.imwrite(file + '_bbox.jpg', dimg)
|
78 |
-
|
79 |
-
return processes
|
80 |
-
|
81 |
-
|
82 |
-
def to_batch(path):
|
83 |
-
src = path.replace('npy', 'jpg')
|
84 |
-
if not os.path.exists(src):
|
85 |
-
src = path.replace('npy', 'png')
|
86 |
-
|
87 |
-
image = imread(src)[:, :, :3]
|
88 |
-
image = image / 255.
|
89 |
-
image = cv2.resize(image, (224, 224)).transpose(2, 0, 1)
|
90 |
-
image = torch.tensor(image).cuda()[None]
|
91 |
-
|
92 |
-
arcface = np.load(path)
|
93 |
-
arcface = torch.tensor(arcface).cuda()[None]
|
94 |
-
|
95 |
-
return image, arcface
|
96 |
-
|
97 |
-
|
98 |
-
def load_checkpoint(args, mica):
|
99 |
-
checkpoint = torch.load(args.m)
|
100 |
-
if 'arcface' in checkpoint:
|
101 |
-
mica.arcface.load_state_dict(checkpoint['arcface'])
|
102 |
-
if 'flameModel' in checkpoint:
|
103 |
-
mica.flameModel.load_state_dict(checkpoint['flameModel'])
|
104 |
-
|
105 |
-
|
106 |
-
def main(cfg, args):
|
107 |
-
device = 'cuda:0'
|
108 |
-
cfg.model.testing = True
|
109 |
-
mica = util.find_model_using_name(model_dir='micalib.models', model_name=cfg.model.name)(cfg, device)
|
110 |
-
load_checkpoint(args, mica)
|
111 |
-
mica.eval()
|
112 |
-
|
113 |
-
faces = mica.flameModel.generator.faces_tensor.cpu()
|
114 |
-
Path(args.o).mkdir(exist_ok=True, parents=True)
|
115 |
-
|
116 |
-
app = LandmarksDetector(model=detectors.RETINAFACE)
|
117 |
-
|
118 |
-
with torch.no_grad():
|
119 |
-
logger.info(f'Processing has started...')
|
120 |
-
paths = process(args, app, draw_bbox=False)
|
121 |
-
for path in tqdm(paths):
|
122 |
-
name = Path(path).stem
|
123 |
-
images, arcface = to_batch(path)
|
124 |
-
codedict = mica.encode(images, arcface)
|
125 |
-
opdict = mica.decode(codedict)
|
126 |
-
meshes = opdict['pred_canonical_shape_vertices']
|
127 |
-
code = opdict['pred_shape_code']
|
128 |
-
lmk = mica.flame.compute_landmarks(meshes)
|
129 |
-
|
130 |
-
mesh = meshes[0]
|
131 |
-
landmark_51 = lmk[0, 17:]
|
132 |
-
landmark_7 = landmark_51[[19, 22, 25, 28, 16, 31, 37]]
|
133 |
-
|
134 |
-
dst = Path(args.o, name)
|
135 |
-
dst.mkdir(parents=True, exist_ok=True)
|
136 |
-
trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.ply') # save in millimeters
|
137 |
-
trimesh.Trimesh(vertices=mesh.cpu() * 1000.0, faces=faces, process=False).export(f'{dst}/mesh.obj')
|
138 |
-
np.save(f'{dst}/identity', code[0].cpu().numpy())
|
139 |
-
np.save(f'{dst}/kpt7', landmark_7.cpu().numpy() * 1000.0)
|
140 |
-
np.save(f'{dst}/kpt68', lmk.cpu().numpy() * 1000.0)
|
141 |
-
|
142 |
-
logger.info(f'Processing finished. Results has been saved in {args.o}')
|
143 |
-
|
144 |
-
|
145 |
-
if __name__ == '__main__':
|
146 |
-
parser = argparse.ArgumentParser(description='MICA - Towards Metrical Reconstruction of Human Faces')
|
147 |
-
parser.add_argument('-i', default='demo/input', type=str, help='Input folder with images')
|
148 |
-
parser.add_argument('-o', default='demo/output', type=str, help='Output folder')
|
149 |
-
parser.add_argument('-a', default='demo/arcface', type=str, help='Processed images for MICA input')
|
150 |
-
parser.add_argument('-m', default='data/pretrained/mica.tar', type=str, help='Pretrained model path')
|
151 |
-
|
152 |
-
args = parser.parse_args()
|
153 |
-
cfg = get_cfg_defaults()
|
154 |
-
|
155 |
-
deterministic(42)
|
156 |
-
main(cfg, args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|