Spaces:
Running
on
Zero
Running
on
Zero
Upload farl.py
Browse files
src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from ..util import download_jit
|
7 |
+
from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
|
8 |
+
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
|
9 |
+
from .base import FaceParser
|
10 |
+
|
11 |
+
pretrain_settings = {
|
12 |
+
'lapa/448': {
|
13 |
+
'url': [
|
14 |
+
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
|
15 |
+
],
|
16 |
+
'matrix_src_tag': 'points',
|
17 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
18 |
+
target_shape=(448, 448), target_face_scale=1.0),
|
19 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
20 |
+
warp_factor=0.8, warped_shape=(448, 448)),
|
21 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
22 |
+
warp_factor=0.8, warped_shape=(448, 448)),
|
23 |
+
'label_names': ['background', 'face', 'rb', 'lb', 're',
|
24 |
+
'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
25 |
+
},
|
26 |
+
'celebm/448': {
|
27 |
+
'url': [
|
28 |
+
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
|
29 |
+
],
|
30 |
+
'matrix_src_tag': 'points',
|
31 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
|
32 |
+
target_shape=(448, 448)),
|
33 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
34 |
+
warp_factor=0, warped_shape=(448, 448)),
|
35 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
36 |
+
warp_factor=0, warped_shape=(448, 448)),
|
37 |
+
'label_names': [
|
38 |
+
'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
|
39 |
+
'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
|
40 |
+
'eyeg', 'hat', 'earr', 'neck_l']
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
class FaRLFaceParser(FaceParser):
|
46 |
+
""" The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
47 |
+
|
48 |
+
Please consider citing
|
49 |
+
```bibtex
|
50 |
+
@article{zheng2021farl,
|
51 |
+
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
52 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
53 |
+
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
54 |
+
Dong and Zeng, Ming and Wen, Fang},
|
55 |
+
journal={arXiv preprint arXiv:2112.03109},
|
56 |
+
year={2021}
|
57 |
+
}
|
58 |
+
```
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, conf_name: Optional[str] = None,
|
62 |
+
model_path: Optional[str] = None, device=None) -> None:
|
63 |
+
super().__init__()
|
64 |
+
if conf_name is None:
|
65 |
+
conf_name = 'lapa/448'
|
66 |
+
if model_path is None:
|
67 |
+
model_path = pretrain_settings[conf_name]['url']
|
68 |
+
self.conf_name = conf_name
|
69 |
+
self.net = download_jit(model_path, map_location=device)
|
70 |
+
self.eval()
|
71 |
+
|
72 |
+
def forward(self, images: torch.Tensor, data: Dict[str, Any], bbox_scale_factor : float = 1.0):
|
73 |
+
setting = pretrain_settings[self.conf_name]
|
74 |
+
images = images.float() / 255.0
|
75 |
+
_, _, h, w = images.shape
|
76 |
+
|
77 |
+
simages = images[data['image_ids']]
|
78 |
+
matrix_fun = functools.partial(get_face_align_matrix_celebm,
|
79 |
+
target_shape=(448, 448), bbox_scale_factor=bbox_scale_factor)
|
80 |
+
matrix = matrix_fun(data[setting['matrix_src_tag']])
|
81 |
+
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
82 |
+
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
83 |
+
|
84 |
+
w_images = F.grid_sample(
|
85 |
+
simages, grid, mode='bilinear', align_corners=False)
|
86 |
+
|
87 |
+
w_seg_logits, _ = self.net(w_images) # (b*n) x c x h x w
|
88 |
+
|
89 |
+
seg_logits = F.grid_sample(w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
|
90 |
+
|
91 |
+
|
92 |
+
data['seg'] = {'logits': seg_logits,
|
93 |
+
'label_names': setting['label_names']}
|
94 |
+
return data
|