alexnasa commited on
Commit
f0d65fe
·
verified ·
1 Parent(s): 44fb8b6

Upload farl.py

Browse files
src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+ import functools
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ..util import download_jit
7
+ from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
8
+ make_inverted_tanh_warp_grid, make_tanh_warp_grid)
9
+ from .base import FaceParser
10
+
11
+ pretrain_settings = {
12
+ 'lapa/448': {
13
+ 'url': [
14
+ 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
15
+ ],
16
+ 'matrix_src_tag': 'points',
17
+ 'get_matrix_fn': functools.partial(get_face_align_matrix,
18
+ target_shape=(448, 448), target_face_scale=1.0),
19
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
20
+ warp_factor=0.8, warped_shape=(448, 448)),
21
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
22
+ warp_factor=0.8, warped_shape=(448, 448)),
23
+ 'label_names': ['background', 'face', 'rb', 'lb', 're',
24
+ 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
25
+ },
26
+ 'celebm/448': {
27
+ 'url': [
28
+ 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
29
+ ],
30
+ 'matrix_src_tag': 'points',
31
+ 'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
32
+ target_shape=(448, 448)),
33
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
34
+ warp_factor=0, warped_shape=(448, 448)),
35
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
36
+ warp_factor=0, warped_shape=(448, 448)),
37
+ 'label_names': [
38
+ 'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
39
+ 'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
40
+ 'eyeg', 'hat', 'earr', 'neck_l']
41
+ }
42
+ }
43
+
44
+
45
+ class FaRLFaceParser(FaceParser):
46
+ """ The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
47
+
48
+ Please consider citing
49
+ ```bibtex
50
+ @article{zheng2021farl,
51
+ title={General Facial Representation Learning in a Visual-Linguistic Manner},
52
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
53
+ Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
54
+ Dong and Zeng, Ming and Wen, Fang},
55
+ journal={arXiv preprint arXiv:2112.03109},
56
+ year={2021}
57
+ }
58
+ ```
59
+ """
60
+
61
+ def __init__(self, conf_name: Optional[str] = None,
62
+ model_path: Optional[str] = None, device=None) -> None:
63
+ super().__init__()
64
+ if conf_name is None:
65
+ conf_name = 'lapa/448'
66
+ if model_path is None:
67
+ model_path = pretrain_settings[conf_name]['url']
68
+ self.conf_name = conf_name
69
+ self.net = download_jit(model_path, map_location=device)
70
+ self.eval()
71
+
72
+ def forward(self, images: torch.Tensor, data: Dict[str, Any], bbox_scale_factor : float = 1.0):
73
+ setting = pretrain_settings[self.conf_name]
74
+ images = images.float() / 255.0
75
+ _, _, h, w = images.shape
76
+
77
+ simages = images[data['image_ids']]
78
+ matrix_fun = functools.partial(get_face_align_matrix_celebm,
79
+ target_shape=(448, 448), bbox_scale_factor=bbox_scale_factor)
80
+ matrix = matrix_fun(data[setting['matrix_src_tag']])
81
+ grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
82
+ inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
83
+
84
+ w_images = F.grid_sample(
85
+ simages, grid, mode='bilinear', align_corners=False)
86
+
87
+ w_seg_logits, _ = self.net(w_images) # (b*n) x c x h x w
88
+
89
+ seg_logits = F.grid_sample(w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
90
+
91
+
92
+ data['seg'] = {'logits': seg_logits,
93
+ 'label_names': setting['label_names']}
94
+ return data