alexnasa commited on
Commit
b28d79e
·
verified ·
1 Parent(s): 4522fd6

Upload 52 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. src/pixel3dmm/preprocessing/facer/.gitignore +134 -0
  3. src/pixel3dmm/preprocessing/facer/LICENSE +21 -0
  4. src/pixel3dmm/preprocessing/facer/README.md +187 -0
  5. src/pixel3dmm/preprocessing/facer/facer/__init__.py +55 -0
  6. src/pixel3dmm/preprocessing/facer/facer/draw.py +186 -0
  7. src/pixel3dmm/preprocessing/facer/facer/face_alignment/__init__.py +2 -0
  8. src/pixel3dmm/preprocessing/facer/facer/face_alignment/base.py +24 -0
  9. src/pixel3dmm/preprocessing/facer/facer/face_alignment/farl.py +180 -0
  10. src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/__init__.py +42 -0
  11. src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/common.py +91 -0
  12. src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/geometry.py +45 -0
  13. src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/mmseg.py +29 -0
  14. src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/transformers.py +173 -0
  15. src/pixel3dmm/preprocessing/facer/facer/face_attribute/__init__.py +2 -0
  16. src/pixel3dmm/preprocessing/facer/facer/face_attribute/base.py +24 -0
  17. src/pixel3dmm/preprocessing/facer/facer/face_attribute/farl.py +158 -0
  18. src/pixel3dmm/preprocessing/facer/facer/face_detection/__init__.py +2 -0
  19. src/pixel3dmm/preprocessing/facer/facer/face_detection/base.py +19 -0
  20. src/pixel3dmm/preprocessing/facer/facer/face_detection/retinaface.py +677 -0
  21. src/pixel3dmm/preprocessing/facer/facer/face_parsing/__init__.py +2 -0
  22. src/pixel3dmm/preprocessing/facer/facer/face_parsing/base.py +27 -0
  23. src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py +174 -0
  24. src/pixel3dmm/preprocessing/facer/facer/farl/__init__.py +5 -0
  25. src/pixel3dmm/preprocessing/facer/facer/farl/classification.py +149 -0
  26. src/pixel3dmm/preprocessing/facer/facer/farl/model.py +419 -0
  27. src/pixel3dmm/preprocessing/facer/facer/io.py +28 -0
  28. src/pixel3dmm/preprocessing/facer/facer/show.py +36 -0
  29. src/pixel3dmm/preprocessing/facer/facer/transform.py +384 -0
  30. src/pixel3dmm/preprocessing/facer/facer/util.py +169 -0
  31. src/pixel3dmm/preprocessing/facer/facer/version.py +1 -0
  32. src/pixel3dmm/preprocessing/facer/requirements.txt +11 -0
  33. src/pixel3dmm/preprocessing/facer/samples/data/ffhq_15723.jpg +3 -0
  34. src/pixel3dmm/preprocessing/facer/samples/data/fire.webp +0 -0
  35. src/pixel3dmm/preprocessing/facer/samples/data/girl.jpg +0 -0
  36. src/pixel3dmm/preprocessing/facer/samples/data/sideface.jpg +0 -0
  37. src/pixel3dmm/preprocessing/facer/samples/data/twogirls.jpg +0 -0
  38. src/pixel3dmm/preprocessing/facer/samples/data/weirdface.jpg +3 -0
  39. src/pixel3dmm/preprocessing/facer/samples/data/weirdface2.jpg +0 -0
  40. src/pixel3dmm/preprocessing/facer/samples/data/weirdface3.jpg +0 -0
  41. src/pixel3dmm/preprocessing/facer/samples/download.ipynb +66 -0
  42. src/pixel3dmm/preprocessing/facer/samples/example_output/alignment.png +3 -0
  43. src/pixel3dmm/preprocessing/facer/samples/example_output/detect.png +3 -0
  44. src/pixel3dmm/preprocessing/facer/samples/example_output/parsing.png +3 -0
  45. src/pixel3dmm/preprocessing/facer/samples/face_alignment.ipynb +0 -0
  46. src/pixel3dmm/preprocessing/facer/samples/face_attribute.ipynb +0 -0
  47. src/pixel3dmm/preprocessing/facer/samples/face_detect.ipynb +0 -0
  48. src/pixel3dmm/preprocessing/facer/samples/face_parsing.ipynb +0 -0
  49. src/pixel3dmm/preprocessing/facer/samples/transform.ipynb +0 -0
  50. src/pixel3dmm/preprocessing/facer/scripts/build.sh +1 -0
.gitattributes CHANGED
@@ -41,3 +41,8 @@ example_videos/ex3.mp4 filter=lfs diff=lfs merge=lfs -text
41
  example_videos/ex4.mp4 filter=lfs diff=lfs merge=lfs -text
42
  example_videos/ex5.mp4 filter=lfs diff=lfs merge=lfs -text
43
  media/banner.gif filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
41
  example_videos/ex4.mp4 filter=lfs diff=lfs merge=lfs -text
42
  example_videos/ex5.mp4 filter=lfs diff=lfs merge=lfs -text
43
  media/banner.gif filter=lfs diff=lfs merge=lfs -text
