alexnasa commited on
Commit
44fb8b6
·
verified ·
1 Parent(s): b28d79e

Delete src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py

Browse files
src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py DELETED
@@ -1,174 +0,0 @@
1
- from typing import Optional, Dict, Any
2
- import functools
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from ..util import download_jit
7
- from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
8
- make_inverted_tanh_warp_grid, make_tanh_warp_grid)
9
- from .base import FaceParser
10
- import numpy as np
11
-
12
- pretrain_settings = {
13
- 'lapa/448': {
14
- 'url': [
15
- 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
16
- ],
17
- 'matrix_src_tag': 'points',
18
- 'get_matrix_fn': functools.partial(get_face_align_matrix,
19
- target_shape=(448, 448), target_face_scale=1.0),
20
- 'get_grid_fn': functools.partial(make_tanh_warp_grid,
21
- warp_factor=0.8, warped_shape=(448, 448)),
22
- 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
23
- warp_factor=0.8, warped_shape=(448, 448)),
24
- 'label_names': ['background', 'face', 'rb', 'lb', 're',
25
- 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
26
- },
27
- 'celebm/448': {
28
- 'url': [
29
- 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
30
- ],
31
- 'matrix_src_tag': 'points',
32
- 'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
33
- target_shape=(448, 448)),
34
- 'get_grid_fn': functools.partial(make_tanh_warp_grid,
35
- warp_factor=0, warped_shape=(448, 448)),
36
- 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
37
- warp_factor=0, warped_shape=(448, 448)),
38
- 'label_names': [
39
- 'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
40
- 'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
41
- 'eyeg', 'hat', 'earr', 'neck_l']
42
- }
43
- }
44
-
45
-
46
- class FaRLFaceParser(FaceParser):
47
- """ The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
48
-
49
- Please consider citing
50
- ```bibtex
51
- @article{zheng2021farl,
52
- title={General Facial Representation Learning in a Visual-Linguistic Manner},
53
- author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
54
- Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
55
- Dong and Zeng, Ming and Wen, Fang},
56
- journal={arXiv preprint arXiv:2112.03109},
57
- year={2021}
58
- }
59
- ```
60
- """
61
-
62
- def __init__(self, conf_name: Optional[str] = None, model_path: Optional[str] = None, device=None) -> None:
63
- super().__init__()
64
- if conf_name is None:
65
- conf_name = 'lapa/448'
66
- if model_path is None:
67
- model_path = pretrain_settings[conf_name]['url']
68
- self.conf_name = conf_name
69
- self.net = download_jit(model_path, map_location=device)
70
- self.eval()
71
- self.device = device
72
- self.setting = pretrain_settings[conf_name]
73
- self.label_names = self.setting['label_names']
74
-
75
-
76
- def get_warp_grid(self, images: torch.Tensor, matrix_src):
77
- _, _, h, w = images.shape
78
- matrix = self.setting['get_matrix_fn'](matrix_src)
79
- grid = self.setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
80
- inv_grid = self.setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
81
- return grid, inv_grid
82
-
83
- def warp_images(self, images: torch.Tensor, data: Dict[str, Any]):
84
- simages = self.unify_image_dtype(images)
85
- simages = simages[data['image_ids']]
86
- matrix_src = data[self.setting['matrix_src_tag']]
87
- grid, inv_grid = self.get_warp_grid(simages, matrix_src)
88
-
89
- w_images = F.grid_sample(
90
- simages, grid, mode='bilinear', align_corners=False)
91
- return w_images, grid, inv_grid
92
-
93
-
94
- def decode_image_to_cv2(self, images: torch.Tensor):
95
- '''
96
- output: b x 3 x h x w, torch.uint8, [0, 255]
97
- '''
98
- assert images.ndim == 4
99
- assert images.shape[1] == 3
100
- images = images.permute(0, 2, 3, 1).cpu().numpy() * 255
101
- images = images.astype(np.uint8)
102
- return images
103
-
104
- def unify_image_dtype(self, images: torch.Tensor|np.ndarray|list):
105
- '''
106
- output: b x 3 x h x w, torch.float32, [0, 1]
107
- '''
108
- if isinstance(images, np.ndarray):
109
- images = torch.from_numpy(images)
110
- elif isinstance(images, torch.Tensor):
111
- pass
112
- elif isinstance(images, list):
113
- assert len(images) > 0, "images is empty"
114
- first_image = images[0]
115
- if isinstance(first_image, np.ndarray):
116
- images = [torch.from_numpy(image).permute(2, 0, 1) for image in images]
117
- images = torch.stack(images)
118
- elif isinstance(first_image, torch.Tensor):
119
- images = torch.stack(images)
120
- else:
121
- raise ValueError(f"Unsupported image type: {type(first_image)}")
122
-
123
- else:
124
- raise ValueError(f"Unsupported image type: {type(images)}")
125
-
126
- assert images.ndim == 4
127
- assert images.shape[1] == 3
128
-
129
- max_val = images.max()
130
- if max_val <= 1:
131
- assert images.dtype == torch.float32 or images.dtype == torch.float16
132
- elif max_val <= 255:
133
- assert images.dtype == torch.uint8
134
- images = images.float() / 255.0
135
- else:
136
- raise ValueError(f"Unsupported image type: {images.dtype}")
137
- if images.device != self.device:
138
- images = images.to(device=self.device)
139
- return images
140
-
141
- @torch.no_grad()
142
- @torch.inference_mode()
143
- def forward(self, images: torch.Tensor, data: Dict[str, Any]):
144
- '''
145
- images: b x 3 x h x w , torch.uint8, [0, 255]
146
- data: {'rects': rects, 'points': points, 'scores': scores, 'image_ids': image_ids}
147
- '''
148
- w_images, grid, inv_grid = self.warp_images(images, data)
149
- w_seg_logits = self.forward_warped(w_images, return_preds=False)
150
-
151
- seg_logits = F.grid_sample(
152
- w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
153
-
154
- data['seg'] = {'logits': seg_logits, 'label_names': self.label_names}
155
- return data
156
-
157
-
158
- def logits2predictions(self, logits: torch.Tensor):
159
- return logits.argmax(dim=1)
160
-
161
- @torch.no_grad()
162
- @torch.inference_mode()
163
- def forward_warped(self, images: torch.Tensor, return_preds: bool = True):
164
- '''
165
- images: b x 3 x h x w , torch.uint8, [0, 255]
166
- '''
167
- images = self.unify_image_dtype(images)
168
- seg_logits, _ = self.net(images) # nfaces x c x h x w
169
- # seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
170
- if return_preds:
171
- seg_preds = self.logits2predictions(seg_logits)
172
- return seg_logits, seg_preds, self.label_names
173
- else:
174
- return seg_logits