Spaces:
Running
on
Zero
Running
on
Zero
Delete src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py
Browse files
src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
from typing import Optional, Dict, Any
|
2 |
-
import functools
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from ..util import download_jit
|
7 |
-
from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
|
8 |
-
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
|
9 |
-
from .base import FaceParser
|
10 |
-
import numpy as np
|
11 |
-
|
12 |
-
pretrain_settings = {
|
13 |
-
'lapa/448': {
|
14 |
-
'url': [
|
15 |
-
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
|
16 |
-
],
|
17 |
-
'matrix_src_tag': 'points',
|
18 |
-
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
19 |
-
target_shape=(448, 448), target_face_scale=1.0),
|
20 |
-
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
21 |
-
warp_factor=0.8, warped_shape=(448, 448)),
|
22 |
-
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
23 |
-
warp_factor=0.8, warped_shape=(448, 448)),
|
24 |
-
'label_names': ['background', 'face', 'rb', 'lb', 're',
|
25 |
-
'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
26 |
-
},
|
27 |
-
'celebm/448': {
|
28 |
-
'url': [
|
29 |
-
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
|
30 |
-
],
|
31 |
-
'matrix_src_tag': 'points',
|
32 |
-
'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
|
33 |
-
target_shape=(448, 448)),
|
34 |
-
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
35 |
-
warp_factor=0, warped_shape=(448, 448)),
|
36 |
-
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
37 |
-
warp_factor=0, warped_shape=(448, 448)),
|
38 |
-
'label_names': [
|
39 |
-
'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
|
40 |
-
'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
|
41 |
-
'eyeg', 'hat', 'earr', 'neck_l']
|
42 |
-
}
|
43 |
-
}
|
44 |
-
|
45 |
-
|
46 |
-
class FaRLFaceParser(FaceParser):
|
47 |
-
""" The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
48 |
-
|
49 |
-
Please consider citing
|
50 |
-
```bibtex
|
51 |
-
@article{zheng2021farl,
|
52 |
-
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
53 |
-
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
54 |
-
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
55 |
-
Dong and Zeng, Ming and Wen, Fang},
|
56 |
-
journal={arXiv preprint arXiv:2112.03109},
|
57 |
-
year={2021}
|
58 |
-
}
|
59 |
-
```
|
60 |
-
"""
|
61 |
-
|
62 |
-
def __init__(self, conf_name: Optional[str] = None, model_path: Optional[str] = None, device=None) -> None:
|
63 |
-
super().__init__()
|
64 |
-
if conf_name is None:
|
65 |
-
conf_name = 'lapa/448'
|
66 |
-
if model_path is None:
|
67 |
-
model_path = pretrain_settings[conf_name]['url']
|
68 |
-
self.conf_name = conf_name
|
69 |
-
self.net = download_jit(model_path, map_location=device)
|
70 |
-
self.eval()
|
71 |
-
self.device = device
|
72 |
-
self.setting = pretrain_settings[conf_name]
|
73 |
-
self.label_names = self.setting['label_names']
|
74 |
-
|
75 |
-
|
76 |
-
def get_warp_grid(self, images: torch.Tensor, matrix_src):
|
77 |
-
_, _, h, w = images.shape
|
78 |
-
matrix = self.setting['get_matrix_fn'](matrix_src)
|
79 |
-
grid = self.setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
80 |
-
inv_grid = self.setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
81 |
-
return grid, inv_grid
|
82 |
-
|
83 |
-
def warp_images(self, images: torch.Tensor, data: Dict[str, Any]):
|
84 |
-
simages = self.unify_image_dtype(images)
|
85 |
-
simages = simages[data['image_ids']]
|
86 |
-
matrix_src = data[self.setting['matrix_src_tag']]
|
87 |
-
grid, inv_grid = self.get_warp_grid(simages, matrix_src)
|
88 |
-
|
89 |
-
w_images = F.grid_sample(
|
90 |
-
simages, grid, mode='bilinear', align_corners=False)
|
91 |
-
return w_images, grid, inv_grid
|
92 |
-
|
93 |
-
|
94 |
-
def decode_image_to_cv2(self, images: torch.Tensor):
|
95 |
-
'''
|
96 |
-
output: b x 3 x h x w, torch.uint8, [0, 255]
|
97 |
-
'''
|
98 |
-
assert images.ndim == 4
|
99 |
-
assert images.shape[1] == 3
|
100 |
-
images = images.permute(0, 2, 3, 1).cpu().numpy() * 255
|
101 |
-
images = images.astype(np.uint8)
|
102 |
-
return images
|
103 |
-
|
104 |
-
def unify_image_dtype(self, images: torch.Tensor|np.ndarray|list):
|
105 |
-
'''
|
106 |
-
output: b x 3 x h x w, torch.float32, [0, 1]
|
107 |
-
'''
|
108 |
-
if isinstance(images, np.ndarray):
|
109 |
-
images = torch.from_numpy(images)
|
110 |
-
elif isinstance(images, torch.Tensor):
|
111 |
-
pass
|
112 |
-
elif isinstance(images, list):
|
113 |
-
assert len(images) > 0, "images is empty"
|
114 |
-
first_image = images[0]
|
115 |
-
if isinstance(first_image, np.ndarray):
|
116 |
-
images = [torch.from_numpy(image).permute(2, 0, 1) for image in images]
|
117 |
-
images = torch.stack(images)
|
118 |
-
elif isinstance(first_image, torch.Tensor):
|
119 |
-
images = torch.stack(images)
|
120 |
-
else:
|
121 |
-
raise ValueError(f"Unsupported image type: {type(first_image)}")
|
122 |
-
|
123 |
-
else:
|
124 |
-
raise ValueError(f"Unsupported image type: {type(images)}")
|
125 |
-
|
126 |
-
assert images.ndim == 4
|
127 |
-
assert images.shape[1] == 3
|
128 |
-
|
129 |
-
max_val = images.max()
|
130 |
-
if max_val <= 1:
|
131 |
-
assert images.dtype == torch.float32 or images.dtype == torch.float16
|
132 |
-
elif max_val <= 255:
|
133 |
-
assert images.dtype == torch.uint8
|
134 |
-
images = images.float() / 255.0
|
135 |
-
else:
|
136 |
-
raise ValueError(f"Unsupported image type: {images.dtype}")
|
137 |
-
if images.device != self.device:
|
138 |
-
images = images.to(device=self.device)
|
139 |
-
return images
|
140 |
-
|
141 |
-
@torch.no_grad()
|
142 |
-
@torch.inference_mode()
|
143 |
-
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
144 |
-
'''
|
145 |
-
images: b x 3 x h x w , torch.uint8, [0, 255]
|
146 |
-
data: {'rects': rects, 'points': points, 'scores': scores, 'image_ids': image_ids}
|
147 |
-
'''
|
148 |
-
w_images, grid, inv_grid = self.warp_images(images, data)
|
149 |
-
w_seg_logits = self.forward_warped(w_images, return_preds=False)
|
150 |
-
|
151 |
-
seg_logits = F.grid_sample(
|
152 |
-
w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
|
153 |
-
|
154 |
-
data['seg'] = {'logits': seg_logits, 'label_names': self.label_names}
|
155 |
-
return data
|
156 |
-
|
157 |
-
|
158 |
-
def logits2predictions(self, logits: torch.Tensor):
|
159 |
-
return logits.argmax(dim=1)
|
160 |
-
|
161 |
-
@torch.no_grad()
|
162 |
-
@torch.inference_mode()
|
163 |
-
def forward_warped(self, images: torch.Tensor, return_preds: bool = True):
|
164 |
-
'''
|
165 |
-
images: b x 3 x h x w , torch.uint8, [0, 255]
|
166 |
-
'''
|
167 |
-
images = self.unify_image_dtype(images)
|
168 |
-
seg_logits, _ = self.net(images) # nfaces x c x h x w
|
169 |
-
# seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
|
170 |
-
if return_preds:
|
171 |
-
seg_preds = self.logits2predictions(seg_logits)
|
172 |
-
return seg_logits, seg_preds, self.label_names
|
173 |
-
else:
|
174 |
-
return seg_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|