44
+ src/pixel3dmm/preprocessing/facer/samples/data/ffhq_15723.jpg filter=lfs diff=lfs merge=lfs -text
45
+ src/pixel3dmm/preprocessing/facer/samples/data/weirdface.jpg filter=lfs diff=lfs merge=lfs -text
46
+ src/pixel3dmm/preprocessing/facer/samples/example_output/alignment.png filter=lfs diff=lfs merge=lfs -text
47
+ src/pixel3dmm/preprocessing/facer/samples/example_output/detect.png filter=lfs diff=lfs merge=lfs -text
48
+ src/pixel3dmm/preprocessing/facer/samples/example_output/parsing.png filter=lfs diff=lfs merge=lfs -text
src/pixel3dmm/preprocessing/facer/.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ samples/output
131
+ .local/
132
+
133
+ .token.txt
134
+ .downloaded/
src/pixel3dmm/preprocessing/facer/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 FacePerceiver
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
src/pixel3dmm/preprocessing/facer/README.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FACER
2
+
3
+ Face related toolkit. This repo is still under construction to include more models.
4
+
5
+ ## Updates
6
+ - [01/17/2025] FaRL Face Parsing is now available in [elliottzheng/batch-face](https://github.com/elliottzheng/batch-face?tab=readme-ov-file#face-parsing), which provides faster batch processing ability.
7
+ - [14/05/2023] Face attribute recognition model trained on CelebA is available, check it out [here](./samples/face_attribute.ipynb).
8
+ - [04/05/2023] Face alignment model trained on IBUG300W, AFLW19, WFLW dataset is available, check it out [here](./samples/face_alignment.ipynb).
9
+ - [27/04/2023] Face parsing model trained on CelebM dataset is available, check it out [here](./samples/face_parsing.ipynb).
10
+
11
+ ## Install
12
+
13
+ The easiest way to install it is using pip:
14
+
15
+ ```bash
16
+ pip install git+https://github.com/FacePerceiver/facer.git@main
17
+ ```
18
+ No extra setup needs, pretrained weights will be downloaded automatically.
19
+
20
+ If you have trouble install from source, you can try install from PyPI:
21
+ ```bash
22
+ pip install pyfacer
23
+ ```
24
+ the PyPI version is not guaranteed to be the latest version, but we will try to keep it up to date.
25
+
26
+
27
+ ## Face Detection
28
+
29
+ We simply wrap a retinaface detector for easy usage.
30
+ ```python
31
+ import facer
32
+
33
+ image = facer.hwc2bchw(facer.read_hwc('data/twogirls.jpg')).to(device=device) # image: 1 x 3 x h x w
34
+
35
+ face_detector = facer.face_detector('retinaface/mobilenet', device=device)
36
+ with torch.inference_mode():
37
+ faces = face_detector(image)
38
+
39
+ facer.show_bchw(facer.draw_bchw(image, faces))
40
+ ```
41
+ ![](./samples/example_output/detect.png)
42
+
43
+ Check [this notebook](./samples/face_detect.ipynb) for full example.
44
+
45
+ Please consider citing
46
+ ```
47
+ @inproceedings{deng2020retinaface,
48
+ title={Retinaface: Single-shot multi-level face localisation in the wild},
49
+ author={Deng, Jiankang and Guo, Jia and Ververas, Evangelos and Kotsia, Irene and Zafeiriou, Stefanos},
50
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
51
+ pages={5203--5212},
52
+ year={2020}
53
+ }
54
+ ```
55
+
56
+ ## Face Parsing
57
+
58
+ We wrap the [FaRL](https://github.com/faceperceiver/farl) models for face parsing.
59
+ ```python
60
+ import torch
61
+ import facer
62
+
63
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
64
+
65
+ image = facer.hwc2bchw(facer.read_hwc('data/twogirls.jpg')).to(device=device) # image: 1 x 3 x h x w
66
+
67
+ face_detector = facer.face_detector('retinaface/mobilenet', device=device)
68
+ with torch.inference_mode():
69
+ faces = face_detector(image)
70
+
71
+ face_parser = facer.face_parser('farl/lapa/448', device=device) # optional "farl/celebm/448"
72
+
73
+ with torch.inference_mode():
74
+ faces = face_parser(image, faces)
75
+
76
+ seg_logits = faces['seg']['logits']
77
+ seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
78
+ n_classes = seg_probs.size(1)
79
+ vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
80
+ vis_img = vis_seg_probs.sum(0, keepdim=True)
81
+ facer.show_bhw(vis_img)
82
+ facer.show_bchw(facer.draw_bchw(image, faces))
83
+ ```
84
+ ![](./samples/example_output/parsing.png)
85
+
86
+ Check [this notebook](./samples/face_parsing.ipynb) for full example.
87
+
88
+ Please consider citing
89
+ ```
90
+ @inproceedings{zheng2022farl,
91
+ title={General facial representation learning in a visual-linguistic manner},
92
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang},
93
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
94
+ pages={18697--18709},
95
+ year={2022}
96
+ }
97
+ ```
98
+
99
+
100
+ ## Face Alignment
101
+
102
+ We wrap the [FaRL](https://github.com/faceperceiver/farl) models for face alignment.
103
+ ```python
104
+ import torch
105
+ import cv2
106
+ from matplotlib import pyplot as plt
107
+
108
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
109
+
110
+ import facer
111
+ img_file = 'data/twogirls.jpg'
112
+ # image: 1 x 3 x h x w
113
+ image = facer.hwc2bchw(facer.read_hwc(img_file)).to(device=device)
114
+
115
+ face_detector = facer.face_detector('retinaface/mobilenet', device=device)
116
+ with torch.inference_mode():
117
+ faces = face_detector(image)
118
+
119
+ face_aligner = facer.face_aligner('farl/ibug300w/448', device=device) # optional: "farl/wflw/448", "farl/aflw19/448"
120
+
121
+ with torch.inference_mode():
122
+ faces = face_aligner(image, faces)
123
+
124
+ img = cv2.imread(img_file)[..., ::-1]
125
+ vis_img = img.copy()
126
+ for pts in faces['alignment']:
127
+ vis_img = facer.draw_landmarks(vis_img, None, pts.cpu().numpy())
128
+ plt.imshow(vis_img)
129
+ ```
130
+ ![](./samples/example_output/alignment.png)
131
+
132
+ Check [this notebook](./samples/face_alignment.ipynb) for full example.
133
+
134
+ Please consider citing
135
+ ```
136
+ @inproceedings{zheng2022farl,
137
+ title={General facial representation learning in a visual-linguistic manner},
138
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang},
139
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
140
+ pages={18697--18709},
141
+ year={2022}
142
+ }
143
+ ```
144
+
145
+ ## Face Attribute Recognition
146
+ We wrap the [FaRL](https://github.com/faceperceiver/farl) models for face attribute recognition, the model achieves 92.06% accuracy on [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset.
147
+
148
+ ```python
149
+ import sys
150
+ import torch
151
+ import facer
152
+
153
+ device = "cuda" if torch.cuda.is_available() else "cpu"
154
+
155
+ # image: 1 x 3 x h x w
156
+ image = facer.hwc2bchw(facer.read_hwc("data/girl.jpg")).to(device=device)
157
+
158
+ face_detector = facer.face_detector("retinaface/mobilenet", device=device)
159
+ with torch.inference_mode():
160
+ faces = face_detector(image)
161
+
162
+ face_attr = facer.face_attr("farl/celeba/224", device=device)
163
+ with torch.inference_mode():
164
+ faces = face_attr(image, faces)
165
+
166
+ labels = face_attr.labels
167
+ face1_attrs = faces["attrs"][0] # get the first face's attributes
168
+
169
+ print(labels)
170
+
171
+ for prob, label in zip(face1_attrs, labels):
172
+ if prob > 0.5:
173
+ print(label, prob.item())
174
+ ```
175
+
176
+ Check [this notebook](./samples/face_attribute.ipynb) for full example.
177
+
178
+ Please consider citing
179
+ ```
180
+ @inproceedings{zheng2022farl,
181
+ title={General facial representation learning in a visual-linguistic manner},
182
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang},
183
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
184
+ pages={18697--18709},
185
+ year={2022}
186
+ }
187
+ ```
src/pixel3dmm/preprocessing/facer/facer/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+
4
+ from .io import read_hwc, write_hwc
5
+ from .util import hwc2bchw, bchw2hwc, bchw2bhwc, bhwc2bchw, bhwc2hwc
6
+ from .draw import draw_bchw, draw_landmarks
7
+ from .show import show_bchw, show_bhw
8
+
9
+ from .face_detection import FaceDetector
10
+ from .face_parsing import FaceParser
11
+ from .face_alignment import FaceAlignment
12
+ from .face_attribute import FaceAttribute
13
+
14
+
15
+ def _split_name(name: str) -> Tuple[str, Optional[str]]:
16
+ if '/' in name:
17
+ detector_type, conf_name = name.split('/', 1)
18
+ else:
19
+ detector_type, conf_name = name, None
20
+ return detector_type, conf_name
21
+
22
+
23
+ def face_detector(name: str, device: torch.device, **kwargs) -> FaceDetector:
24
+ detector_type, conf_name = _split_name(name)
25
+ if detector_type == 'retinaface':
26
+ from .face_detection import RetinaFaceDetector
27
+ return RetinaFaceDetector(conf_name, **kwargs).to(device)
28
+ else:
29
+ raise RuntimeError(f'Unknown detector type: {detector_type}')
30
+
31
+
32
+ def face_parser(name: str, device: torch.device, **kwargs) -> FaceParser:
33
+ parser_type, conf_name = _split_name(name)
34
+ if parser_type == 'farl':
35
+ from .face_parsing import FaRLFaceParser
36
+ return FaRLFaceParser(conf_name, device=device, **kwargs).to(device)
37
+ else:
38
+ raise RuntimeError(f'Unknown parser type: {parser_type}')
39
+
40
+
41
+ def face_aligner(name: str, device: torch.device, **kwargs) -> FaceAlignment:
42
+ aligner_type, conf_name = _split_name(name)
43
+ if aligner_type == 'farl':
44
+ from .face_alignment import FaRLFaceAlignment
45
+ return FaRLFaceAlignment(conf_name, device=device, **kwargs).to(device)
46
+ else:
47
+ raise RuntimeError(f'Unknown aligner type: {aligner_type}')
48
+
49
+ def face_attr(name: str, device: torch.device, **kwargs) -> FaceAttribute:
50
+ attr_type, conf_name = _split_name(name)
51
+ if attr_type == 'farl':
52
+ from .face_attribute import FaRLFaceAttribute
53
+ return FaRLFaceAttribute(conf_name, device=device, **kwargs).to(device)
54
+ else:
55
+ raise RuntimeError(f'Unknown attribute type: {attr_type}')
src/pixel3dmm/preprocessing/facer/facer/draw.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ import colorsys
4
+ import random
5
+ import numpy as np
6
+ from skimage.draw import line_aa, circle_perimeter_aa
7
+ import cv2
8
+ from .util import select_data
9
+
10
+
11
+ def _gen_random_colors(N, bright=True):
12
+ brightness = 1.0 if bright else 0.7
13
+ hsv = [(i / N, 1, brightness) for i in range(N)]
14
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
15
+ random.shuffle(colors)
16
+ return colors
17
+
18
+
19
+ _static_label_colors = [
20
+ np.array((1.0, 1.0, 1.0), np.float32),
21
+ np.array((255, 250, 79), np.float32) / 255.0, # face
22
+ np.array([255, 125, 138], np.float32) / 255.0, # lb
23
+ np.array([213, 32, 29], np.float32) / 255.0, # rb
24
+ np.array([0, 144, 187], np.float32) / 255.0, # le
25
+ np.array([0, 196, 253], np.float32) / 255.0, # re
26
+ np.array([255, 129, 54], np.float32) / 255.0, # nose
27
+ np.array([88, 233, 135], np.float32) / 255.0, # ulip
28
+ np.array([0, 117, 27], np.float32) / 255.0, # llip
29
+ np.array([255, 76, 249], np.float32) / 255.0, # imouth
30
+ np.array((1.0, 0.0, 0.0), np.float32), # hair
31
+ np.array((255, 250, 100), np.float32) / 255.0, # lr
32
+ np.array((255, 250, 100), np.float32) / 255.0, # rr
33
+ np.array((250, 245, 50), np.float32) / 255.0, # neck
34
+ np.array((0.0, 1.0, 0.5), np.float32), # cloth
35
+ np.array((1.0, 0.0, 0.5), np.float32),
36
+ ] + _gen_random_colors(256)
37
+
38
+ _names_in_static_label_colors = [
39
+ 'background', 'face', 'lb', 'rb', 'le', 're', 'nose',
40
+ 'ulip', 'llip', 'imouth', 'hair', 'lr', 'rr', 'neck',
41
+ 'cloth', 'eyeg', 'hat', 'earr'
42
+ ]
43
+
44
+
45
+ def _blend_labels(image, labels, label_names_dict=None,
46
+ default_alpha=0.6, color_offset=None):
47
+ assert labels.ndim == 2
48
+ bg_mask = labels == 0
49
+ if label_names_dict is None:
50
+ colors = _static_label_colors
51
+ else:
52
+ colors = [np.array((1.0, 1.0, 1.0), np.float32)]
53
+ for i in range(1, labels.max() + 1):
54
+ if isinstance(label_names_dict, dict) and i not in label_names_dict:
55
+ bg_mask = np.logical_or(bg_mask, labels == i)
56
+ colors.append(np.zeros((3)))
57
+ continue
58
+ label_name = label_names_dict[i]
59
+ if label_name in _names_in_static_label_colors:
60
+ color = _static_label_colors[
61
+ _names_in_static_label_colors.index(
62
+ label_name)]
63
+ else:
64
+ color = np.array((1.0, 1.0, 1.0), np.float32)
65
+ colors.append(color)
66
+
67
+ if color_offset is not None:
68
+ ncolors = []
69
+ for c in colors:
70
+ nc = np.array(c)
71
+ if (nc != np.zeros(3)).any():
72
+ nc += color_offset
73
+ ncolors.append(nc)
74
+ colors = ncolors
75
+
76
+ if image is None:
77
+ image = orig_image = np.zeros(
78
+ [labels.shape[0], labels.shape[1], 3], np.float32)
79
+ alpha = 1.0
80
+ else:
81
+ orig_image = image / np.max(image)
82
+ image = orig_image * (1.0 - default_alpha)
83
+ alpha = default_alpha
84
+ for i in range(1, np.max(labels) + 1):
85
+ image += alpha * \
86
+ np.tile(
87
+ np.expand_dims(
88
+ (labels == i).astype(np.float32), -1),
89
+ [1, 1, 3]) * colors[(i) % len(colors)]
90
+ image[np.where(image > 1.0)] = 1.0
91
+ image[np.where(image < 0)] = 0.0
92
+ image[np.where(bg_mask)] = orig_image[np.where(bg_mask)]
93
+ return image
94
+
95
+
96
+ def _draw_hwc(image: torch.Tensor, data: Dict[str, torch.Tensor]):
97
+ device = image.device
98
+ image = np.array(image.cpu().numpy(), copy=True)
99
+ dtype = image.dtype
100
+ h, w, _ = image.shape
101
+
102
+ draw_score_error = False
103
+ for tag, batch_content in data.items():
104
+ if tag == 'rects':
105
+ for cid, content in enumerate(batch_content):
106
+ x1, y1, x2, y2 = [int(v) for v in content]
107
+ y1, y2 = [max(min(v, h-1), 0) for v in [y1, y2]]
108
+ x1, x2 = [max(min(v, w-1), 0) for v in [x1, x2]]
109
+ for xx1, yy1, xx2, yy2 in [
110
+ [x1, y1, x2, y1],
111
+ [x1, y2, x2, y2],
112
+ [x1, y1, x1, y2],
113
+ [x2, y1, x2, y2]
114
+ ]:
115
+ rr, cc, val = line_aa(yy1, xx1, yy2, xx2)
116
+ val = val[:, None][:, [0, 0, 0]]
117
+ image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255
118
+
119
+ if 'scores' in data:
120
+ try:
121
+ import cv2
122
+ score = data['scores'][cid].item()
123
+ score_str = f'{score:0.3f}'
124
+ image_c = np.array(image).copy()
125
+ cv2.putText(image_c, score_str, org=(int(x1), int(y2)),
126
+ fontFace=cv2.FONT_HERSHEY_TRIPLEX,
127
+ fontScale=0.6, color=(255, 255, 255), thickness=1)
128
+ image[:, :, :] = image_c
129
+ except Exception as e:
130
+ if not draw_score_error:
131
+ print(f'Failed to draw scores on image.')
132
+ print(e)
133
+ draw_score_error = True
134
+
135
+ if tag == 'points':
136
+ for content in batch_content:
137
+ # content: npoints x 2
138
+ for x, y in content:
139
+ x = max(min(int(x), w-1), 0)
140
+ y = max(min(int(y), h-1), 0)
141
+ rr, cc, val = circle_perimeter_aa(y, x, 1)
142
+ valid = np.all([rr >= 0, rr < h, cc >= 0, cc < w], axis=0)
143
+ rr = rr[valid]
144
+ cc = cc[valid]
145
+ val = val[valid]
146
+ val = val[:, None][:, [0, 0, 0]]
147
+ image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255
148
+
149
+ if tag == 'seg':
150
+ label_names = batch_content['label_names']
151
+ for seg_logits in batch_content['logits']:
152
+ # content: nclasses x h x w
153
+ seg_probs = seg_logits.softmax(dim=0)
154
+ seg_labels = seg_probs.argmax(dim=0).cpu().numpy()
155
+ image = (_blend_labels(image.astype(np.float32) /
156
+ 255, seg_labels,
157
+ label_names_dict=label_names) * 255).astype(dtype)
158
+
159
+ return torch.from_numpy(image).to(device=device)
160
+
161
+
162
+ def draw_bchw(images: torch.Tensor, data: Dict[str, torch.Tensor]) -> torch.Tensor:
163
+ images2 = []
164
+ for image_id, image_chw in enumerate(images):
165
+ selected_data = select_data(image_id == data['image_ids'], data)
166
+ images2.append(
167
+ _draw_hwc(image_chw.permute(1, 2, 0), selected_data).permute(2, 0, 1))
168
+ return torch.stack(images2, dim=0)
169
+
170
+ def draw_landmarks(img, bbox=None, landmark=None, color=(0, 255, 0)):
171
+ """
172
+ Input:
173
+ - img: gray or RGB
174
+ - bbox: type of BBox
175
+ - landmark: reproject landmark of (5L, 2L)
176
+ Output:
177
+ - img marked with landmark and bbox
178
+ """
179
+ img = cv2.UMat(img).get()
180
+ if bbox is not None:
181
+ x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32)
182
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
183
+ if landmark is not None:
184
+ for x, y in np.array(landmark).astype(np.int32):
185
+ cv2.circle(img, (int(x), int(y)), 2, color, -1)
186
+ return img
src/pixel3dmm/preprocessing/facer/facer/face_alignment/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import FaceAlignment
2
+ from .farl import FaRLFaceAlignment
src/pixel3dmm/preprocessing/facer/facer/face_alignment/base.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class FaceAlignment(nn.Module):
5
+ """ face alignment
6
+
7
+ Args:
8
+ images (torch.Tensor): b x c x h x w
9
+
10
+ data (Dict[str, Any]):
11
+
12
+ * image_ids (torch.Tensor): nfaces
13
+ * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
14
+ * points (torch.Tensor): nfaces x 5 x 2 (x, y)
15
+
16
+ Returns:
17
+ data (Dict[str, Any]):
18
+
19
+ * image_ids (torch.Tensor): nfaces
20
+ * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
21
+ * points (torch.Tensor): nfaces x 5 x 2 (x, y)
22
+ * alignment
23
+ """
24
+ pass
src/pixel3dmm/preprocessing/facer/facer/face_alignment/farl.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+ import functools
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from .network import FaRLVisualFeatures, MMSEG_UPerHead, FaceAlignmentTransformer, denormalize_points, heatmap2points
6
+ from ..transform import (get_face_align_matrix,
7
+ make_inverted_tanh_warp_grid, make_tanh_warp_grid)
8
+ from .base import FaceAlignment
9
+ from ..util import download_jit
10
+ import io
11
+
12
+ pretrain_settings = {
13
+ 'ibug300w/448': {
14
+ # inter_ocular 0.028835 epoch 60
15
+ 'num_classes': 68,
16
+ 'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.ibug300w.main_ema_jit.pt",
17
+ 'matrix_src_tag': 'points',
18
+ 'get_matrix_fn': functools.partial(get_face_align_matrix,
19
+ target_shape=(448, 448), target_face_scale=0.8),
20
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
21
+ warp_factor=0.0, warped_shape=(448, 448)),
22
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
23
+ warp_factor=0.0, warped_shape=(448, 448)),
24
+
25
+ },
26
+ 'aflw19/448': {
27
+ # diag 0.009329 epoch 15
28
+ 'num_classes': 19,
29
+ 'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.aflw19.main_ema_jit.pt",
30
+ 'matrix_src_tag': 'points',
31
+ 'get_matrix_fn': functools.partial(get_face_align_matrix,
32
+ target_shape=(448, 448), target_face_scale=0.8),
33
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
34
+ warp_factor=0.0, warped_shape=(448, 448)),
35
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
36
+ warp_factor=0.0, warped_shape=(448, 448)),
37
+ },
38
+ 'wflw/448': {
39
+ # inter_ocular 0.038933 epoch 20
40
+ 'num_classes': 98,
41
+ 'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.wflw.main_ema_jit.pt",
42
+ 'matrix_src_tag': 'points',
43
+ 'get_matrix_fn': functools.partial(get_face_align_matrix,
44
+ target_shape=(448, 448), target_face_scale=0.8),
45
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
46
+ warp_factor=0.0, warped_shape=(448, 448)),
47
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
48
+ warp_factor=0.0, warped_shape=(448, 448)),
49
+ },
50
+
51
+ }
52
+
53
+
54
+ def load_face_alignment_model(model_path: str, num_classes=68):
55
+ backbone = FaRLVisualFeatures("base", None, forced_input_resolution=448, output_indices=None).cpu()
56
+ if "jit" in model_path:
57
+ extra_files = {"backbone": None}
58
+ heatmap_head = download_jit(model_path, map_location="cpu", _extra_files=extra_files)
59
+ backbone_weight_io = io.BytesIO(extra_files["backbone"])
60
+ backbone.load_state_dict(torch.load(backbone_weight_io))
61
+ # print("load from jit")
62
+ else:
63
+ channels = backbone.get_output_channel("base")
64
+ in_channels = [channels] * 4
65
+ num_classes = num_classes
66
+ heatmap_head = MMSEG_UPerHead(in_channels=in_channels, channels=channels, num_classes=num_classes) # this requires mmseg as a dependency
67
+ state = torch.load(model_path,map_location="cpu")["networks"]["main_ema"]
68
+ # print("load from checkpoint")
69
+
70
+ main_network = FaceAlignmentTransformer(backbone, heatmap_head, heatmap_act="sigmoid").cpu()
71
+
72
+ if "jit" not in model_path:
73
+ main_network.load_state_dict(state, strict=True)
74
+
75
+ return main_network
76
+
77
+
78
+
79
+ class FaRLFaceAlignment(FaceAlignment):
80
+ """ The face alignment models from [FaRL](https://github.com/FacePerceiver/FaRL).
81
+
82
+ Please consider citing
83
+ ```bibtex
84
+ @article{zheng2021farl,
85
+ title={General Facial Representation Learning in a Visual-Linguistic Manner},
86
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
87
+ Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
88
+ Dong and Zeng, Ming and Wen, Fang},
89
+ journal={arXiv preprint arXiv:2112.03109},
90
+ year={2021}
91
+ }
92
+ ```
93
+ """
94
+
95
+ def __init__(self, conf_name: Optional[str] = None,
96
+ model_path: Optional[str] = None, device=None) -> None:
97
+ super().__init__()
98
+ if conf_name is None:
99
+ conf_name = 'ibug300w/448'
100
+ if model_path is None:
101
+ model_path = pretrain_settings[conf_name]['url']
102
+ self.conf_name = conf_name
103
+
104
+ setting = pretrain_settings[self.conf_name]
105
+ self.net = load_face_alignment_model(model_path, num_classes = setting["num_classes"])
106
+ if device is not None:
107
+ self.net = self.net.to(device)
108
+
109
+ self.heatmap_interpolate_mode = 'bilinear'
110
+ self.eval()
111
+
112
+ def forward(self, images: torch.Tensor, data: Dict[str, Any]):
113
+ setting = pretrain_settings[self.conf_name]
114
+ images = images.float() / 255.0 # backbone 自带 normalize
115
+ _, _, h, w = images.shape
116
+
117
+ simages = images[data['image_ids']]
118
+ matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']])
119
+ grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
120
+ inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
121
+
122
+ w_images = F.grid_sample(
123
+ simages, grid, mode='bilinear', align_corners=False)
124
+
125
+ _, _, warp_h, warp_w = w_images.shape
126
+
127
+ heatmap_acted = self.net(w_images)
128
+
129
+ warpped_heatmap = F.interpolate(
130
+ heatmap_acted, size=(warp_h, warp_w),
131
+ mode=self.heatmap_interpolate_mode, align_corners=False)
132
+
133
+ pred_heatmap = F.grid_sample(
134
+ warpped_heatmap, inv_grid, mode='bilinear', align_corners=False)
135
+
136
+ landmark = heatmap2points(pred_heatmap)
137
+
138
+ landmark = denormalize_points(landmark, h, w)
139
+
140
+ data['alignment'] = landmark
141
+
142
+ return data
143
+
144
+
145
+ if __name__=="__main__":
146
+ image = torch.randn(1, 3, 448, 448)
147
+
148
+ aligner1 = FaRLFaceAlignment("wflw/448")
149
+
150
+ x1 = aligner1.net(image)
151
+
152
+ import argparse
153
+
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument("--jit_path", type=str, default=None)
156
+ args = parser.parse_args()
157
+
158
+ if args.jit_path is None:
159
+ exit(0)
160
+
161
+ net = aligner1.net.cpu()
162
+
163
+ features, _ = net.backbone(image)
164
+
165
+ # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
166
+ traced_script_module = torch.jit.trace(net.heatmap_head, example_inputs=[features])
167
+
168
+ buffer = io.BytesIO()
169
+
170
+ torch.save(net.backbone.state_dict(), buffer)
171
+
172
+ # Save to file
173
+ torch.jit.save(traced_script_module, args.jit_path,
174
+ _extra_files={"backbone": buffer.getvalue()})
175
+
176
+ aligner2 = FaRLFaceAlignment(model_path=args.jit_path)
177
+
178
+ # compare the output
179
+ x2 = aligner2.net(image)
180
+ print(torch.allclose(x1, x2))
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from .common import (load_checkpoint, Activation, MLP, Residual)
5
+ from .geometry import (normalize_points, denormalize_points,
6
+ heatmap2points)
7
+ from .mmseg import MMSEG_UPerHead
8
+ from .transformers import FaRLVisualFeatures
9
+ from torch import nn
10
+ from typing import Optional, List, Tuple
11
+
12
+
13
+
14
+ class FaceAlignmentTransformer(nn.Module):
15
+ """Face alignment transformer.
16
+
17
+ Args:
18
+ image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
19
+
20
+ Returns:
21
+ landmark (torch.Tensor): Float32 tensor with shape [b, npoints, 2], coordinates normalized to [0, 1].
22
+ aux_outputs:
23
+ heatmap (torch.Tensor): Float32 tensor with shape [b, npoints, S, S]
24
+ """
25
+
26
+ def __init__(self, backbone: nn.Module, heatmap_head: nn.Module,
27
+ heatmap_act: Optional[str] = 'relu'):
28
+ super().__init__()
29
+ self.backbone = backbone
30
+ self.heatmap_head = heatmap_head
31
+ self.heatmap_act = Activation(heatmap_act)
32
+ self.float()
33
+
34
+ def forward(self, image):
35
+ features, _ = self.backbone(image)
36
+ heatmap = self.heatmap_head(features) # b x npoints x s x s
37
+ heatmap_acted = self.heatmap_act(heatmap)
38
+ # landmark = heatmap2points(heatmap_acted) # b x npoints x 2
39
+ # return landmark, {'heatmap': heatmap, 'heatmap_acted': heatmap_acted}
40
+ return heatmap_acted
41
+
42
+
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/common.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from typing import List, Optional, Tuple, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ def load_checkpoint(net: nn.Module, checkpoint_path: str, network_name: str):
12
+ states = torch.load(open(checkpoint_path, 'rb'), map_location={
13
+ 'cuda:0': f'cuda:{torch.cuda.current_device()}'})
14
+ network_states = states['networks']
15
+ net.load_state_dict(network_states[network_name])
16
+ return net
17
+
18
+
19
+ class Activation(nn.Module):
20
+ def __init__(self, name: Optional[str], **kwargs):
21
+ super().__init__()
22
+ if name == 'relu':
23
+ self.fn = F.relu
24
+ elif name == 'softplus':
25
+ self.fn = F.softplus
26
+ elif name == 'gelu':
27
+ self.fn = F.gelu
28
+ elif name == 'sigmoid':
29
+ self.fn = torch.sigmoid
30
+ elif name == 'sigmoid_x':
31
+ self.epsilon = kwargs.get('epsilon', 1e-3)
32
+ self.fn = lambda x: torch.clamp(
33
+ x.sigmoid() * (1.0 + self.epsilon*2.0) - self.epsilon,
34
+ min=0.0, max=1.0)
35
+ elif name == None:
36
+ self.fn = lambda x: x
37
+ else:
38
+ raise RuntimeError(f'Unknown activation name: {name}')
39
+
40
+ def forward(self, x):
41
+ return self.fn(x)
42
+
43
+
44
+ class MLP(nn.Module):
45
+ def __init__(self, channels: List[int], act: Optional[str]):
46
+ super().__init__()
47
+ assert len(channels) > 1
48
+ layers = []
49
+ for i in range(len(channels)-1):
50
+ layers.append(nn.Linear(channels[i], channels[i+1]))
51
+ if i+1 < len(channels):
52
+ layers.append(Activation(act))
53
+ self.layers = nn.Sequential(*layers)
54
+
55
+ def forward(self, x):
56
+ return self.layers(x)
57
+
58
+
59
+ class Residual(nn.Module):
60
+ def __init__(self, net: nn.Module, res_weight_init: Optional[float] = 0.0):
61
+ super().__init__()
62
+ self.net = net
63
+ if res_weight_init is not None:
64
+ self.res_weight = nn.Parameter(torch.tensor(res_weight_init))
65
+ else:
66
+ self.res_weight = None
67
+
68
+ def forward(self, x):
69
+ if self.res_weight is not None:
70
+ return self.res_weight * self.net(x) + x
71
+ else:
72
+ return self.net(x) + x
73
+
74
+
75
+ class SE(nn.Module):
76
+ def __init__(self, channel: int, r: int = 1):
77
+ super().__init__()
78
+ self.branch = nn.Sequential(
79
+ nn.Conv2d(channel, channel//r, (1, 1)),
80
+ nn.ReLU(),
81
+ nn.Conv2d(channel//r, channel, (1, 1)),
82
+ nn.Sigmoid()
83
+ )
84
+
85
+ def forward(self, x):
86
+ # x: b x channel x h x w
87
+ v = x.mean([2, 3], keepdim=True) # b x channel x 1 x 1
88
+ v = self.branch(v) # b x channel x 1 x 1
89
+ return x * v
90
+
91
+
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/geometry.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from typing import Tuple, Union
5
+
6
+ import torch
7
+
8
+
9
+ def normalize_points(points: torch.Tensor, h: int, w: int) -> torch.Tensor:
10
+ """ Normalize coordinates to [0, 1].
11
+ """
12
+ return (points + 0.5) / torch.tensor([[[w, h]]]).to(points)
13
+
14
+
15
+ def denormalize_points(normalized_points: torch.Tensor, h: int, w: int) -> torch.Tensor:
16
+ """ Reverse normalize_points.
17
+ """
18
+ return normalized_points * torch.tensor([[[w, h]]]).to(normalized_points) - 0.5
19
+
20
+
21
+ def heatmap2points(heatmap, t_scale: Union[None, float, torch.Tensor] = None):
22
+ """ Heatmaps -> normalized points [b x npoints x 2(XY)].
23
+ """
24
+ dtype = heatmap.dtype
25
+ _, _, h, w = heatmap.shape
26
+
27
+ # 0 ~ h-1, 0 ~ w-1
28
+ yy, xx = torch.meshgrid(
29
+ torch.arange(h).float(),
30
+ torch.arange(w).float())
31
+
32
+ yy = yy.view(1, 1, h, w).to(heatmap)
33
+ xx = xx.view(1, 1, h, w).to(heatmap)
34
+
35
+ if t_scale is not None:
36
+ heatmap = (heatmap * t_scale).exp()
37
+ heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
38
+
39
+ yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # b x npoints
40
+ xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # b x npoints
41
+
42
+ points = torch.stack([xx_coord, yy_coord], dim=-1) # b x npoints x 2
43
+
44
+ normalized_points = normalize_points(points, h, w)
45
+ return normalized_points
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/mmseg.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class MMSEG_UPerHead(nn.Module):
8
+ """Wraps the UPerHead from mmseg for segmentation.
9
+ """
10
+
11
+ def __init__(self, num_classes: int,
12
+ in_channels: list = [384, 384, 384, 384], channels: int = 512):
13
+ super().__init__()
14
+
15
+ from mmseg.models.decode_heads import UPerHead
16
+ self.head = UPerHead(
17
+ in_channels=in_channels,
18
+ in_index=[0, 1, 2, 3],
19
+ pool_scales=(1, 2, 3, 6),
20
+ channels=channels,
21
+ dropout_ratio=0.1,
22
+ num_classes=num_classes,
23
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
24
+ align_corners=False,
25
+ loss_decode=dict(
26
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
27
+
28
+ def forward(self, inputs):
29
+ return self.head(inputs)
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/transformers.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import math
5
+
6
+ from typing import Optional, List, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ from ... import farl
14
+
15
+
16
+ def _make_fpns(vision_patch_size: int, output_channels: int):
17
+ if vision_patch_size in {16, 14}:
18
+ fpn1 = nn.Sequential(
19
+ nn.ConvTranspose2d(output_channels, output_channels,
20
+ kernel_size=2, stride=2),
21
+ nn.SyncBatchNorm(output_channels),
22
+ nn.GELU(),
23
+ nn.ConvTranspose2d(output_channels, output_channels, kernel_size=2, stride=2))
24
+
25
+ fpn2 = nn.ConvTranspose2d(
26
+ output_channels, output_channels, kernel_size=2, stride=2)
27
+ fpn3 = nn.Identity()
28
+ fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
29
+ return nn.ModuleList([fpn1, fpn2, fpn3, fpn4])
30
+ elif vision_patch_size == 8:
31
+ fpn1 = nn.Sequential(nn.ConvTranspose2d(
32
+ output_channels, output_channels, kernel_size=2, stride=2))
33
+ fpn2 = nn.Identity()
34
+ fpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
35
+ fpn4 = nn.MaxPool2d(kernel_size=4, stride=4)
36
+ return nn.ModuleList([fpn1, fpn2, fpn3, fpn4])
37
+ else:
38
+ raise NotImplementedError()
39
+
40
+
41
+ def _resize_pe(pe: torch.Tensor, new_size: int, mode: str = 'bicubic', num_tokens: int = 1) -> torch.Tensor:
42
+ """Resize positional embeddings.
43
+
44
+ Args:
45
+ pe (torch.Tensor): A tensor with shape (num_tokens + old_size ** 2, width). pe[0, :] is the CLS token.
46
+
47
+ Returns:
48
+ torch.Tensor: A tensor with shape (num_tokens + new_size **2, width).
49
+ """
50
+ l, w = pe.shape
51
+ old_size = int(math.sqrt(l-num_tokens))
52
+ assert old_size ** 2 + num_tokens == l
53
+ return torch.cat([
54
+ pe[:num_tokens, :],
55
+ F.interpolate(pe[num_tokens:, :].reshape(1, old_size, old_size, w).permute(0, 3, 1, 2),
56
+ (new_size, new_size), mode=mode, align_corners=False).view(w, -1).t()], dim=0)
57
+
58
+
59
+ class FaRLVisualFeatures(nn.Module):
60
+ """Extract features from FaRL visual encoder.
61
+
62
+ Args:
63
+ image (torch.Tensor): Float32 tensor with shape [b, 3, h, w],
64
+ normalized to [0, 1].
65
+
66
+ Returns:
67
+ List[torch.Tensor]: A list of features.
68
+ """
69
+ image_mean: torch.Tensor
70
+ image_std: torch.Tensor
71
+ output_channels: int
72
+ num_outputs: int
73
+
74
+ def __init__(self, model_type: str,
75
+ model_path: Optional[str] = None, output_indices: Optional[List[int]] = None,
76
+ forced_input_resolution: Optional[int] = None,
77
+ apply_fpn: bool = True):
78
+ super().__init__()
79
+ self.visual = farl.load_farl(model_type, model_path)
80
+
81
+ vision_patch_size = self.visual.conv1.weight.shape[-1]
82
+
83
+ self.input_resolution = self.visual.input_resolution
84
+ if forced_input_resolution is not None and \
85
+ self.input_resolution != forced_input_resolution:
86
+ # resizing the positonal embeddings
87
+ self.visual.positional_embedding = nn.Parameter(
88
+ _resize_pe(self.visual.positional_embedding,
89
+ forced_input_resolution//vision_patch_size))
90
+ self.input_resolution = forced_input_resolution
91
+
92
+ self.output_channels = self.visual.transformer.width
93
+
94
+ if output_indices is None:
95
+ output_indices = self.__class__.get_default_output_indices(
96
+ model_type)
97
+ self.output_indices = output_indices
98
+ self.num_outputs = len(output_indices)
99
+
100
+ self.register_buffer('image_mean', torch.tensor(
101
+ [0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1))
102
+ self.register_buffer('image_std', torch.tensor(
103
+ [0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1))
104
+
105
+ if apply_fpn:
106
+ self.fpns = _make_fpns(vision_patch_size, self.output_channels)
107
+ else:
108
+ self.fpns = None
109
+
110
+ @staticmethod
111
+ def get_output_channel(model_type):
112
+ if model_type == 'base':
113
+ return 768
114
+ if model_type == 'large':
115
+ return 1024
116
+ if model_type == 'huge':
117
+ return 1280
118
+
119
+ @staticmethod
120
+ def get_default_output_indices(model_type):
121
+ if model_type == 'base':
122
+ return [3, 5, 7, 11]
123
+ if model_type == 'large':
124
+ return [7, 11, 15, 23]
125
+ if model_type == 'huge':
126
+ return [8, 14, 20, 31]
127
+
128
+ def forward(self, image: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
129
+ # b x 3 x res x res
130
+ _, _, input_h, input_w = image.shape
131
+ if input_h != self.input_resolution or input_w != self.input_resolution:
132
+ image = F.interpolate(image, self.input_resolution,
133
+ mode='bilinear', align_corners=False)
134
+
135
+ image = (image - self.image_mean.to(image.device)) / self.image_std.to(image.device)
136
+
137
+ x = image.to(self.visual.conv1.weight.data)
138
+
139
+ x = self.visual.conv1(x) # shape = [*, width, grid, grid]
140
+ N, _, S, S = x.shape
141
+
142
+ # shape = [*, width, grid ** 2]
143
+ x = x.reshape(x.shape[0], x.shape[1], -1)
144
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
145
+ x = torch.cat([self.visual.class_embedding.to(x.dtype) +
146
+ torch.zeros(x.shape[0], 1, x.shape[-1],
147
+ dtype=x.dtype, device=x.device),
148
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
149
+
150
+ x = x + self.visual.positional_embedding.to(x.dtype)
151
+
152
+ x = self.visual.ln_pre(x)
153
+
154
+ x = x.permute(1, 0, 2).contiguous() # NLD -> LND
155
+
156
+ features = []
157
+ cls_tokens = []
158
+ for blk in self.visual.transformer.resblocks:
159
+ x = blk(x) # [S ** 2 + 1, N, D]
160
+ # if idx in self.output_indices:
161
+ feature = x[1:, :, :].permute(
162
+ 1, 2, 0).view(N, -1, S, S).contiguous().float()
163
+ features.append(feature)
164
+ cls_tokens.append(x[0, :, :])
165
+
166
+ features = [features[ind] for ind in self.output_indices]
167
+ cls_tokens = [cls_tokens[ind] for ind in self.output_indices]
168
+
169
+ if self.fpns is not None:
170
+ for i, fpn in enumerate(self.fpns):
171
+ features[i] = fpn(features[i])
172
+
173
+ return features, cls_tokens
src/pixel3dmm/preprocessing/facer/facer/face_attribute/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import FaceAttribute
2
+ from .farl import FaRLFaceAttribute
src/pixel3dmm/preprocessing/facer/facer/face_attribute/base.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class FaceAttribute(nn.Module):
4
+ """ face attribute base class
5
+
6
+ Args:
7
+ images (torch.Tensor): b x c x h x w
8
+
9
+ data (Dict[str, Any]):
10
+
11
+ * image_ids (torch.Tensor): nfaces
12
+ * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
13
+ * points (torch.Tensor): nfaces x 5 x 2 (x, y)
14
+
15
+ Returns:
16
+ data (Dict[str, Any]):
17
+
18
+ * image_ids (torch.Tensor): nfaces
19
+ * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
20
+ * points (torch.Tensor): nfaces x 5 x 2 (x, y)
21
+ * attrs (Dict[str, Any]):
22
+ * logits (torch.Tensor): nfaces x nclasses
23
+ """
24
+ pass
src/pixel3dmm/preprocessing/facer/facer/face_attribute/farl.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+ import functools
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from ..transform import get_face_align_matrix, make_tanh_warp_grid
6
+ from .base import FaceAttribute
7
+ from ..farl import farl_classification
8
+ from ..util import download_jit
9
+ import numpy as np
10
+ from torchvision.transforms import Normalize
11
+
12
+
13
+ def get_std_points_xray(out_size=256, mid_size=500):
14
+ std_points_256 = np.array(
15
+ [
16
+ [85.82991, 85.7792],
17
+ [169.0532, 84.3381],
18
+ [127.574, 137.0006],
19
+ [90.6964, 174.7014],
20
+ [167.3069, 173.3733],
21
+ ]
22
+ )
23
+ std_points_256[:, 1] += 30
24
+ old_size = 256
25
+ mid = mid_size / 2
26
+ new_std_points = std_points_256 - old_size / 2 + mid
27
+ target_pts = new_std_points * out_size / mid_size
28
+ target_pts = torch.from_numpy(target_pts).float()
29
+ return target_pts
30
+
31
+
32
+ pretrain_settings = {
33
+ "celeba/224": {
34
+ # acc 92.06617474555969
35
+ "num_classes": 40,
36
+ "layers": [11],
37
+ "url": "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_attribute.farl.celeba.pt",
38
+ "matrix_src_tag": "points",
39
+ "get_matrix_fn": functools.partial(
40
+ get_face_align_matrix,
41
+ target_shape=(224, 224),
42
+ target_pts=get_std_points_xray(out_size=224, mid_size=500),
43
+ ),
44
+ "get_grid_fn": functools.partial(
45
+ make_tanh_warp_grid, warp_factor=0.0, warped_shape=(224, 224)
46
+ ),
47
+ "classes": [
48
+ "5_o_Clock_Shadow",
49
+ "Arched_Eyebrows",
50
+ "Attractive",
51
+ "Bags_Under_Eyes",
52
+ "Bald",
53
+ "Bangs",
54
+ "Big_Lips",
55
+ "Big_Nose",
56
+ "Black_Hair",
57
+ "Blond_Hair",
58
+ "Blurry",
59
+ "Brown_Hair",
60
+ "Bushy_Eyebrows",
61
+ "Chubby",
62
+ "Double_Chin",
63
+ "Eyeglasses",
64
+ "Goatee",
65
+ "Gray_Hair",
66
+ "Heavy_Makeup",
67
+ "High_Cheekbones",
68
+ "Male",
69
+ "Mouth_Slightly_Open",
70
+ "Mustache",
71
+ "Narrow_Eyes",
72
+ "No_Beard",
73
+ "Oval_Face",
74
+ "Pale_Skin",
75
+ "Pointy_Nose",
76
+ "Receding_Hairline",
77
+ "Rosy_Cheeks",
78
+ "Sideburns",
79
+ "Smiling",
80
+ "Straight_Hair",
81
+ "Wavy_Hair",
82
+ "Wearing_Earrings",
83
+ "Wearing_Hat",
84
+ "Wearing_Lipstick",
85
+ "Wearing_Necklace",
86
+ "Wearing_Necktie",
87
+ "Young",
88
+ ],
89
+ }
90
+ }
91
+
92
+
93
+ def load_face_attr(model_path, num_classes=40, layers=[11]):
94
+ model = farl_classification(num_classes=num_classes, layers=layers)
95
+ state_dict = download_jit(model_path, jit=False)
96
+ model.load_state_dict(state_dict)
97
+ return model
98
+
99
+
100
+ class FaRLFaceAttribute(FaceAttribute):
101
+ """The face attribute recognition models from [FaRL](https://github.com/FacePerceiver/FaRL).
102
+
103
+ Please consider citing
104
+ ```bibtex
105
+ @article{zheng2021farl,
106
+ title={General Facial Representation Learning in a Visual-Linguistic Manner},
107
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
108
+ Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
109
+ Dong and Zeng, Ming and Wen, Fang},
110
+ journal={arXiv preprint arXiv:2112.03109},
111
+ year={2021}
112
+ }
113
+ ```
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ conf_name: Optional[str] = None,
119
+ model_path: Optional[str] = None,
120
+ device=None,
121
+ ) -> None:
122
+ super().__init__()
123
+ if conf_name is None:
124
+ conf_name = "celeba/224"
125
+ if model_path is None:
126
+ model_path = pretrain_settings[conf_name]["url"]
127
+ self.conf_name = conf_name
128
+
129
+ setting = pretrain_settings[self.conf_name]
130
+ self.labels = setting["classes"]
131
+ self.net = load_face_attr(model_path, num_classes=setting["num_classes"], layers = setting["layers"])
132
+ if device is not None:
133
+ self.net = self.net.to(device)
134
+ self.normalize = Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
135
+ self.eval()
136
+
137
+ def forward(self, images: torch.Tensor, data: Dict[str, Any]):
138
+ setting = pretrain_settings[self.conf_name]
139
+ images = images.float() / 255.0
140
+ _, _, h, w = images.shape
141
+
142
+ simages = images[data["image_ids"]]
143
+ matrix = setting["get_matrix_fn"](data[setting["matrix_src_tag"]])
144
+ grid = setting["get_grid_fn"](matrix=matrix, orig_shape=(h, w))
145
+
146
+ w_images = F.grid_sample(simages, grid, mode="bilinear", align_corners=False)
147
+ w_images = self.normalize(w_images)
148
+
149
+ outputs = self.net(w_images)
150
+ probs = torch.sigmoid(outputs)
151
+
152
+ data["attrs"] = probs
153
+
154
+ return data
155
+
156
+
157
+ if __name__ == "__main__":
158
+ model = FaRLFaceAttribute()
src/pixel3dmm/preprocessing/facer/facer/face_detection/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import FaceDetector
2
+ from .retinaface import RetinaFaceDetector
src/pixel3dmm/preprocessing/facer/facer/face_detection/base.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FaceDetector(nn.Module):
6
+ """ face detector
7
+
8
+ Args:
9
+ images (torch.Tensor): b x c x h x w
10
+
11
+ Returns:
12
+ data (Dict[str, torch.Tensor]):
13
+
14
+ * rects: nfaces x 4 (x1, y1, x2, y2)
15
+ * points: nfaces x 5 x 2 (x, y)
16
+ * scores: nfaces
17
+ * image_ids: nfaces
18
+ """
19
+ pass
src/pixel3dmm/preprocessing/facer/facer/face_detection/retinaface.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # largely borrowed from https://github.dev/elliottzheng/batch-face/face_detection/alignment.py
2
+
3
+ from typing import Dict, List, Optional, Tuple
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision.models._utils as _utils
9
+ from .base import FaceDetector
10
+
11
+
12
+ from itertools import product as product
13
+ from math import ceil
14
+
15
+
16
+ pretrained_urls = {
17
+ "mobilenet": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/mobilenet0.25_Final.pth",
18
+ "resnet50": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/Resnet50_Final.pth"
19
+ }
20
+
21
+
22
+ def conv_bn(inp, oup, stride=1, leaky=0):
23
+ return nn.Sequential(
24
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
25
+ nn.BatchNorm2d(oup),
26
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
27
+ )
28
+
29
+
30
+ def conv_bn_no_relu(inp, oup, stride):
31
+ return nn.Sequential(
32
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
33
+ nn.BatchNorm2d(oup),
34
+ )
35
+
36
+
37
+ def conv_bn1X1(inp, oup, stride, leaky=0):
38
+ return nn.Sequential(
39
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
40
+ nn.BatchNorm2d(oup),
41
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
42
+ )
43
+
44
+
45
+ def conv_dw(inp, oup, stride, leaky=0.1):
46
+ return nn.Sequential(
47
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
48
+ nn.BatchNorm2d(inp),
49
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
50
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
51
+ nn.BatchNorm2d(oup),
52
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
53
+ )
54
+
55
+
56
+ class SSH(nn.Module):
57
+ def __init__(self, in_channel, out_channel):
58
+ super(SSH, self).__init__()
59
+ assert out_channel % 4 == 0
60
+ leaky = 0
61
+ if out_channel <= 64:
62
+ leaky = 0.1
63
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
64
+
65
+ self.conv5X5_1 = conv_bn(
66
+ in_channel, out_channel // 4, stride=1, leaky=leaky)
67
+ self.conv5X5_2 = conv_bn_no_relu(
68
+ out_channel // 4, out_channel // 4, stride=1)
69
+
70
+ self.conv7X7_2 = conv_bn(
71
+ out_channel // 4, out_channel // 4, stride=1, leaky=leaky
72
+ )
73
+ self.conv7x7_3 = conv_bn_no_relu(
74
+ out_channel // 4, out_channel // 4, stride=1)
75
+
76
+ def forward(self, input):
77
+ conv3X3 = self.conv3X3(input)
78
+
79
+ conv5X5_1 = self.conv5X5_1(input)
80
+ conv5X5 = self.conv5X5_2(conv5X5_1)
81
+
82
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
83
+ conv7X7 = self.conv7x7_3(conv7X7_2)
84
+
85
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
86
+ out = F.relu(out)
87
+ return out
88
+
89
+
90
+ class FPN(nn.Module):
91
+ def __init__(self, in_channels_list, out_channels):
92
+ super(FPN, self).__init__()
93
+ leaky = 0
94
+ if out_channels <= 64:
95
+ leaky = 0.1
96
+ self.output1 = conv_bn1X1(
97
+ in_channels_list[0], out_channels, stride=1, leaky=leaky
98
+ )
99
+ self.output2 = conv_bn1X1(
100
+ in_channels_list[1], out_channels, stride=1, leaky=leaky
101
+ )
102
+ self.output3 = conv_bn1X1(
103
+ in_channels_list[2], out_channels, stride=1, leaky=leaky
104
+ )
105
+
106
+ self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
107
+ self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
108
+
109
+ def forward(self, input):
110
+ # names = list(input.keys())
111
+ input = list(input.values())
112
+
113
+ output1 = self.output1(input[0])
114
+ output2 = self.output2(input[1])
115
+ output3 = self.output3(input[2])
116
+
117
+ up3 = F.interpolate(
118
+ output3, size=[output2.size(2), output2.size(3)], mode="nearest"
119
+ )
120
+ output2 = output2 + up3
121
+ output2 = self.merge2(output2)
122
+
123
+ up2 = F.interpolate(
124
+ output2, size=[output1.size(2), output1.size(3)], mode="nearest"
125
+ )
126
+ output1 = output1 + up2
127
+ output1 = self.merge1(output1)
128
+
129
+ out = [output1, output2, output3]
130
+ return out
131
+
132
+
133
+ class MobileNetV1(nn.Module):
134
+ def __init__(self):
135
+ super(MobileNetV1, self).__init__()
136
+ self.stage1 = nn.Sequential(
137
+ conv_bn(3, 8, 2, leaky=0.1), # 3
138
+ conv_dw(8, 16, 1), # 7
139
+ conv_dw(16, 32, 2), # 11
140
+ conv_dw(32, 32, 1), # 19
141
+ conv_dw(32, 64, 2), # 27
142
+ conv_dw(64, 64, 1), # 43
143
+ )
144
+ self.stage2 = nn.Sequential(
145
+ conv_dw(64, 128, 2), # 43 + 16 = 59
146
+ conv_dw(128, 128, 1), # 59 + 32 = 91
147
+ conv_dw(128, 128, 1), # 91 + 32 = 123
148
+ conv_dw(128, 128, 1), # 123 + 32 = 155
149
+ conv_dw(128, 128, 1), # 155 + 32 = 187
150
+ conv_dw(128, 128, 1), # 187 + 32 = 219
151
+ )
152
+ self.stage3 = nn.Sequential(
153
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
154
+ conv_dw(256, 256, 1), # 241 + 64 = 301
155
+ )
156
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
157
+ self.fc = nn.Linear(256, 1000)
158
+
159
+ def forward(self, x):
160
+ x = self.stage1(x)
161
+ x = self.stage2(x)
162
+ x = self.stage3(x)
163
+ x = self.avg(x)
164
+ # x = self.model(x)
165
+ x = x.view(-1, 256)
166
+ x = self.fc(x)
167
+ return x
168
+
169
+
170
+ class ClassHead(nn.Module):
171
+ def __init__(self, inchannels=512, num_anchors=3):
172
+ super(ClassHead, self).__init__()
173
+ self.num_anchors = num_anchors
174
+ self.conv1x1 = nn.Conv2d(
175
+ inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0
176
+ )
177
+
178
+ def forward(self, x):
179
+ out = self.conv1x1(x)
180
+ out = out.permute(0, 2, 3, 1).contiguous()
181
+ return out.view(out.shape[0], -1, 2)
182
+
183
+
184
+ class BboxHead(nn.Module):
185
+ def __init__(self, inchannels=512, num_anchors=3):
186
+ super(BboxHead, self).__init__()
187
+ self.conv1x1 = nn.Conv2d(
188
+ inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0
189
+ )
190
+
191
+ def forward(self, x):
192
+ out = self.conv1x1(x)
193
+ out = out.permute(0, 2, 3, 1).contiguous()
194
+ return out.view(out.shape[0], -1, 4)
195
+
196
+
197
+ class LandmarkHead(nn.Module):
198
+ def __init__(self, inchannels=512, num_anchors=3):
199
+ super(LandmarkHead, self).__init__()
200
+ self.conv1x1 = nn.Conv2d(
201
+ inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0
202
+ )
203
+
204
+ def forward(self, x):
205
+ out = self.conv1x1(x)
206
+ out = out.permute(0, 2, 3, 1).contiguous()
207
+ return out.view(out.shape[0], -1, 10)
208
+
209
+
210
+ class RetinaFace(nn.Module):
211
+ def __init__(self, cfg=None, phase="train"):
212
+ """
213
+ :param cfg: Network related settings.
214
+ :param phase: train or test.
215
+ """
216
+ super(RetinaFace, self).__init__()
217
+ self.phase = phase
218
+ backbone = None
219
+ if cfg["name"] == "mobilenet0.25":
220
+ backbone = MobileNetV1()
221
+ elif cfg["name"] == "Resnet50":
222
+ import torchvision.models as models
223
+ backbone = models.resnet50(pretrained=cfg["pretrain"])
224
+
225
+ self.body = _utils.IntermediateLayerGetter(
226
+ backbone, cfg["return_layers"])
227
+ in_channels_stage2 = cfg["in_channel"]
228
+ in_channels_list = [
229
+ in_channels_stage2 * 2,
230
+ in_channels_stage2 * 4,
231
+ in_channels_stage2 * 8,
232
+ ]
233
+ out_channels = cfg["out_channel"]
234
+ self.fpn = FPN(in_channels_list, out_channels)
235
+ self.ssh1 = SSH(out_channels, out_channels)
236
+ self.ssh2 = SSH(out_channels, out_channels)
237
+ self.ssh3 = SSH(out_channels, out_channels)
238
+
239
+ self.ClassHead = self._make_class_head(
240
+ fpn_num=3, inchannels=cfg["out_channel"])
241
+ self.BboxHead = self._make_bbox_head(
242
+ fpn_num=3, inchannels=cfg["out_channel"])
243
+ self.LandmarkHead = self._make_landmark_head(
244
+ fpn_num=3, inchannels=cfg["out_channel"]
245
+ )
246
+
247
+ def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2):
248
+ classhead = nn.ModuleList()
249
+ for i in range(fpn_num):
250
+ classhead.append(ClassHead(inchannels, anchor_num))
251
+ return classhead
252
+
253
+ def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2):
254
+ bboxhead = nn.ModuleList()
255
+ for i in range(fpn_num):
256
+ bboxhead.append(BboxHead(inchannels, anchor_num))
257
+ return bboxhead
258
+
259
+ def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2):
260
+ landmarkhead = nn.ModuleList()
261
+ for i in range(fpn_num):
262
+ landmarkhead.append(LandmarkHead(inchannels, anchor_num))
263
+ return landmarkhead
264
+
265
+ def forward(self, inputs):
266
+ out = self.body(inputs)
267
+
268
+ # FPN
269
+ fpn = self.fpn(out)
270
+
271
+ # SSH
272
+ feature1 = self.ssh1(fpn[0])
273
+ feature2 = self.ssh2(fpn[1])
274
+ feature3 = self.ssh3(fpn[2])
275
+ features = [feature1, feature2, feature3]
276
+
277
+ bbox_regressions = torch.cat(
278
+ [self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
279
+ )
280
+ classifications = torch.cat(
281
+ [self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
282
+ )
283
+ ldm_regressions = torch.cat(
284
+ [self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1
285
+ )
286
+
287
+ if self.phase == "train":
288
+ output = (bbox_regressions, classifications, ldm_regressions)
289
+ else:
290
+ output = (
291
+ bbox_regressions,
292
+ F.softmax(classifications, dim=-1),
293
+ ldm_regressions,
294
+ )
295
+ return output
296
+
297
+
298
+ # Adapted from https://github.com/Hakuyume/chainer-ssd
299
+ def decode(loc: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor:
300
+ boxes = torch.cat(
301
+ (
302
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
303
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]),
304
+ ),
305
+ 1,
306
+ )
307
+ boxes[:, :2] -= boxes[:, 2:] / 2
308
+ boxes[:, 2:] += boxes[:, :2]
309
+ return boxes
310
+
311
+
312
+ def decode_landm(pre: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor:
313
+ landms = torch.cat(
314
+ (
315
+ priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
316
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
317
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
318
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
319
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
320
+ ),
321
+ dim=1,
322
+ )
323
+ return landms
324
+
325
+
326
+ def nms(dets: torch.Tensor, thresh: float) -> List[int]:
327
+ """Pure Python NMS baseline."""
328
+ x1 = dets[:, 0]
329
+ y1 = dets[:, 1]
330
+ x2 = dets[:, 2]
331
+ y2 = dets[:, 3]
332
+ scores = dets[:, 4]
333
+
334
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
335
+ order = torch.flip(scores.argsort(), [0])
336
+
337
+ keep = []
338
+ while order.numel() > 0:
339
+ i = order[0].item()
340
+ keep.append(i)
341
+ xx1 = torch.maximum(x1[i], x1[order[1:]])
342
+ yy1 = torch.maximum(y1[i], y1[order[1:]])
343
+ xx2 = torch.minimum(x2[i], x2[order[1:]])
344
+ yy2 = torch.minimum(y2[i], y2[order[1:]])
345
+
346
+ w = torch.maximum(torch.tensor(0.0).to(dets), xx2 - xx1 + 1)
347
+ h = torch.maximum(torch.tensor(0.0).to(dets), yy2 - yy1 + 1)
348
+ inter = w * h
349
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
350
+
351
+ inds = torch.where(ovr <= thresh)[0]
352
+ order = order[inds + 1]
353
+
354
+ return keep
355
+
356
+
357
+ class PriorBox:
358
+ def __init__(self, cfg: dict, image_size: Tuple[int, int]):
359
+ self.min_sizes = cfg["min_sizes"]
360
+ self.steps = cfg["steps"]
361
+ self.clip = cfg["clip"]
362
+ self.image_size = image_size
363
+ self.feature_maps = [
364
+ [ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)]
365
+ for step in self.steps
366
+ ]
367
+
368
+ def generate_anchors(self, device) -> torch.Tensor:
369
+ anchors = []
370
+ for k, f in enumerate(self.feature_maps):
371
+ min_sizes = self.min_sizes[k]
372
+ for i, j in product(range(f[0]), range(f[1])):
373
+ for min_size in min_sizes:
374
+ s_kx = min_size / self.image_size[1]
375
+ s_ky = min_size / self.image_size[0]
376
+ dense_cx = [
377
+ x * self.steps[k] / self.image_size[1] for x in [j + 0.5]
378
+ ]
379
+ dense_cy = [
380
+ y * self.steps[k] / self.image_size[0] for y in [i + 0.5]
381
+ ]
382
+ for cy, cx in product(dense_cy, dense_cx):
383
+ anchors += [cx, cy, s_kx, s_ky]
384
+
385
+ # back to torch land
386
+ output = torch.tensor(anchors).view(-1, 4)
387
+ if self.clip:
388
+ output.clamp_(max=1, min=0)
389
+ return output.to(device=device)
390
+
391
+
392
+ cfg_mnet = {
393
+ "name": "mobilenet0.25",
394
+ "min_sizes": [[16, 32], [64, 128], [256, 512]],
395
+ "steps": [8, 16, 32],
396
+ "variance": [0.1, 0.2],
397
+ "clip": False,
398
+ "loc_weight": 2.0,
399
+ "gpu_train": True,
400
+ "batch_size": 32,
401
+ "ngpu": 1,
402
+ "epoch": 250,
403
+ "decay1": 190,
404
+ "decay2": 220,
405
+ "image_size": 640,
406
+ "pretrain": True,
407
+ "return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
408
+ "in_channel": 32,
409
+ "out_channel": 64,
410
+ }
411
+
412
+ cfg_re50 = {
413
+ "name": "Resnet50",
414
+ "min_sizes": [[16, 32], [64, 128], [256, 512]],
415
+ "steps": [8, 16, 32],
416
+ "variance": [0.1, 0.2],
417
+ "clip": False,
418
+ "loc_weight": 2.0,
419
+ "gpu_train": True,
420
+ "batch_size": 24,
421
+ "ngpu": 4,
422
+ "epoch": 100,
423
+ "decay1": 70,
424
+ "decay2": 90,
425
+ "image_size": 840,
426
+ "pretrain": False,
427
+ "return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
428
+ "in_channel": 256,
429
+ "out_channel": 256,
430
+ }
431
+
432
+
433
+ def check_keys(model, pretrained_state_dict):
434
+ ckpt_keys = set(pretrained_state_dict.keys())
435
+ model_keys = set(model.state_dict().keys())
436
+ used_pretrained_keys = model_keys & ckpt_keys
437
+ assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
438
+ return True
439
+
440
+
441
+ def remove_prefix(state_dict, prefix):
442
+ """ Old style model is stored with all names of parameters sharing common prefix 'module.' """
443
+ def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
444
+ return {f(key): value for key, value in state_dict.items()}
445
+
446
+
447
+ def load_model(model, pretrained_path, load_to_cpu, network: str):
448
+ if pretrained_path is None:
449
+ url = pretrained_urls[network]
450
+ if load_to_cpu:
451
+ pretrained_dict = torch.utils.model_zoo.load_url(
452
+ url, map_location=lambda storage, loc: storage
453
+ )
454
+ else:
455
+ pretrained_dict = torch.utils.model_zoo.load_url(
456
+ url, map_location=lambda storage, loc: storage.cuda(device)
457
+ )
458
+ else:
459
+ if load_to_cpu:
460
+ pretrained_dict = torch.load(
461
+ pretrained_path, map_location=lambda storage, loc: storage
462
+ )
463
+ else:
464
+ device = torch.cuda.current_device()
465
+ pretrained_dict = torch.load(
466
+ pretrained_path, map_location=lambda storage, loc: storage.cuda(
467
+ device)
468
+ )
469
+ if "state_dict" in pretrained_dict.keys():
470
+ pretrained_dict = remove_prefix(
471
+ pretrained_dict["state_dict"], "module.")
472
+ else:
473
+ pretrained_dict = remove_prefix(pretrained_dict, "module.")
474
+ check_keys(model, pretrained_dict)
475
+ model.load_state_dict(pretrained_dict, strict=False)
476
+ return model
477
+
478
+
479
+ def load_net(model_path, network="mobilenet"):
480
+ if network == "mobilenet":
481
+ cfg = cfg_mnet
482
+ elif network == "resnet50":
483
+ cfg = cfg_re50
484
+ else:
485
+ raise NotImplementedError(network)
486
+ # net and model
487
+ net = RetinaFace(cfg=cfg, phase="test")
488
+ net = load_model(net, model_path, True, network=network)
489
+ net.eval()
490
+ cudnn.benchmark = True
491
+ # net = net.to(device)
492
+ return net
493
+
494
+
495
+ def parse_det(det: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, float]:
496
+ landmarks = det[5:].reshape(5, 2)
497
+ box = det[:4]
498
+ score = det[4]
499
+ return box, landmarks, score.item()
500
+
501
+
502
+ def post_process(
503
+ loc: torch.Tensor,
504
+ conf: torch.Tensor,
505
+ landms: torch.Tensor,
506
+ prior_data: torch.Tensor,
507
+ cfg: dict,
508
+ scale: float,
509
+ scale1: float,
510
+ resize,
511
+ confidence_threshold,
512
+ top_k,
513
+ nms_threshold,
514
+ keep_top_k,
515
+ ):
516
+ boxes = decode(loc, prior_data, cfg["variance"])
517
+ boxes = boxes * scale / resize
518
+ # boxes = boxes.cpu().numpy()
519
+ # scores = conf.cpu().numpy()[:, 1]
520
+ scores = conf[:, 1]
521
+ landms_copy = decode_landm(landms, prior_data, cfg["variance"])
522
+
523
+ landms_copy = landms_copy * scale1 / resize
524
+ # landms_copy = landms_copy.cpu().numpy()
525
+
526
+ # ignore low scores
527
+ inds = torch.where(scores > confidence_threshold)[0]
528
+ boxes = boxes[inds]
529
+ landms_copy = landms_copy[inds]
530
+ scores = scores[inds]
531
+
532
+ # keep top-K before NMS
533
+ order = torch.flip(scores.argsort(), [0])[:top_k]
534
+ boxes = boxes[order]
535
+ landms_copy = landms_copy[order]
536
+ scores = scores[order]
537
+
538
+ # do NMS
539
+ dets = torch.hstack((boxes, scores.unsqueeze(-1))).to(
540
+ dtype=torch.float32, copy=False)
541
+ keep = nms(dets, nms_threshold)
542
+ # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
543
+ dets = dets[keep, :]
544
+ landms_copy = landms_copy[keep]
545
+
546
+ # keep top-K faster NMS
547
+ dets = dets[:keep_top_k, :]
548
+ landms_copy = landms_copy[:keep_top_k, :]
549
+
550
+ dets = torch.cat((dets, landms_copy), dim=1)
551
+ # show image
552
+ dets = sorted(dets, key=lambda x: x[4], reverse=True)
553
+ dets = [parse_det(x) for x in dets]
554
+
555
+ return dets
556
+
557
+
558
+ # @torch.no_grad()
559
+ def batch_detect(net: nn.Module, images: torch.Tensor, threshold: float = 0.5):
560
+ confidence_threshold = threshold
561
+ cfg = cfg_mnet
562
+ top_k = 5000
563
+ nms_threshold = 0.4
564
+ keep_top_k = 750
565
+ resize = 1
566
+
567
+ img = images.float()
568
+ mean = torch.as_tensor([104, 117, 123], dtype=img.dtype, device=img.device).view(
569
+ 1, 3, 1, 1
570
+ )
571
+ img -= mean
572
+ (
573
+ _,
574
+ _,
575
+ im_height,
576
+ im_width,
577
+ ) = img.shape
578
+ scale = torch.as_tensor(
579
+ [im_width, im_height, im_width, im_height],
580
+ dtype=img.dtype,
581
+ device=img.device,
582
+ )
583
+ scale = scale.to(img.device)
584
+
585
+ loc, conf, landms = net(img) # forward pass
586
+
587
+ priorbox = PriorBox(cfg, image_size=(im_height, im_width))
588
+ prior_data = priorbox.generate_anchors(device=img.device)
589
+ scale1 = torch.as_tensor(
590
+ [
591
+ img.shape[3],
592
+ img.shape[2],
593
+ img.shape[3],
594
+ img.shape[2],
595
+ img.shape[3],
596
+ img.shape[2],
597
+ img.shape[3],
598
+ img.shape[2],
599
+ img.shape[3],
600
+ img.shape[2],
601
+ ],
602
+ dtype=img.dtype,
603
+ device=img.device,
604
+ )
605
+ scale1 = scale1.to(img.device)
606
+
607
+ all_dets = [
608
+ post_process(
609
+ loc_i,
610
+ conf_i,
611
+ landms_i,
612
+ prior_data,
613
+ cfg,
614
+ scale,
615
+ scale1,
616
+ resize,
617
+ confidence_threshold,
618
+ top_k,
619
+ nms_threshold,
620
+ keep_top_k,
621
+ )
622
+ for loc_i, conf_i, landms_i in zip(loc, conf, landms)
623
+ ]
624
+
625
+ rects = []
626
+ points = []
627
+ scores = []
628
+ image_ids = []
629
+ for image_id, faces_in_one_image in enumerate(all_dets):
630
+ for rect, landmarks, score in faces_in_one_image:
631
+ rects.append(rect)
632
+ points.append(landmarks)
633
+ scores.append(score)
634
+ image_ids.append(image_id)
635
+
636
+ if len(rects) == 0:
637
+ return {
638
+ 'rects': torch.Tensor().to(img.device),
639
+ 'points': torch.Tensor().to(img.device),
640
+ 'scores': torch.Tensor().to(img.device),
641
+ 'image_ids': torch.Tensor().to(img.device),
642
+ }
643
+
644
+ return {
645
+ 'rects': torch.stack(rects, dim=0).to(img.device),
646
+ 'points': torch.stack(points, dim=0).to(img.device),
647
+ 'scores': torch.tensor(scores).to(img.device),
648
+ 'image_ids': torch.tensor(image_ids).to(img.device),
649
+ }
650
+
651
+
652
+ class RetinaFaceDetector(FaceDetector):
653
+ """RetinaFaceDetector
654
+
655
+ Args:
656
+ images (torch.Tensor): b x c x h x w, uint8, 0~255.
657
+
658
+ Returns:
659
+ faces (Dict[str, torch.Tensor]):
660
+
661
+ * image_ids: n, int
662
+ * rects: n x 4 (x1, y1, x2, y2)
663
+ * points: n x 5 x 2 (x, y)
664
+ * scores: n
665
+ """
666
+
667
+ def __init__(self, conf_name: Optional[str] = None,
668
+ model_path: Optional[str] = None, threshold=0.8) -> None:
669
+ super().__init__()
670
+ if conf_name is None:
671
+ conf_name = 'mobilenet'
672
+ self.net = load_net(model_path, conf_name)
673
+ self.threshold = threshold
674
+ self.eval()
675
+
676
+ def forward(self, images: torch.Tensor) -> Dict[str, torch.Tensor]:
677
+ return batch_detect(self.net, images.clone(), threshold=self.threshold)
src/pixel3dmm/preprocessing/facer/facer/face_parsing/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import FaceParser
2
+ from .farl import FaRLFaceParser
src/pixel3dmm/preprocessing/facer/facer/face_parsing/base.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class FaceParser(nn.Module):
5
+ """ face parser
6
+
7
+ Args:
8
+ images (torch.Tensor): b x c x h x w
9
+
10
+ data (Dict[str, Any]):
11
+
12
+ * image_ids (torch.Tensor): nfaces
13
+ * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
14
+ * points (torch.Tensor): nfaces x 5 x 2 (x, y)
15
+
16
+ Returns:
17
+ data (Dict[str, Any]):
18
+
19
+ * image_ids (torch.Tensor): nfaces
20
+ * rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
21
+ * points (torch.Tensor): nfaces x 5 x 2 (x, y)
22
+ * seg (Dict[str, Any]):
23
+
24
+ * logits (torch.Tensor): nfaces x nclasses x h x w
25
+ * label_names (List[str]): nclasses
26
+ """
27
+ pass
src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any
2
+ import functools
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from ..util import download_jit
7
+ from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
8
+ make_inverted_tanh_warp_grid, make_tanh_warp_grid)
9
+ from .base import FaceParser
10
+ import numpy as np
11
+
12
+ pretrain_settings = {
13
+ 'lapa/448': {
14
+ 'url': [
15
+ 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
16
+ ],
17
+ 'matrix_src_tag': 'points',
18
+ 'get_matrix_fn': functools.partial(get_face_align_matrix,
19
+ target_shape=(448, 448), target_face_scale=1.0),
20
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
21
+ warp_factor=0.8, warped_shape=(448, 448)),
22
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
23
+ warp_factor=0.8, warped_shape=(448, 448)),
24
+ 'label_names': ['background', 'face', 'rb', 'lb', 're',
25
+ 'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
26
+ },
27
+ 'celebm/448': {
28
+ 'url': [
29
+ 'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
30
+ ],
31
+ 'matrix_src_tag': 'points',
32
+ 'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
33
+ target_shape=(448, 448)),
34
+ 'get_grid_fn': functools.partial(make_tanh_warp_grid,
35
+ warp_factor=0, warped_shape=(448, 448)),
36
+ 'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
37
+ warp_factor=0, warped_shape=(448, 448)),
38
+ 'label_names': [
39
+ 'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
40
+ 'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
41
+ 'eyeg', 'hat', 'earr', 'neck_l']
42
+ }
43
+ }
44
+
45
+
46
+ class FaRLFaceParser(FaceParser):
47
+ """ The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
48
+
49
+ Please consider citing
50
+ ```bibtex
51
+ @article{zheng2021farl,
52
+ title={General Facial Representation Learning in a Visual-Linguistic Manner},
53
+ author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
54
+ Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
55
+ Dong and Zeng, Ming and Wen, Fang},
56
+ journal={arXiv preprint arXiv:2112.03109},
57
+ year={2021}
58
+ }
59
+ ```
60
+ """
61
+
62
+ def __init__(self, conf_name: Optional[str] = None, model_path: Optional[str] = None, device=None) -> None:
63
+ super().__init__()
64
+ if conf_name is None:
65
+ conf_name = 'lapa/448'
66
+ if model_path is None:
67
+ model_path = pretrain_settings[conf_name]['url']
68
+ self.conf_name = conf_name
69
+ self.net = download_jit(model_path, map_location=device)
70
+ self.eval()
71
+ self.device = device
72
+ self.setting = pretrain_settings[conf_name]
73
+ self.label_names = self.setting['label_names']
74
+
75
+
76
+ def get_warp_grid(self, images: torch.Tensor, matrix_src):
77
+ _, _, h, w = images.shape
78
+ matrix = self.setting['get_matrix_fn'](matrix_src)
79
+ grid = self.setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
80
+ inv_grid = self.setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
81
+ return grid, inv_grid
82
+
83
+ def warp_images(self, images: torch.Tensor, data: Dict[str, Any]):
84
+ simages = self.unify_image_dtype(images)
85
+ simages = simages[data['image_ids']]
86
+ matrix_src = data[self.setting['matrix_src_tag']]
87
+ grid, inv_grid = self.get_warp_grid(simages, matrix_src)
88
+
89
+ w_images = F.grid_sample(
90
+ simages, grid, mode='bilinear', align_corners=False)
91
+ return w_images, grid, inv_grid
92
+
93
+
94
+ def decode_image_to_cv2(self, images: torch.Tensor):
95
+ '''
96
+ output: b x 3 x h x w, torch.uint8, [0, 255]
97
+ '''
98
+ assert images.ndim == 4
99
+ assert images.shape[1] == 3
100
+ images = images.permute(0, 2, 3, 1).cpu().numpy() * 255
101
+ images = images.astype(np.uint8)
102
+ return images
103
+
104
+ def unify_image_dtype(self, images: torch.Tensor|np.ndarray|list):
105
+ '''
106
+ output: b x 3 x h x w, torch.float32, [0, 1]
107
+ '''
108
+ if isinstance(images, np.ndarray):
109
+ images = torch.from_numpy(images)
110
+ elif isinstance(images, torch.Tensor):
111
+ pass
112
+ elif isinstance(images, list):
113
+ assert len(images) > 0, "images is empty"
114
+ first_image = images[0]
115
+ if isinstance(first_image, np.ndarray):
116
+ images = [torch.from_numpy(image).permute(2, 0, 1) for image in images]
117
+ images = torch.stack(images)
118
+ elif isinstance(first_image, torch.Tensor):
119
+ images = torch.stack(images)
120
+ else:
121
+ raise ValueError(f"Unsupported image type: {type(first_image)}")
122
+
123
+ else:
124
+ raise ValueError(f"Unsupported image type: {type(images)}")
125
+
126
+ assert images.ndim == 4
127
+ assert images.shape[1] == 3
128
+
129
+ max_val = images.max()
130
+ if max_val <= 1:
131
+ assert images.dtype == torch.float32 or images.dtype == torch.float16
132
+ elif max_val <= 255:
133
+ assert images.dtype == torch.uint8
134
+ images = images.float() / 255.0
135
+ else:
136
+ raise ValueError(f"Unsupported image type: {images.dtype}")
137
+ if images.device != self.device:
138
+ images = images.to(device=self.device)
139
+ return images
140
+
141
+ @torch.no_grad()
142
+ @torch.inference_mode()
143
+ def forward(self, images: torch.Tensor, data: Dict[str, Any]):
144
+ '''
145
+ images: b x 3 x h x w , torch.uint8, [0, 255]
146
+ data: {'rects': rects, 'points': points, 'scores': scores, 'image_ids': image_ids}
147
+ '''
148
+ w_images, grid, inv_grid = self.warp_images(images, data)
149
+ w_seg_logits = self.forward_warped(w_images, return_preds=False)
150
+
151
+ seg_logits = F.grid_sample(
152
+ w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
153
+
154
+ data['seg'] = {'logits': seg_logits, 'label_names': self.label_names}
155
+ return data
156
+
157
+
158
+ def logits2predictions(self, logits: torch.Tensor):
159
+ return logits.argmax(dim=1)
160
+
161
+ @torch.no_grad()
162
+ @torch.inference_mode()
163
+ def forward_warped(self, images: torch.Tensor, return_preds: bool = True):
164
+ '''
165
+ images: b x 3 x h x w , torch.uint8, [0, 255]
166
+ '''
167
+ images = self.unify_image_dtype(images)
168
+ seg_logits, _ = self.net(images) # nfaces x c x h x w
169
+ # seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
170
+ if return_preds:
171
+ seg_preds = self.logits2predictions(seg_logits)
172
+ return seg_logits, seg_preds, self.label_names
173
+ else:
174
+ return seg_logits
src/pixel3dmm/preprocessing/facer/facer/farl/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from .model import load_farl, VisualTransformer
5
+ from .classification import farl_classification
src/pixel3dmm/preprocessing/facer/facer/farl/classification.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.utils.checkpoint as checkpoint
4
+ from .model import VisualTransformer
5
+
6
+
7
+ class VITClassificationHeadV0(nn.Module):
8
+ def __init__(
9
+ self,
10
+ num_features: int,
11
+ channel: int,
12
+ num_labels: int,
13
+ norm=False,
14
+ dropout=0.0,
15
+ ret_feat=False,
16
+ ):
17
+ super().__init__()
18
+ self.weights = nn.Parameter(
19
+ torch.ones(1, num_features * 3, 1, dtype=torch.float32)
20
+ )
21
+ self.final_fc = nn.Linear(channel, num_labels)
22
+ self.norm = norm
23
+ if self.norm:
24
+ for i in range(num_features * 3):
25
+ setattr(self, f"norm_{i}", nn.LayerNorm(channel))
26
+ self.dropout = nn.Dropout(p=dropout)
27
+ self.ret_feat = ret_feat
28
+
29
+ def forward(self, features, cls_tokens):
30
+ xs = []
31
+ for feature, cls_token in zip(features, cls_tokens):
32
+ # feature: b x c x s x s
33
+ # cls_token: b x c
34
+ xs.append(feature.mean([2, 3]))
35
+ xs.append(feature.max(-1).values.max(-1).values)
36
+ xs.append(cls_token)
37
+ if self.norm:
38
+ xs = [getattr(self, f"norm_{i}")(x) for i, x in enumerate(xs)]
39
+ xs = torch.stack(xs, dim=1) # b x 3N x c
40
+ feat = (xs * self.weights.softmax(dim=1)).sum(1) # b x c
41
+ x = self.dropout(feat)
42
+ x = self.final_fc(x) # b x num_labels
43
+ if self.ret_feat:
44
+ return x, feat
45
+ else:
46
+ return x
47
+
48
+
49
+ class FACTransformer(nn.Module):
50
+ """A face attribute classification transformer leveraging multiple cls_tokens.
51
+ Args:
52
+ image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
53
+ Returns:
54
+ logits (torch.Tensor): Float32 tensor with shape [b, n_classes].
55
+ aux_outputs:
56
+ """
57
+
58
+ def __init__(self, backbone: nn.Module, head: nn.Module):
59
+ super().__init__()
60
+ self.backbone = backbone
61
+ self.head = head
62
+ self.cuda().float()
63
+
64
+ def forward(self, image):
65
+ logits = self.head(*self.backbone(image))
66
+ return logits
67
+
68
+
69
+ def add_method(obj, name, method):
70
+ import types
71
+
72
+ setattr(obj, name, types.MethodType(method, obj))
73
+
74
+
75
+ def get_clip_encode_func(layers):
76
+ def func(self, x):
77
+ x = self.conv1(x) # shape = [*, width, grid, grid]
78
+ # shape = [*, width, grid ** 2]
79
+ x = x.reshape(x.shape[0], x.shape[1], -1)
80
+ extra_tokens = getattr(self, "extra_tokens", [])
81
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
82
+ class_token = self.class_embedding.to(x.dtype) + torch.zeros(
83
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
84
+ )
85
+ special_tokens = [
86
+ getattr(self, name).to(x.dtype)
87
+ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
88
+ for name in extra_tokens
89
+ ]
90
+ x = torch.cat(
91
+ [class_token, *special_tokens, x], dim=1
92
+ ) # shape = [*, grid ** 2 + 1, width]
93
+ x = x + self.positional_embedding.to(x.dtype)
94
+ x = self.ln_pre(x)
95
+ x = x.permute(1, 0, 2) # NLD -> LND
96
+ outs = []
97
+ max_layer = max(layers)
98
+ use_checkpoint = self.transformer.use_checkpoint
99
+ for layer_i, blk in enumerate(self.transformer.resblocks):
100
+ if layer_i > max_layer:
101
+ break
102
+ if self.training and use_checkpoint:
103
+ x = checkpoint.checkpoint(blk, x)
104
+ else:
105
+ x = blk(x)
106
+ outs.append(x)
107
+
108
+ outs = torch.stack(outs).permute(0, 2, 1, 3)
109
+ cls_tokens = outs[layers, :, 0, :]
110
+
111
+ extra_token_feats = {}
112
+ for i, name in enumerate(extra_tokens):
113
+ extra_token_feats[name] = outs[layers, :, i + 1, :]
114
+ L, B, N, C = outs.shape
115
+ import math
116
+
117
+ W = int(math.sqrt(N - 1 - len(extra_tokens)))
118
+ features = (
119
+ outs[layers, :, 1 + len(extra_tokens) :, :]
120
+ .reshape(len(layers), B, W, W, C)
121
+ .permute(0, 1, 4, 2, 3)
122
+ )
123
+ if getattr(self, "ret_special", False):
124
+ return features, cls_tokens, extra_token_feats
125
+ else:
126
+ return features, cls_tokens
127
+
128
+ return func
129
+
130
+
131
+ def farl_classification(num_classes=2, layers=list(range(12))):
132
+ model = VisualTransformer(
133
+ input_resolution=224,
134
+ patch_size=16,
135
+ width=768,
136
+ layers=12,
137
+ heads=12,
138
+ output_dim=512,
139
+ )
140
+ channel = 768
141
+ model = model.cuda()
142
+ del model.proj
143
+ del model.ln_post
144
+ add_method(model, "forward", get_clip_encode_func(layers))
145
+ head = VITClassificationHeadV0(
146
+ num_features=len(layers), channel=channel, num_labels=num_classes, norm=True
147
+ )
148
+ model = FACTransformer(model, head)
149
+ return model
src/pixel3dmm/preprocessing/facer/facer/farl/model.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from collections import OrderedDict
5
+ import logging
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ import torch.utils.checkpoint as checkpoint
10
+ import numpy as np
11
+ from timm.models.layers import trunc_normal_, DropPath
12
+
13
+
14
+ class Bottleneck(nn.Module):
15
+ expansion = 4
16
+
17
+ def __init__(self, inplanes, planes, stride=1):
18
+ super().__init__()
19
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
20
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
21
+ self.bn1 = nn.BatchNorm2d(planes)
22
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
23
+ self.bn2 = nn.BatchNorm2d(planes)
24
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ self.stride = stride
30
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
31
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
32
+ self.downsample = nn.Sequential(OrderedDict([
33
+ ("-1", nn.AvgPool2d(stride)),
34
+ ("0", nn.Conv2d(inplanes, planes *
35
+ self.expansion, 1, stride=1, bias=False)),
36
+ ("1", nn.BatchNorm2d(planes * self.expansion))
37
+ ]))
38
+
39
+ def forward(self, x: torch.Tensor):
40
+ identity = x
41
+ out = self.relu(self.bn1(self.conv1(x)))
42
+ out = self.relu(self.bn2(self.conv2(out)))
43
+ out = self.avgpool(out)
44
+ out = self.bn3(self.conv3(out))
45
+ if self.downsample is not None:
46
+ identity = self.downsample(x)
47
+ out += identity
48
+ out = self.relu(out)
49
+ return out
50
+
51
+
52
+ class AttentionPool2d(nn.Module):
53
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
54
+ super().__init__()
55
+ self.positional_embedding = nn.Parameter(
56
+ torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
57
+ )
58
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
59
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
60
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
62
+ self.num_heads = num_heads
63
+
64
+ def forward(self, x):
65
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2]
66
+ * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
67
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
68
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
69
+ x, _ = F.multi_head_attention_forward(
70
+ query=x, key=x, value=x,
71
+ embed_dim_to_check=x.shape[-1],
72
+ num_heads=self.num_heads,
73
+ q_proj_weight=self.q_proj.weight,
74
+ k_proj_weight=self.k_proj.weight,
75
+ v_proj_weight=self.v_proj.weight,
76
+ in_proj_weight=None,
77
+ in_proj_bias=torch.cat(
78
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
79
+ ),
80
+ bias_k=None,
81
+ bias_v=None,
82
+ add_zero_attn=False,
83
+ dropout_p=0,
84
+ out_proj_weight=self.c_proj.weight,
85
+ out_proj_bias=self.c_proj.bias,
86
+ use_separate_proj_weight=True,
87
+ training=self.training,
88
+ need_weights=False
89
+ )
90
+ return x[0]
91
+
92
+
93
+ class ModifiedResNet(nn.Module):
94
+ """
95
+ A ResNet class that is similar to torchvision's but contains the following changes:
96
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
+ - The final pooling layer is a QKV attention instead of an average pool
99
+ """
100
+
101
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102
+ super().__init__()
103
+ self.output_dim = output_dim
104
+ self.input_resolution = input_resolution
105
+ # the 3-layer stem
106
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3,
107
+ stride=2, padding=1, bias=False)
108
+ self.bn1 = nn.BatchNorm2d(width // 2)
109
+ self.conv2 = nn.Conv2d(width // 2, width // 2,
110
+ kernel_size=3, padding=1, bias=False)
111
+ self.bn2 = nn.BatchNorm2d(width // 2)
112
+ self.conv3 = nn.Conv2d(
113
+ width // 2, width, kernel_size=3, padding=1, bias=False)
114
+ self.bn3 = nn.BatchNorm2d(width)
115
+ self.avgpool = nn.AvgPool2d(2)
116
+ self.relu = nn.ReLU(inplace=True)
117
+ # residual layers
118
+ self._inplanes = width # this is a *mutable* variable used during construction
119
+ self.layer1 = self._make_layer(width, layers[0])
120
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
121
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
122
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
123
+ embed_dim = width * 32 # the ResNet feature dimension
124
+ self.attnpool = AttentionPool2d(
125
+ input_resolution // 32, embed_dim, heads, output_dim
126
+ )
127
+ self.apply(self._init_weights)
128
+
129
+ def _init_weights(self, m):
130
+ if isinstance(m, (nn.BatchNorm2d, LayerNorm)):
131
+ nn.init.constant_(m.weight, 1)
132
+ nn.init.constant_(m.bias, 0)
133
+ elif isinstance(m, (nn.Linear, nn.Conv2d)):
134
+ trunc_normal_(m.weight, std=0.02)
135
+ if m.bias is not None:
136
+ nn.init.constant_(m.bias, 0)
137
+
138
+ def _make_layer(self, planes, blocks, stride=1):
139
+ layers = [Bottleneck(self._inplanes, planes, stride)]
140
+ self._inplanes = planes * Bottleneck.expansion
141
+ for _ in range(1, blocks):
142
+ layers.append(Bottleneck(self._inplanes, planes))
143
+ return nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ def stem(x):
147
+ for conv, bn in [
148
+ (self.conv1, self.bn1),
149
+ (self.conv2, self.bn2),
150
+ (self.conv3, self.bn3)
151
+ ]:
152
+ x = self.relu(bn(conv(x)))
153
+ x = self.avgpool(x)
154
+ return x
155
+ x = x.type(self.conv1.weight.dtype)
156
+ x = stem(x)
157
+ x = self.layer1(x)
158
+ x = self.layer2(x)
159
+ x = self.layer3(x)
160
+ x = self.layer4(x)
161
+ x = self.attnpool(x)
162
+ return x
163
+
164
+
165
+ class LayerNorm(nn.Module):
166
+ def __init__(self, hidden_size, eps=1e-5):
167
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
168
+ """
169
+ super(LayerNorm, self).__init__()
170
+ self.weight = nn.Parameter(torch.ones(hidden_size))
171
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
172
+ self.variance_epsilon = eps
173
+
174
+ def forward(self, x):
175
+ pdtype = x.dtype
176
+ x = x.float()
177
+ u = x.mean(-1, keepdim=True)
178
+ s = (x - u).pow(2).mean(-1, keepdim=True)
179
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
180
+ return self.weight * x.to(pdtype) + self.bias
181
+
182
+
183
+ class QuickGELU(nn.Module):
184
+ def forward(self, x: torch.Tensor):
185
+ return x * torch.sigmoid(1.702 * x)
186
+
187
+
188
+ class ResidualAttentionBlock(nn.Module):
189
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path=0.):
190
+ super().__init__()
191
+ self.attn = nn.MultiheadAttention(d_model, n_head)
192
+ self.ln_1 = LayerNorm(d_model)
193
+ self.mlp = nn.Sequential(OrderedDict([
194
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
195
+ ("gelu", QuickGELU()),
196
+ ("c_proj", nn.Linear(d_model * 4, d_model))
197
+ ]))
198
+ self.ln_2 = LayerNorm(d_model)
199
+ self.attn_mask = attn_mask
200
+ self.drop_path = DropPath(
201
+ drop_path) if drop_path > 0. else nn.Identity()
202
+
203
+ def add_drop_path(self, drop_path):
204
+ self.drop_path = DropPath(
205
+ drop_path) if drop_path > 0. else nn.Identity()
206
+
207
+ def attention(self, x: torch.Tensor):
208
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
209
+ if self.attn_mask is not None else None
210
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
211
+
212
+ def forward(self, x: torch.Tensor):
213
+ x = x + self.drop_path(self.attention(self.ln_1(x)))
214
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
215
+ return x
216
+
217
+
218
+ class Transformer(nn.Module):
219
+ def __init__(self,
220
+ width: int,
221
+ layers: int,
222
+ heads: int,
223
+ attn_mask: torch.Tensor = None,
224
+ use_checkpoint=True,
225
+ drop_rate=0.,
226
+ attn_drop_rate=0.,
227
+ drop_path_rate=0.,
228
+ ):
229
+ super().__init__()
230
+ self.width = width
231
+ self.layers = layers
232
+ self.use_checkpoint = use_checkpoint
233
+ # stochastic depth decay rule
234
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers)]
235
+ self.resblocks = nn.ModuleList([
236
+ ResidualAttentionBlock(width, heads, attn_mask, drop_path=dpr[i])
237
+ for i in range(layers)
238
+ ])
239
+ self.apply(self._init_weights)
240
+
241
+ def _init_weights(self, m):
242
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
243
+ trunc_normal_(m.weight, std=0.02)
244
+ if m.bias is not None:
245
+ nn.init.constant_(m.bias, 0)
246
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
247
+ nn.init.constant_(m.bias, 0)
248
+ nn.init.constant_(m.weight, 1.0)
249
+
250
+ def forward(self, x: torch.Tensor):
251
+ for i, blk in enumerate(self.resblocks):
252
+ x = blk(x)
253
+ return x
254
+
255
+
256
+
257
+ class VisualTransformer(nn.Module):
258
+ positional_embedding: nn.Parameter
259
+
260
+ def __init__(self,
261
+ input_resolution: int,
262
+ patch_size: int,
263
+ width: int,
264
+ layers: int,
265
+ heads: int,
266
+ output_dim: int,
267
+ pool_type: str = 'default',
268
+ skip_cls: bool = False,
269
+ drop_path_rate=0.,
270
+ **kwargs):
271
+ super().__init__()
272
+ self.pool_type = pool_type
273
+ self.skip_cls = skip_cls
274
+ self.input_resolution = input_resolution
275
+ self.output_dim = output_dim
276
+ self.conv1 = nn.Conv2d(
277
+ in_channels=3,
278
+ out_channels=width,
279
+ kernel_size=patch_size,
280
+ stride=patch_size,
281
+ bias=False
282
+ )
283
+ self.config = kwargs.get("config", None)
284
+ self.sequence_length = (input_resolution // patch_size) ** 2 + 1
285
+ self.conv_pool = nn.Identity()
286
+ if (self.pool_type == 'linear'):
287
+ if (not self.skip_cls):
288
+ self.conv_pool = nn.Conv1d(
289
+ width, width, self.sequence_length, stride=self.sequence_length, groups=width)
290
+ else:
291
+ self.conv_pool = nn.Conv1d(
292
+ width, width, self.sequence_length-1, stride=self.sequence_length, groups=width)
293
+ scale = width ** -0.5
294
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
295
+ self.positional_embedding = nn.Parameter(
296
+ scale * torch.randn(
297
+ self.sequence_length, width
298
+ )
299
+ )
300
+ self.ln_pre = LayerNorm(width)
301
+ self.transformer = Transformer(
302
+ width, layers, heads, drop_path_rate=drop_path_rate)
303
+ self.ln_post = LayerNorm(width)
304
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
305
+ if self.config is not None and self.config.MIM.ENABLE:
306
+ logging.info("MIM ENABLED")
307
+ self.mim = True
308
+ self.lm_transformer = Transformer(
309
+ width, self.config.MIM.LAYERS, heads)
310
+ self.ln_lm = LayerNorm(width)
311
+ self.lm_head = nn.Linear(width, self.config.MIM.VOCAB_SIZE)
312
+ self.mask_token = nn.Parameter(scale * torch.randn(width))
313
+ else:
314
+ self.mim = False
315
+ self.apply(self._init_weights)
316
+
317
+ def _init_weights(self, m):
318
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):
319
+ trunc_normal_(m.weight, std=0.02)
320
+ if m.bias is not None:
321
+ nn.init.constant_(m.bias, 0)
322
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
323
+ nn.init.constant_(m.bias, 0)
324
+ nn.init.constant_(m.weight, 1.0)
325
+
326
+ def forward(self, x: torch.Tensor):
327
+ x = self.conv1(x) # shape = [*, width, grid, grid]
328
+ # shape = [*, width, grid ** 2]
329
+ x = x.reshape(x.shape[0], x.shape[1], -1)
330
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
331
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1],
332
+ dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
333
+ x = x + self.positional_embedding.to(x.dtype)
334
+ x = self.ln_pre(x)
335
+ x = x.permute(1, 0, 2) # NLD -> LND
336
+ x = self.transformer(x)
337
+ x = x.permute(1, 0, 2) # LND -> NLD
338
+ if (self.pool_type == 'average'):
339
+ if self.skip_cls:
340
+ x = x[:, 1:, :]
341
+ x = torch.mean(x, dim=1)
342
+ elif (self.pool_type == 'linear'):
343
+ if self.skip_cls:
344
+ x = x[:, 1:, :]
345
+ x = x.permute(0, 2, 1)
346
+ x = self.conv_pool(x)
347
+ x = x.permute(0, 2, 1).squeeze()
348
+ else:
349
+ x = x[:, 0, :]
350
+ x = self.ln_post(x)
351
+ if self.proj is not None:
352
+ x = x @ self.proj
353
+ return x
354
+
355
+ def forward_mim(self, x: torch.Tensor, bool_masked_pos, return_all_tokens=False, disable_vlc=False):
356
+ x = self.conv1(x) # shape = [*, width, grid, grid]
357
+ # shape = [*, width, grid ** 2]
358
+ x = x.reshape(x.shape[0], x.shape[1], -1)
359
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
360
+ batch_size, seq_len, _ = x.size()
361
+ mask_token = self.mask_token.unsqueeze(
362
+ 0).unsqueeze(0).expand(batch_size, seq_len, -1)
363
+ w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
364
+ masked_x = x * (1 - w) + mask_token * w
365
+ if disable_vlc:
366
+ x = masked_x
367
+ masked_start = 0
368
+ else:
369
+ x = torch.cat([x, masked_x], 0)
370
+ masked_start = batch_size
371
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(
372
+ x.shape[0], 1, x.shape[-1],
373
+ dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
374
+ x = x + self.positional_embedding.to(x.dtype)
375
+ x = self.ln_pre(x)
376
+ x = x.permute(1, 0, 2) # NLD -> LND
377
+ x = self.transformer(x)
378
+ masked_x = x[:, masked_start:]
379
+ masked_x = self.lm_transformer(masked_x)
380
+ masked_x = masked_x.permute(1, 0, 2)
381
+ masked_x = masked_x[:, 1:]
382
+ masked_x = self.ln_lm(masked_x)
383
+ if not return_all_tokens:
384
+ masked_x = masked_x[bool_masked_pos]
385
+ logits = self.lm_head(masked_x)
386
+ assert self.pool_type == "default"
387
+ result = {"logits": logits}
388
+ if not disable_vlc:
389
+ x = x[0, :batch_size]
390
+ x = self.ln_post(x)
391
+ if self.proj is not None:
392
+ x = x @ self.proj
393
+ result["feature"] = x
394
+ return result
395
+
396
+
397
+ def load_farl(model_type, model_file=None) -> VisualTransformer:
398
+ if model_type == "base":
399
+ model = VisualTransformer(
400
+ input_resolution=224, patch_size=16, width=768, layers=12, heads=12, output_dim=512)
401
+ elif model_type == "large":
402
+ model = VisualTransformer(
403
+ input_resolution=224, patch_size=16, width=1024, layers=24, heads=16, output_dim=512)
404
+ elif model_type == "huge":
405
+ model = VisualTransformer(
406
+ input_resolution=224, patch_size=14, width=1280, layers=32, heads=16, output_dim=512)
407
+ else:
408
+ raise
409
+ model.transformer.use_checkpoint = False
410
+ if model_file is not None:
411
+ checkpoint = torch.load(model_file, map_location='cpu')
412
+ state_dict = {}
413
+ for name, weight in checkpoint["state_dict"].items():
414
+ if name.startswith("visual"):
415
+ state_dict[name[7:]] = weight
416
+ inco = model.load_state_dict(state_dict, strict=False)
417
+ # print(inco.missing_keys)
418
+ assert len(inco.missing_keys) == 0
419
+ return model
src/pixel3dmm/preprocessing/facer/facer/io.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_hwc(path: str) -> torch.Tensor:
7
+ """Read an image from a given path.
8
+
9
+ Args:
10
+ path (str): The given path.
11
+ """
12
+ image = Image.open(path)
13
+ np_image = np.array(image.convert('RGB'))
14
+ return torch.from_numpy(np_image)
15
+
16
+
17
+ def write_hwc(image: torch.Tensor, path: str):
18
+ """Write an image to a given path.
19
+
20
+ Args:
21
+ image (torch.Tensor): The image.
22
+ path (str): The given path.
23
+ """
24
+
25
+ Image.fromarray(image.cpu().numpy()).save(path)
26
+
27
+
28
+
src/pixel3dmm/preprocessing/facer/facer/show.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+
6
+ from .util import bchw2hwc
7
+
8
+
9
+ def set_figsize(*args):
10
+ if len(args) == 0:
11
+ plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]
12
+ elif len(args) == 1:
13
+ plt.rcParams["figure.figsize"] = (args[0], args[0])
14
+ elif len(args) == 2:
15
+ plt.rcParams["figure.figsize"] = tuple(args)
16
+ else:
17
+ raise RuntimeError(
18
+ f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)')
19
+
20
+
21
+ def show_hwc(image: torch.Tensor):
22
+ if image.dtype != torch.uint8:
23
+ image = image.to(torch.uint8)
24
+ if image.size(2) == 1:
25
+ image = image.repeat(1, 1, 3)
26
+ pimage = Image.fromarray(image.cpu().numpy())
27
+ plt.imshow(pimage)
28
+ plt.show()
29
+
30
+
31
+ def show_bchw(image: torch.Tensor):
32
+ show_hwc(bchw2hwc(image))
33
+
34
+
35
+ def show_bhw(image: torch.Tensor):
36
+ show_bchw(image.unsqueeze(1))
src/pixel3dmm/preprocessing/facer/facer/transform.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Callable, Tuple, Optional
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import functools
5
+ import numpy as np
6
+
7
+
8
+ def get_crop_and_resize_matrix(
9
+ box: torch.Tensor, target_shape: Tuple[int, int],
10
+ target_face_scale: float = 1.0, make_square_crop: bool = True,
11
+ offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True,
12
+ offset_box_coords: bool = False) -> torch.Tensor:
13
+ """
14
+ Args:
15
+ box: b x 4(x1, y1, x2, y2)
16
+ align_corners (bool): Set this to `True` only if the box you give has coordinates
17
+ ranging from `0` to `h-1` or `w-1`.
18
+
19
+ offset_box_coords (bool): Set this to `True` if the box you give has coordinates
20
+ ranging from `0` to `h` or `w`.
21
+
22
+ Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`.
23
+
24
+ If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`.
25
+
26
+ Returns:
27
+ torch.Tensor: b x 3 x 3.
28
+ """
29
+ if offset_xy is None:
30
+ offset_xy = (0.0, 0.0)
31
+
32
+ x1, y1, x2, y2 = box.split(1, dim=1) # b x 1
33
+ cx = (x1 + x2) / 2 + offset_xy[0]
34
+ cy = (y1 + y2) / 2 + offset_xy[1]
35
+ rx = (x2 - x1) / 2 / target_face_scale
36
+ ry = (y2 - y1) / 2 / target_face_scale
37
+ if make_square_crop:
38
+ rx = ry = torch.maximum(rx, ry)
39
+
40
+ x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry
41
+
42
+ h, w, *_ = target_shape
43
+
44
+ zeros_pl = torch.zeros_like(x1)
45
+ ones_pl = torch.ones_like(x1)
46
+
47
+ if align_corners:
48
+ # x -> (x - x1) / (x2 - x1) * (w - 1)
49
+ # y -> (y - y1) / (y2 - y1) * (h - 1)
50
+ ax = 1.0 / (x2 - x1) * (w - 1)
51
+ ay = 1.0 / (y2 - y1) * (h - 1)
52
+ matrix = torch.cat([
53
+ ax, zeros_pl, -x1 * ax,
54
+ zeros_pl, ay, -y1 * ay,
55
+ zeros_pl, zeros_pl, ones_pl
56
+ ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
57
+ else:
58
+ if offset_box_coords:
59
+ # x1, x2 \in [0, w], y1, y2 \in [0, h]
60
+ # first we should offset x1, x2, y1, y2 to be ranging in
61
+ # [-0.5, w-0.5] and [-0.5, h-0.5]
62
+ # so to convert these pixel coordinates into boundary coordinates.
63
+ x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5
64
+
65
+ # x -> (x - x1) / (x2 - x1) * w - 0.5
66
+ # y -> (y - y1) / (y2 - y1) * h - 0.5
67
+ ax = 1.0 / (x2 - x1) * w
68
+ ay = 1.0 / (y2 - y1) * h
69
+ matrix = torch.cat([
70
+ ax, zeros_pl, -x1 * ax - 0.5*ones_pl,
71
+ zeros_pl, ay, -y1 * ay - 0.5*ones_pl,
72
+ zeros_pl, zeros_pl, ones_pl
73
+ ], dim=1).reshape(-1, 3, 3) # b x 3 x 3
74
+ return matrix
75
+
76
+
77
+ def get_similarity_transform_matrix(
78
+ from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Args:
81
+ from_pts, to_pts: b x n x 2
82
+
83
+ Returns:
84
+ torch.Tensor: b x 3 x 3
85
+ """
86
+ mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2
87
+ mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2
88
+
89
+ a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b
90
+ c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b
91
+
92
+ to_delta = to_pts - mto
93
+ from_delta = from_pts - mfrom
94
+ c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:,
95
+ :, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b
96
+
97
+ a = c1 / a1
98
+ b = c2 / a1
99
+ dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b
100
+ dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b
101
+
102
+ ones_pl = torch.ones_like(a1)
103
+ zeros_pl = torch.zeros_like(a1)
104
+
105
+ return torch.stack([
106
+ a, b, dx,
107
+ -b, a, dy,
108
+ zeros_pl, zeros_pl, ones_pl,
109
+ ], dim=-1).reshape(-1, 3, 3)
110
+
111
+
112
+ @functools.lru_cache()
113
+ def _standard_face_pts():
114
+ pts = torch.tensor([
115
+ 196.0, 226.0,
116
+ 316.0, 226.0,
117
+ 256.0, 286.0,
118
+ 220.0, 360.4,
119
+ 292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0
120
+ return torch.reshape(pts, (5, 2))
121
+
122
+
123
+ def get_face_align_matrix(
124
+ face_pts: torch.Tensor, target_shape: Tuple[int, int],
125
+ target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None,
126
+ target_pts: Optional[torch.Tensor] = None):
127
+
128
+ if target_pts is None:
129
+ with torch.no_grad():
130
+ std_pts = _standard_face_pts().to(face_pts) # [-1 1]
131
+ h, w, *_ = target_shape
132
+ target_pts = (std_pts * target_face_scale + 1) * \
133
+ torch.tensor([w-1, h-1]).to(face_pts) / 2.0
134
+ if offset_xy is not None:
135
+ target_pts[:, 0] += offset_xy[0]
136
+ target_pts[:, 1] += offset_xy[1]
137
+ else:
138
+ target_pts = target_pts.to(face_pts)
139
+
140
+ if target_pts.dim() == 2:
141
+ target_pts = target_pts.unsqueeze(0)
142
+ if target_pts.size(0) == 1:
143
+ target_pts = target_pts.broadcast_to(face_pts.shape)
144
+
145
+ assert target_pts.shape == face_pts.shape
146
+
147
+ return get_similarity_transform_matrix(face_pts, target_pts)
148
+
149
+
150
+ def rot90(v):
151
+ return np.array([-v[1], v[0]])
152
+
153
+
154
+ def get_quad(lm: torch.Tensor):
155
+ # N,2
156
+ lm = lm.detach().cpu().numpy()
157
+ # Choose oriented crop rectangle.
158
+ eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
159
+ mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
160
+ eye_to_eye = lm[1] - lm[0]
161
+ eye_to_mouth = mouth_avg - eye_avg
162
+ x = eye_to_eye - rot90(eye_to_mouth)
163
+ x /= np.hypot(*x)
164
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
165
+ y = rot90(x)
166
+ c = eye_avg + eye_to_mouth * 0.1
167
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
168
+ quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下
169
+ return torch.from_numpy(quad_for_coeffs).float()
170
+
171
+
172
+ def get_face_align_matrix_celebm(
173
+ face_pts: torch.Tensor, target_shape: Tuple[int, int]):
174
+
175
+ face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts)
176
+
177
+ assert target_shape[0] == target_shape[1]
178
+ target_size = target_shape[0]
179
+ target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts)
180
+
181
+ if target_pts.dim() == 2:
182
+ target_pts = target_pts.unsqueeze(0)
183
+ if target_pts.size(0) == 1:
184
+ target_pts = target_pts.broadcast_to(face_pts.shape)
185
+
186
+ assert target_pts.shape == face_pts.shape
187
+
188
+ return get_similarity_transform_matrix(face_pts, target_pts)
189
+
190
+ @functools.lru_cache(maxsize=128)
191
+ def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
192
+ yy, xx = torch.meshgrid(torch.arange(h).float(),
193
+ torch.arange(w).float(),
194
+ indexing='ij')
195
+ return yy + 0.5, xx + 0.5
196
+
197
+
198
+ def _forge_grid(batch_size: int, device: torch.device,
199
+ output_shape: Tuple[int, int],
200
+ fn: Callable[[torch.Tensor], torch.Tensor]
201
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
202
+ """ Forge transform maps with a given function `fn`.
203
+
204
+ Args:
205
+ output_shape (tuple): (b, h, w, ...).
206
+ fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts
207
+ a bxnx2 array and outputs the transformed bxnx2 array. Both input
208
+ and output store (x, y) coordinates.
209
+
210
+ Note:
211
+ both input and output arrays of `fn` should store (y, x) coordinates.
212
+
213
+ Returns:
214
+ Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each
215
+ pixel (y, x) or coordinate (x, y),
216
+ `(X[y, x], Y[y, x]) = fn([x, y])`
217
+ """
218
+ h, w, *_ = output_shape
219
+ yy, xx = _meshgrid(h, w) # h x w
220
+ yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
221
+ xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
222
+
223
+ in_xxyy = torch.stack(
224
+ [xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2
225
+ out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2
226
+ return out_xxyy.reshape(batch_size, h, w, 2)
227
+
228
+
229
+ def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor:
230
+ return torch.clamp(x, -1+eps, 1-eps).arctanh()
231
+
232
+
233
+ def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor,
234
+ warp_factor: float, warped_shape: Tuple[int, int]):
235
+ """ Inverted tanh-warp function.
236
+
237
+ Args:
238
+ coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
239
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
240
+ from the original image to the aligned yet not-warped image.
241
+ warp_factor (float): The warp factor.
242
+ 0 means linear transform, 1 means full tanh warp.
243
+ warped_shape (tuple): [height, width].
244
+
245
+ Returns:
246
+ torch.Tensor: b x n x 2 (x, y). The original coordinates.
247
+ """
248
+ h, w, *_ = warped_shape
249
+ # h -= 1
250
+ # w -= 1
251
+
252
+ w_h = torch.tensor([[w, h]]).to(coords)
253
+
254
+ if warp_factor > 0:
255
+ # normalize coordinates to [-1, +1]
256
+ coords = coords / w_h * 2 - 1
257
+
258
+ nl_part1 = coords > 1.0 - warp_factor
259
+ nl_part2 = coords < -1.0 + warp_factor
260
+
261
+ ret_nl_part1 = _safe_arctanh(
262
+ (coords - 1.0 + warp_factor) /
263
+ warp_factor) * warp_factor + \
264
+ 1.0 - warp_factor
265
+ ret_nl_part2 = _safe_arctanh(
266
+ (coords + 1.0 - warp_factor) /
267
+ warp_factor) * warp_factor - \
268
+ 1.0 + warp_factor
269
+
270
+ coords = torch.where(nl_part1, ret_nl_part1,
271
+ torch.where(nl_part2, ret_nl_part2, coords))
272
+
273
+ # denormalize
274
+ coords = (coords + 1) / 2 * w_h
275
+
276
+ coords_homo = torch.cat(
277
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
278
+
279
+ inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3
280
+ # inv_matrix = np.linalg.inv(matrix)
281
+ coords_homo = torch.bmm(
282
+ coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3
283
+ return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]
284
+
285
+
286
+ def tanh_warp_transform(
287
+ coords: torch.Tensor, matrix: torch.Tensor,
288
+ warp_factor: float, warped_shape: Tuple[int, int]):
289
+ """ Tanh-warp function.
290
+
291
+ Args:
292
+ coords (torch.Tensor): b x n x 2 (x, y). The original coordinates.
293
+ matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
294
+ from the original image to the aligned yet not-warped image.
295
+ warp_factor (float): The warp factor.
296
+ 0 means linear transform, 1 means full tanh warp.
297
+ warped_shape (tuple): [height, width].
298
+
299
+ Returns:
300
+ torch.Tensor: b x n x 2 (x, y). The transformed coordinates.
301
+ """
302
+ h, w, *_ = warped_shape
303
+ # h -= 1
304
+ # w -= 1
305
+ w_h = torch.tensor([[w, h]]).to(coords)
306
+
307
+ coords_homo = torch.cat(
308
+ [coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
309
+
310
+ coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3
311
+ coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2
312
+
313
+ if warp_factor > 0:
314
+ # normalize coordinates to [-1, +1]
315
+ coords = coords / w_h * 2 - 1
316
+
317
+ nl_part1 = coords > 1.0 - warp_factor
318
+ nl_part2 = coords < -1.0 + warp_factor
319
+
320
+ ret_nl_part1 = torch.tanh(
321
+ (coords - 1.0 + warp_factor) /
322
+ warp_factor) * warp_factor + \
323
+ 1.0 - warp_factor
324
+ ret_nl_part2 = torch.tanh(
325
+ (coords + 1.0 - warp_factor) /
326
+ warp_factor) * warp_factor - \
327
+ 1.0 + warp_factor
328
+
329
+ coords = torch.where(nl_part1, ret_nl_part1,
330
+ torch.where(nl_part2, ret_nl_part2, coords))
331
+
332
+ # denormalize
333
+ coords = (coords + 1) / 2 * w_h
334
+
335
+ return coords
336
+
337
+
338
+ def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
339
+ warped_shape: Tuple[int, int],
340
+ orig_shape: Tuple[int, int]):
341
+ """
342
+ Args:
343
+ matrix: bx3x3 matrix.
344
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
345
+ `warp_factor=0.0` represents a cropping.
346
+ warped_shape: The target image shape to transform to.
347
+
348
+ Returns:
349
+ torch.Tensor: b x h x w x 2 (x, y).
350
+ """
351
+ orig_h, orig_w, *_ = orig_shape
352
+ w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
353
+ return _forge_grid(
354
+ matrix.size(0), matrix.device,
355
+ warped_shape,
356
+ functools.partial(inverted_tanh_warp_transform,
357
+ matrix=matrix,
358
+ warp_factor=warp_factor,
359
+ warped_shape=warped_shape)) / w_h*2-1
360
+
361
+
362
+ def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
363
+ warped_shape: Tuple[int, int],
364
+ orig_shape: Tuple[int, int]):
365
+ """
366
+ Args:
367
+ matrix: bx3x3 matrix.
368
+ warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
369
+ `warp_factor=0.0` represents a cropping.
370
+ warped_shape: The target image shape to transform to.
371
+ orig_shape: The original image shape that is transformed from.
372
+
373
+ Returns:
374
+ torch.Tensor: b x h x w x 2 (x, y).
375
+ """
376
+ h, w, *_ = warped_shape
377
+ w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2)
378
+ return _forge_grid(
379
+ matrix.size(0), matrix.device,
380
+ orig_shape,
381
+ functools.partial(tanh_warp_transform,
382
+ matrix=matrix,
383
+ warp_factor=warp_factor,
384
+ warped_shape=warped_shape)) / w_h * 2-1
src/pixel3dmm/preprocessing/facer/facer/util.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Optional, Union, List, Dict
3
+ import math
4
+ import os
5
+ from urllib.parse import urlparse
6
+ import errno
7
+ import sys
8
+ import validators
9
+ import requests
10
+ import json
11
+
12
+
13
+ def hwc2bchw(images: torch.Tensor) -> torch.Tensor:
14
+ return images.unsqueeze(0).permute(0, 3, 1, 2)
15
+
16
+
17
+ def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2,
18
+ background_value: float = 0) -> torch.Tensor:
19
+ """ make a grid image from an image batch.
20
+
21
+ Args:
22
+ images (torch.Tensor): input image batch.
23
+ nrows: rows of grid.
24
+ border: border size in pixel.
25
+ background_value: color value of background.
26
+ """
27
+ assert images.ndim == 4 # n x c x h x w
28
+ images = images.permute(0, 2, 3, 1) # n x h x w x c
29
+ n, h, w, c = images.shape
30
+ if nrows is None:
31
+ nrows = max(int(math.sqrt(n)), 1)
32
+ ncols = (n + nrows - 1) // nrows
33
+ result = torch.full([(h + border) * nrows - border,
34
+ (w + border) * ncols - border, c], background_value,
35
+ device=images.device,
36
+ dtype=images.dtype)
37
+
38
+ for i, single_image in enumerate(images):
39
+ row = i // ncols
40
+ col = i % ncols
41
+ yy = (h + border) * row
42
+ xx = (w + border) * col
43
+ result[yy:(yy + h), xx:(xx + w), :] = single_image
44
+ return result
45
+
46
+
47
+ def bchw2bhwc(images: torch.Tensor) -> torch.Tensor:
48
+ return images.permute(0, 2, 3, 1)
49
+
50
+
51
+ def bhwc2bchw(images: torch.Tensor) -> torch.Tensor:
52
+ return images.permute(0, 3, 1, 2)
53
+
54
+
55
+ def bhwc2hwc(images: torch.Tensor, *kargs, **kwargs) -> torch.Tensor:
56
+ return bchw2hwc(bhwc2bchw(images), *kargs, **kwargs)
57
+
58
+
59
+ def select_data(selection, data):
60
+ if isinstance(data, dict):
61
+ return {name: select_data(selection, val) for name, val in data.items()}
62
+ elif isinstance(data, (list, tuple)):
63
+ return [select_data(selection, val) for val in data]
64
+ elif isinstance(data, torch.Tensor):
65
+ return data[selection]
66
+ return data
67
+
68
+
69
+ def download_from_github(to_path, organisation, repository, file_path, branch='main', username=None, access_token=None):
70
+ """ download files (including LFS files) from github.
71
+
72
+ For example, in order to downlod https://github.com/FacePerceiver/facer/blob/main/README.md, call with
73
+ ```
74
+ download_from_github(
75
+ to_path='README.md', organisation='FacePerceiver',
76
+ repository='facer', file_path='README.md', branch='main')
77
+ ```
78
+ """
79
+ if username is not None:
80
+ assert access_token is not None
81
+ auth = (username, access_token)
82
+ else:
83
+ auth = None
84
+ r = requests.get(f'https://api.github.com/repos/{organisation}/{repository}/contents/{file_path}?ref={branch}',
85
+ auth=auth)
86
+ data = json.loads(r.content)
87
+ torch.hub.download_url_to_file(data['download_url'], to_path)
88
+
89
+
90
+ def is_github_url(url: str):
91
+ """
92
+ A typical github url should be like
93
+ https://github.com/FacePerceiver/facer/blob/main/facer/util.py or
94
+ https://github.com/FacePerceiver/facer/raw/main/facer/util.py.
95
+ """
96
+ return ('blob' in url or 'raw' in url) and url.startswith('https://github.com/')
97
+
98
+
99
+ def get_github_components(url: str):
100
+ assert is_github_url(url)
101
+ organisation, repository, blob_or_raw, branch, * \
102
+ path = url[len('https://github.com/'):].split('/')
103
+ assert blob_or_raw in {'blob', 'raw'}
104
+ return organisation, repository, branch, '/'.join(path)
105
+
106
+
107
+ def download_url_to_file(url, dst, **kwargs):
108
+ if is_github_url(url):
109
+ org, rep, branch, path = get_github_components(url)
110
+ download_from_github(dst, org, rep, path, branch, kwargs.get(
111
+ 'username', None), kwargs.get('access_token', None))
112
+ else:
113
+ torch.hub.download_url_to_file(url, dst)
114
+
115
+
116
+ def select_data(selection, data):
117
+ if isinstance(data, dict):
118
+ return {name: select_data(selection, val) for name, val in data.items()}
119
+ elif isinstance(data, (list, tuple)):
120
+ return [select_data(selection, val) for val in data]
121
+ elif isinstance(data, torch.Tensor):
122
+ return data[selection]
123
+ return data
124
+
125
+
126
+ def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None, jit=True, **kwargs):
127
+ if isinstance(url_or_paths, str):
128
+ url_or_paths = [url_or_paths]
129
+
130
+ for url_or_path in url_or_paths:
131
+ try:
132
+ if validators.url(url_or_path):
133
+ url = url_or_path
134
+ if model_dir is None:
135
+ if hasattr(torch.hub, 'get_dir'):
136
+ hub_dir = torch.hub.get_dir()
137
+ else:
138
+ hub_dir = os.path.join(os.path.expanduser(
139
+ '~'), '.cache', 'torch', 'hub')
140
+ model_dir = os.path.join(hub_dir, 'checkpoints')
141
+
142
+ try:
143
+ os.makedirs(model_dir)
144
+ except OSError as e:
145
+ if e.errno == errno.EEXIST:
146
+ # Directory already exists, ignore.
147
+ pass
148
+ else:
149
+ # Unexpected OSError, re-raise.
150
+ raise
151
+
152
+ parts = urlparse(url)
153
+ filename = os.path.basename(parts.path)
154
+ cached_file = os.path.join(model_dir, filename)
155
+ if not os.path.exists(cached_file):
156
+ sys.stderr.write(
157
+ 'Downloading: "{}" to {}\n'.format(url, cached_file))
158
+ download_url_to_file(url, cached_file)
159
+ else:
160
+ cached_file = url_or_path
161
+ if jit:
162
+ return torch.jit.load(cached_file, map_location=map_location, **kwargs)
163
+ else:
164
+ return torch.load(cached_file, map_location=map_location, **kwargs)
165
+ except:
166
+ sys.stderr.write(f'failed downloading from {url_or_path}\n')
167
+ raise
168
+
169
+ raise RuntimeError('failed to download jit models from all given urls')
src/pixel3dmm/preprocessing/facer/facer/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__="0.0.5"
src/pixel3dmm/preprocessing/facer/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch >= 1.9.1
2
+ torchvision
3
+ timm
4
+ pillow
5
+ numpy
6
+ ipywidgets
7
+ scikit-image
8
+ matplotlib
9
+ validators
10
+ requests
11
+ opencv-python
src/pixel3dmm/preprocessing/facer/samples/data/ffhq_15723.jpg ADDED

Git LFS Details

  • SHA256: 7de44ea12326e0c91d249af710d9469f20b9e08b1935a9c8e435b89d09f5e268
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
src/pixel3dmm/preprocessing/facer/samples/data/fire.webp ADDED
src/pixel3dmm/preprocessing/facer/samples/data/girl.jpg ADDED
src/pixel3dmm/preprocessing/facer/samples/data/sideface.jpg ADDED
src/pixel3dmm/preprocessing/facer/samples/data/twogirls.jpg ADDED
src/pixel3dmm/preprocessing/facer/samples/data/weirdface.jpg ADDED

Git LFS Details

  • SHA256: 3c97a31b51b239b8f727c15086cfe09a0b024900c4d0ebbba83e80fce8c6a51c
  • Pointer size: 131 Bytes
  • Size of remote file: 154 kB
src/pixel3dmm/preprocessing/facer/samples/data/weirdface2.jpg ADDED
src/pixel3dmm/preprocessing/facer/samples/data/weirdface3.jpg ADDED
src/pixel3dmm/preprocessing/facer/samples/download.ipynb ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 5,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import sys\n",
10
+ "sys.path.append('..')"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 6,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "data": {
20
+ "application/vnd.jupyter.widget-view+json": {
21
+ "model_id": "03bf1e12ed8a4b4ebf4ee9c5acda4a4f",
22
+ "version_major": 2,
23
+ "version_minor": 0
24
+ },
25
+ "text/plain": [
26
+ " 0%| | 0.00/1.25k [00:00<?, ?B/s]"
27
+ ]
28
+ },
29
+ "metadata": {},
30
+ "output_type": "display_data"
31
+ }
32
+ ],
33
+ "source": [
34
+ "import facer\n",
35
+ "facer.util.download_from_github(\n",
36
+ " to_path='.downloaded/download.ipynb', organisation='FacePerceiver',\n",
37
+ " repository='facer', file_path='samples/download.ipynb', branch='main')\n"
38
+ ]
39
+ }
40
+ ],
41
+ "metadata": {
42
+ "interpreter": {
43
+ "hash": "f55a4f7b5691d52d9535fa31bc30af4661bee88c7bc930914ceb21aeaa908798"
44
+ },
45
+ "kernelspec": {
46
+ "display_name": "Python 3.8.12 ('haya38')",
47
+ "language": "python",
48
+ "name": "python3"
49
+ },
50
+ "language_info": {
51
+ "codemirror_mode": {
52
+ "name": "ipython",
53
+ "version": 3
54
+ },
55
+ "file_extension": ".py",
56
+ "mimetype": "text/x-python",
57
+ "name": "python",
58
+ "nbconvert_exporter": "python",
59
+ "pygments_lexer": "ipython3",
60
+ "version": "3.8.13"
61
+ },
62
+ "orig_nbformat": 4
63
+ },
64
+ "nbformat": 4,
65
+ "nbformat_minor": 2
66
+ }
src/pixel3dmm/preprocessing/facer/samples/example_output/alignment.png ADDED

Git LFS Details

  • SHA256: 14df003541adf14ef4a163b7b463253f5bcdb17fa14f9cd70bd86047ce14ee6e
  • Pointer size: 131 Bytes
  • Size of remote file: 280 kB
src/pixel3dmm/preprocessing/facer/samples/example_output/detect.png ADDED

Git LFS Details

  • SHA256: d8beabe620fc331becdb989cd9d9e8b30c7f9ddd8a7b102b93aac1e16a0c7fe3
  • Pointer size: 131 Bytes
  • Size of remote file: 278 kB
src/pixel3dmm/preprocessing/facer/samples/example_output/parsing.png ADDED

Git LFS Details

  • SHA256: b63323df49fd2efdb9376d8454d501da2c71288858f91b750145bffb4f915fec
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
src/pixel3dmm/preprocessing/facer/samples/face_alignment.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/pixel3dmm/preprocessing/facer/samples/face_attribute.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/pixel3dmm/preprocessing/facer/samples/face_detect.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/pixel3dmm/preprocessing/facer/samples/face_parsing.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/pixel3dmm/preprocessing/facer/samples/transform.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/pixel3dmm/preprocessing/facer/scripts/build.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python setup.py bdist_wheel