Spaces:
Running
on
Zero
Running
on
Zero
Upload 52 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- src/pixel3dmm/preprocessing/facer/.gitignore +134 -0
- src/pixel3dmm/preprocessing/facer/LICENSE +21 -0
- src/pixel3dmm/preprocessing/facer/README.md +187 -0
- src/pixel3dmm/preprocessing/facer/facer/__init__.py +55 -0
- src/pixel3dmm/preprocessing/facer/facer/draw.py +186 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/__init__.py +2 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/base.py +24 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/farl.py +180 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/__init__.py +42 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/common.py +91 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/geometry.py +45 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/mmseg.py +29 -0
- src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/transformers.py +173 -0
- src/pixel3dmm/preprocessing/facer/facer/face_attribute/__init__.py +2 -0
- src/pixel3dmm/preprocessing/facer/facer/face_attribute/base.py +24 -0
- src/pixel3dmm/preprocessing/facer/facer/face_attribute/farl.py +158 -0
- src/pixel3dmm/preprocessing/facer/facer/face_detection/__init__.py +2 -0
- src/pixel3dmm/preprocessing/facer/facer/face_detection/base.py +19 -0
- src/pixel3dmm/preprocessing/facer/facer/face_detection/retinaface.py +677 -0
- src/pixel3dmm/preprocessing/facer/facer/face_parsing/__init__.py +2 -0
- src/pixel3dmm/preprocessing/facer/facer/face_parsing/base.py +27 -0
- src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py +174 -0
- src/pixel3dmm/preprocessing/facer/facer/farl/__init__.py +5 -0
- src/pixel3dmm/preprocessing/facer/facer/farl/classification.py +149 -0
- src/pixel3dmm/preprocessing/facer/facer/farl/model.py +419 -0
- src/pixel3dmm/preprocessing/facer/facer/io.py +28 -0
- src/pixel3dmm/preprocessing/facer/facer/show.py +36 -0
- src/pixel3dmm/preprocessing/facer/facer/transform.py +384 -0
- src/pixel3dmm/preprocessing/facer/facer/util.py +169 -0
- src/pixel3dmm/preprocessing/facer/facer/version.py +1 -0
- src/pixel3dmm/preprocessing/facer/requirements.txt +11 -0
- src/pixel3dmm/preprocessing/facer/samples/data/ffhq_15723.jpg +3 -0
- src/pixel3dmm/preprocessing/facer/samples/data/fire.webp +0 -0
- src/pixel3dmm/preprocessing/facer/samples/data/girl.jpg +0 -0
- src/pixel3dmm/preprocessing/facer/samples/data/sideface.jpg +0 -0
- src/pixel3dmm/preprocessing/facer/samples/data/twogirls.jpg +0 -0
- src/pixel3dmm/preprocessing/facer/samples/data/weirdface.jpg +3 -0
- src/pixel3dmm/preprocessing/facer/samples/data/weirdface2.jpg +0 -0
- src/pixel3dmm/preprocessing/facer/samples/data/weirdface3.jpg +0 -0
- src/pixel3dmm/preprocessing/facer/samples/download.ipynb +66 -0
- src/pixel3dmm/preprocessing/facer/samples/example_output/alignment.png +3 -0
- src/pixel3dmm/preprocessing/facer/samples/example_output/detect.png +3 -0
- src/pixel3dmm/preprocessing/facer/samples/example_output/parsing.png +3 -0
- src/pixel3dmm/preprocessing/facer/samples/face_alignment.ipynb +0 -0
- src/pixel3dmm/preprocessing/facer/samples/face_attribute.ipynb +0 -0
- src/pixel3dmm/preprocessing/facer/samples/face_detect.ipynb +0 -0
- src/pixel3dmm/preprocessing/facer/samples/face_parsing.ipynb +0 -0
- src/pixel3dmm/preprocessing/facer/samples/transform.ipynb +0 -0
- src/pixel3dmm/preprocessing/facer/scripts/build.sh +1 -0
.gitattributes
CHANGED
@@ -41,3 +41,8 @@ example_videos/ex3.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
41 |
example_videos/ex4.mp4 filter=lfs diff=lfs merge=lfs -text
|
42 |
example_videos/ex5.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
media/banner.gif filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
41 |
example_videos/ex4.mp4 filter=lfs diff=lfs merge=lfs -text
|
42 |
example_videos/ex5.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
media/banner.gif filter=lfs diff=lfs merge=lfs -text
|
44 |
+
src/pixel3dmm/preprocessing/facer/samples/data/ffhq_15723.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
src/pixel3dmm/preprocessing/facer/samples/data/weirdface.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
src/pixel3dmm/preprocessing/facer/samples/example_output/alignment.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
src/pixel3dmm/preprocessing/facer/samples/example_output/detect.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
src/pixel3dmm/preprocessing/facer/samples/example_output/parsing.png filter=lfs diff=lfs merge=lfs -text
|
src/pixel3dmm/preprocessing/facer/.gitignore
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
samples/output
|
131 |
+
.local/
|
132 |
+
|
133 |
+
.token.txt
|
134 |
+
.downloaded/
|
src/pixel3dmm/preprocessing/facer/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 FacePerceiver
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
src/pixel3dmm/preprocessing/facer/README.md
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FACER
|
2 |
+
|
3 |
+
Face related toolkit. This repo is still under construction to include more models.
|
4 |
+
|
5 |
+
## Updates
|
6 |
+
- [01/17/2025] FaRL Face Parsing is now available in [elliottzheng/batch-face](https://github.com/elliottzheng/batch-face?tab=readme-ov-file#face-parsing), which provides faster batch processing ability.
|
7 |
+
- [14/05/2023] Face attribute recognition model trained on CelebA is available, check it out [here](./samples/face_attribute.ipynb).
|
8 |
+
- [04/05/2023] Face alignment model trained on IBUG300W, AFLW19, WFLW dataset is available, check it out [here](./samples/face_alignment.ipynb).
|
9 |
+
- [27/04/2023] Face parsing model trained on CelebM dataset is available, check it out [here](./samples/face_parsing.ipynb).
|
10 |
+
|
11 |
+
## Install
|
12 |
+
|
13 |
+
The easiest way to install it is using pip:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
pip install git+https://github.com/FacePerceiver/facer.git@main
|
17 |
+
```
|
18 |
+
No extra setup needs, pretrained weights will be downloaded automatically.
|
19 |
+
|
20 |
+
If you have trouble install from source, you can try install from PyPI:
|
21 |
+
```bash
|
22 |
+
pip install pyfacer
|
23 |
+
```
|
24 |
+
the PyPI version is not guaranteed to be the latest version, but we will try to keep it up to date.
|
25 |
+
|
26 |
+
|
27 |
+
## Face Detection
|
28 |
+
|
29 |
+
We simply wrap a retinaface detector for easy usage.
|
30 |
+
```python
|
31 |
+
import facer
|
32 |
+
|
33 |
+
image = facer.hwc2bchw(facer.read_hwc('data/twogirls.jpg')).to(device=device) # image: 1 x 3 x h x w
|
34 |
+
|
35 |
+
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
|
36 |
+
with torch.inference_mode():
|
37 |
+
faces = face_detector(image)
|
38 |
+
|
39 |
+
facer.show_bchw(facer.draw_bchw(image, faces))
|
40 |
+
```
|
41 |
+

|
42 |
+
|
43 |
+
Check [this notebook](./samples/face_detect.ipynb) for full example.
|
44 |
+
|
45 |
+
Please consider citing
|
46 |
+
```
|
47 |
+
@inproceedings{deng2020retinaface,
|
48 |
+
title={Retinaface: Single-shot multi-level face localisation in the wild},
|
49 |
+
author={Deng, Jiankang and Guo, Jia and Ververas, Evangelos and Kotsia, Irene and Zafeiriou, Stefanos},
|
50 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
51 |
+
pages={5203--5212},
|
52 |
+
year={2020}
|
53 |
+
}
|
54 |
+
```
|
55 |
+
|
56 |
+
## Face Parsing
|
57 |
+
|
58 |
+
We wrap the [FaRL](https://github.com/faceperceiver/farl) models for face parsing.
|
59 |
+
```python
|
60 |
+
import torch
|
61 |
+
import facer
|
62 |
+
|
63 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
64 |
+
|
65 |
+
image = facer.hwc2bchw(facer.read_hwc('data/twogirls.jpg')).to(device=device) # image: 1 x 3 x h x w
|
66 |
+
|
67 |
+
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
|
68 |
+
with torch.inference_mode():
|
69 |
+
faces = face_detector(image)
|
70 |
+
|
71 |
+
face_parser = facer.face_parser('farl/lapa/448', device=device) # optional "farl/celebm/448"
|
72 |
+
|
73 |
+
with torch.inference_mode():
|
74 |
+
faces = face_parser(image, faces)
|
75 |
+
|
76 |
+
seg_logits = faces['seg']['logits']
|
77 |
+
seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
|
78 |
+
n_classes = seg_probs.size(1)
|
79 |
+
vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
|
80 |
+
vis_img = vis_seg_probs.sum(0, keepdim=True)
|
81 |
+
facer.show_bhw(vis_img)
|
82 |
+
facer.show_bchw(facer.draw_bchw(image, faces))
|
83 |
+
```
|
84 |
+

|
85 |
+
|
86 |
+
Check [this notebook](./samples/face_parsing.ipynb) for full example.
|
87 |
+
|
88 |
+
Please consider citing
|
89 |
+
```
|
90 |
+
@inproceedings{zheng2022farl,
|
91 |
+
title={General facial representation learning in a visual-linguistic manner},
|
92 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang},
|
93 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
94 |
+
pages={18697--18709},
|
95 |
+
year={2022}
|
96 |
+
}
|
97 |
+
```
|
98 |
+
|
99 |
+
|
100 |
+
## Face Alignment
|
101 |
+
|
102 |
+
We wrap the [FaRL](https://github.com/faceperceiver/farl) models for face alignment.
|
103 |
+
```python
|
104 |
+
import torch
|
105 |
+
import cv2
|
106 |
+
from matplotlib import pyplot as plt
|
107 |
+
|
108 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
109 |
+
|
110 |
+
import facer
|
111 |
+
img_file = 'data/twogirls.jpg'
|
112 |
+
# image: 1 x 3 x h x w
|
113 |
+
image = facer.hwc2bchw(facer.read_hwc(img_file)).to(device=device)
|
114 |
+
|
115 |
+
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
|
116 |
+
with torch.inference_mode():
|
117 |
+
faces = face_detector(image)
|
118 |
+
|
119 |
+
face_aligner = facer.face_aligner('farl/ibug300w/448', device=device) # optional: "farl/wflw/448", "farl/aflw19/448"
|
120 |
+
|
121 |
+
with torch.inference_mode():
|
122 |
+
faces = face_aligner(image, faces)
|
123 |
+
|
124 |
+
img = cv2.imread(img_file)[..., ::-1]
|
125 |
+
vis_img = img.copy()
|
126 |
+
for pts in faces['alignment']:
|
127 |
+
vis_img = facer.draw_landmarks(vis_img, None, pts.cpu().numpy())
|
128 |
+
plt.imshow(vis_img)
|
129 |
+
```
|
130 |
+

|
131 |
+
|
132 |
+
Check [this notebook](./samples/face_alignment.ipynb) for full example.
|
133 |
+
|
134 |
+
Please consider citing
|
135 |
+
```
|
136 |
+
@inproceedings{zheng2022farl,
|
137 |
+
title={General facial representation learning in a visual-linguistic manner},
|
138 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang},
|
139 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
140 |
+
pages={18697--18709},
|
141 |
+
year={2022}
|
142 |
+
}
|
143 |
+
```
|
144 |
+
|
145 |
+
## Face Attribute Recognition
|
146 |
+
We wrap the [FaRL](https://github.com/faceperceiver/farl) models for face attribute recognition, the model achieves 92.06% accuracy on [CelebA](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset.
|
147 |
+
|
148 |
+
```python
|
149 |
+
import sys
|
150 |
+
import torch
|
151 |
+
import facer
|
152 |
+
|
153 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
154 |
+
|
155 |
+
# image: 1 x 3 x h x w
|
156 |
+
image = facer.hwc2bchw(facer.read_hwc("data/girl.jpg")).to(device=device)
|
157 |
+
|
158 |
+
face_detector = facer.face_detector("retinaface/mobilenet", device=device)
|
159 |
+
with torch.inference_mode():
|
160 |
+
faces = face_detector(image)
|
161 |
+
|
162 |
+
face_attr = facer.face_attr("farl/celeba/224", device=device)
|
163 |
+
with torch.inference_mode():
|
164 |
+
faces = face_attr(image, faces)
|
165 |
+
|
166 |
+
labels = face_attr.labels
|
167 |
+
face1_attrs = faces["attrs"][0] # get the first face's attributes
|
168 |
+
|
169 |
+
print(labels)
|
170 |
+
|
171 |
+
for prob, label in zip(face1_attrs, labels):
|
172 |
+
if prob > 0.5:
|
173 |
+
print(label, prob.item())
|
174 |
+
```
|
175 |
+
|
176 |
+
Check [this notebook](./samples/face_attribute.ipynb) for full example.
|
177 |
+
|
178 |
+
Please consider citing
|
179 |
+
```
|
180 |
+
@inproceedings{zheng2022farl,
|
181 |
+
title={General facial representation learning in a visual-linguistic manner},
|
182 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen, Dongdong and Huang, Yangyu and Yuan, Lu and Chen, Dong and Zeng, Ming and Wen, Fang},
|
183 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
184 |
+
pages={18697--18709},
|
185 |
+
year={2022}
|
186 |
+
}
|
187 |
+
```
|
src/pixel3dmm/preprocessing/facer/facer/__init__.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .io import read_hwc, write_hwc
|
5 |
+
from .util import hwc2bchw, bchw2hwc, bchw2bhwc, bhwc2bchw, bhwc2hwc
|
6 |
+
from .draw import draw_bchw, draw_landmarks
|
7 |
+
from .show import show_bchw, show_bhw
|
8 |
+
|
9 |
+
from .face_detection import FaceDetector
|
10 |
+
from .face_parsing import FaceParser
|
11 |
+
from .face_alignment import FaceAlignment
|
12 |
+
from .face_attribute import FaceAttribute
|
13 |
+
|
14 |
+
|
15 |
+
def _split_name(name: str) -> Tuple[str, Optional[str]]:
|
16 |
+
if '/' in name:
|
17 |
+
detector_type, conf_name = name.split('/', 1)
|
18 |
+
else:
|
19 |
+
detector_type, conf_name = name, None
|
20 |
+
return detector_type, conf_name
|
21 |
+
|
22 |
+
|
23 |
+
def face_detector(name: str, device: torch.device, **kwargs) -> FaceDetector:
|
24 |
+
detector_type, conf_name = _split_name(name)
|
25 |
+
if detector_type == 'retinaface':
|
26 |
+
from .face_detection import RetinaFaceDetector
|
27 |
+
return RetinaFaceDetector(conf_name, **kwargs).to(device)
|
28 |
+
else:
|
29 |
+
raise RuntimeError(f'Unknown detector type: {detector_type}')
|
30 |
+
|
31 |
+
|
32 |
+
def face_parser(name: str, device: torch.device, **kwargs) -> FaceParser:
|
33 |
+
parser_type, conf_name = _split_name(name)
|
34 |
+
if parser_type == 'farl':
|
35 |
+
from .face_parsing import FaRLFaceParser
|
36 |
+
return FaRLFaceParser(conf_name, device=device, **kwargs).to(device)
|
37 |
+
else:
|
38 |
+
raise RuntimeError(f'Unknown parser type: {parser_type}')
|
39 |
+
|
40 |
+
|
41 |
+
def face_aligner(name: str, device: torch.device, **kwargs) -> FaceAlignment:
|
42 |
+
aligner_type, conf_name = _split_name(name)
|
43 |
+
if aligner_type == 'farl':
|
44 |
+
from .face_alignment import FaRLFaceAlignment
|
45 |
+
return FaRLFaceAlignment(conf_name, device=device, **kwargs).to(device)
|
46 |
+
else:
|
47 |
+
raise RuntimeError(f'Unknown aligner type: {aligner_type}')
|
48 |
+
|
49 |
+
def face_attr(name: str, device: torch.device, **kwargs) -> FaceAttribute:
|
50 |
+
attr_type, conf_name = _split_name(name)
|
51 |
+
if attr_type == 'farl':
|
52 |
+
from .face_attribute import FaRLFaceAttribute
|
53 |
+
return FaRLFaceAttribute(conf_name, device=device, **kwargs).to(device)
|
54 |
+
else:
|
55 |
+
raise RuntimeError(f'Unknown attribute type: {attr_type}')
|
src/pixel3dmm/preprocessing/facer/facer/draw.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import torch
|
3 |
+
import colorsys
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from skimage.draw import line_aa, circle_perimeter_aa
|
7 |
+
import cv2
|
8 |
+
from .util import select_data
|
9 |
+
|
10 |
+
|
11 |
+
def _gen_random_colors(N, bright=True):
|
12 |
+
brightness = 1.0 if bright else 0.7
|
13 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
14 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
15 |
+
random.shuffle(colors)
|
16 |
+
return colors
|
17 |
+
|
18 |
+
|
19 |
+
_static_label_colors = [
|
20 |
+
np.array((1.0, 1.0, 1.0), np.float32),
|
21 |
+
np.array((255, 250, 79), np.float32) / 255.0, # face
|
22 |
+
np.array([255, 125, 138], np.float32) / 255.0, # lb
|
23 |
+
np.array([213, 32, 29], np.float32) / 255.0, # rb
|
24 |
+
np.array([0, 144, 187], np.float32) / 255.0, # le
|
25 |
+
np.array([0, 196, 253], np.float32) / 255.0, # re
|
26 |
+
np.array([255, 129, 54], np.float32) / 255.0, # nose
|
27 |
+
np.array([88, 233, 135], np.float32) / 255.0, # ulip
|
28 |
+
np.array([0, 117, 27], np.float32) / 255.0, # llip
|
29 |
+
np.array([255, 76, 249], np.float32) / 255.0, # imouth
|
30 |
+
np.array((1.0, 0.0, 0.0), np.float32), # hair
|
31 |
+
np.array((255, 250, 100), np.float32) / 255.0, # lr
|
32 |
+
np.array((255, 250, 100), np.float32) / 255.0, # rr
|
33 |
+
np.array((250, 245, 50), np.float32) / 255.0, # neck
|
34 |
+
np.array((0.0, 1.0, 0.5), np.float32), # cloth
|
35 |
+
np.array((1.0, 0.0, 0.5), np.float32),
|
36 |
+
] + _gen_random_colors(256)
|
37 |
+
|
38 |
+
_names_in_static_label_colors = [
|
39 |
+
'background', 'face', 'lb', 'rb', 'le', 're', 'nose',
|
40 |
+
'ulip', 'llip', 'imouth', 'hair', 'lr', 'rr', 'neck',
|
41 |
+
'cloth', 'eyeg', 'hat', 'earr'
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
def _blend_labels(image, labels, label_names_dict=None,
|
46 |
+
default_alpha=0.6, color_offset=None):
|
47 |
+
assert labels.ndim == 2
|
48 |
+
bg_mask = labels == 0
|
49 |
+
if label_names_dict is None:
|
50 |
+
colors = _static_label_colors
|
51 |
+
else:
|
52 |
+
colors = [np.array((1.0, 1.0, 1.0), np.float32)]
|
53 |
+
for i in range(1, labels.max() + 1):
|
54 |
+
if isinstance(label_names_dict, dict) and i not in label_names_dict:
|
55 |
+
bg_mask = np.logical_or(bg_mask, labels == i)
|
56 |
+
colors.append(np.zeros((3)))
|
57 |
+
continue
|
58 |
+
label_name = label_names_dict[i]
|
59 |
+
if label_name in _names_in_static_label_colors:
|
60 |
+
color = _static_label_colors[
|
61 |
+
_names_in_static_label_colors.index(
|
62 |
+
label_name)]
|
63 |
+
else:
|
64 |
+
color = np.array((1.0, 1.0, 1.0), np.float32)
|
65 |
+
colors.append(color)
|
66 |
+
|
67 |
+
if color_offset is not None:
|
68 |
+
ncolors = []
|
69 |
+
for c in colors:
|
70 |
+
nc = np.array(c)
|
71 |
+
if (nc != np.zeros(3)).any():
|
72 |
+
nc += color_offset
|
73 |
+
ncolors.append(nc)
|
74 |
+
colors = ncolors
|
75 |
+
|
76 |
+
if image is None:
|
77 |
+
image = orig_image = np.zeros(
|
78 |
+
[labels.shape[0], labels.shape[1], 3], np.float32)
|
79 |
+
alpha = 1.0
|
80 |
+
else:
|
81 |
+
orig_image = image / np.max(image)
|
82 |
+
image = orig_image * (1.0 - default_alpha)
|
83 |
+
alpha = default_alpha
|
84 |
+
for i in range(1, np.max(labels) + 1):
|
85 |
+
image += alpha * \
|
86 |
+
np.tile(
|
87 |
+
np.expand_dims(
|
88 |
+
(labels == i).astype(np.float32), -1),
|
89 |
+
[1, 1, 3]) * colors[(i) % len(colors)]
|
90 |
+
image[np.where(image > 1.0)] = 1.0
|
91 |
+
image[np.where(image < 0)] = 0.0
|
92 |
+
image[np.where(bg_mask)] = orig_image[np.where(bg_mask)]
|
93 |
+
return image
|
94 |
+
|
95 |
+
|
96 |
+
def _draw_hwc(image: torch.Tensor, data: Dict[str, torch.Tensor]):
|
97 |
+
device = image.device
|
98 |
+
image = np.array(image.cpu().numpy(), copy=True)
|
99 |
+
dtype = image.dtype
|
100 |
+
h, w, _ = image.shape
|
101 |
+
|
102 |
+
draw_score_error = False
|
103 |
+
for tag, batch_content in data.items():
|
104 |
+
if tag == 'rects':
|
105 |
+
for cid, content in enumerate(batch_content):
|
106 |
+
x1, y1, x2, y2 = [int(v) for v in content]
|
107 |
+
y1, y2 = [max(min(v, h-1), 0) for v in [y1, y2]]
|
108 |
+
x1, x2 = [max(min(v, w-1), 0) for v in [x1, x2]]
|
109 |
+
for xx1, yy1, xx2, yy2 in [
|
110 |
+
[x1, y1, x2, y1],
|
111 |
+
[x1, y2, x2, y2],
|
112 |
+
[x1, y1, x1, y2],
|
113 |
+
[x2, y1, x2, y2]
|
114 |
+
]:
|
115 |
+
rr, cc, val = line_aa(yy1, xx1, yy2, xx2)
|
116 |
+
val = val[:, None][:, [0, 0, 0]]
|
117 |
+
image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255
|
118 |
+
|
119 |
+
if 'scores' in data:
|
120 |
+
try:
|
121 |
+
import cv2
|
122 |
+
score = data['scores'][cid].item()
|
123 |
+
score_str = f'{score:0.3f}'
|
124 |
+
image_c = np.array(image).copy()
|
125 |
+
cv2.putText(image_c, score_str, org=(int(x1), int(y2)),
|
126 |
+
fontFace=cv2.FONT_HERSHEY_TRIPLEX,
|
127 |
+
fontScale=0.6, color=(255, 255, 255), thickness=1)
|
128 |
+
image[:, :, :] = image_c
|
129 |
+
except Exception as e:
|
130 |
+
if not draw_score_error:
|
131 |
+
print(f'Failed to draw scores on image.')
|
132 |
+
print(e)
|
133 |
+
draw_score_error = True
|
134 |
+
|
135 |
+
if tag == 'points':
|
136 |
+
for content in batch_content:
|
137 |
+
# content: npoints x 2
|
138 |
+
for x, y in content:
|
139 |
+
x = max(min(int(x), w-1), 0)
|
140 |
+
y = max(min(int(y), h-1), 0)
|
141 |
+
rr, cc, val = circle_perimeter_aa(y, x, 1)
|
142 |
+
valid = np.all([rr >= 0, rr < h, cc >= 0, cc < w], axis=0)
|
143 |
+
rr = rr[valid]
|
144 |
+
cc = cc[valid]
|
145 |
+
val = val[valid]
|
146 |
+
val = val[:, None][:, [0, 0, 0]]
|
147 |
+
image[rr, cc] = image[rr, cc] * (1.0-val) + val * 255
|
148 |
+
|
149 |
+
if tag == 'seg':
|
150 |
+
label_names = batch_content['label_names']
|
151 |
+
for seg_logits in batch_content['logits']:
|
152 |
+
# content: nclasses x h x w
|
153 |
+
seg_probs = seg_logits.softmax(dim=0)
|
154 |
+
seg_labels = seg_probs.argmax(dim=0).cpu().numpy()
|
155 |
+
image = (_blend_labels(image.astype(np.float32) /
|
156 |
+
255, seg_labels,
|
157 |
+
label_names_dict=label_names) * 255).astype(dtype)
|
158 |
+
|
159 |
+
return torch.from_numpy(image).to(device=device)
|
160 |
+
|
161 |
+
|
162 |
+
def draw_bchw(images: torch.Tensor, data: Dict[str, torch.Tensor]) -> torch.Tensor:
|
163 |
+
images2 = []
|
164 |
+
for image_id, image_chw in enumerate(images):
|
165 |
+
selected_data = select_data(image_id == data['image_ids'], data)
|
166 |
+
images2.append(
|
167 |
+
_draw_hwc(image_chw.permute(1, 2, 0), selected_data).permute(2, 0, 1))
|
168 |
+
return torch.stack(images2, dim=0)
|
169 |
+
|
170 |
+
def draw_landmarks(img, bbox=None, landmark=None, color=(0, 255, 0)):
|
171 |
+
"""
|
172 |
+
Input:
|
173 |
+
- img: gray or RGB
|
174 |
+
- bbox: type of BBox
|
175 |
+
- landmark: reproject landmark of (5L, 2L)
|
176 |
+
Output:
|
177 |
+
- img marked with landmark and bbox
|
178 |
+
"""
|
179 |
+
img = cv2.UMat(img).get()
|
180 |
+
if bbox is not None:
|
181 |
+
x1, y1, x2, y2 = np.array(bbox)[:4].astype(np.int32)
|
182 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
|
183 |
+
if landmark is not None:
|
184 |
+
for x, y in np.array(landmark).astype(np.int32):
|
185 |
+
cv2.circle(img, (int(x), int(y)), 2, color, -1)
|
186 |
+
return img
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceAlignment
|
2 |
+
from .farl import FaRLFaceAlignment
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/base.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class FaceAlignment(nn.Module):
|
5 |
+
""" face alignment
|
6 |
+
|
7 |
+
Args:
|
8 |
+
images (torch.Tensor): b x c x h x w
|
9 |
+
|
10 |
+
data (Dict[str, Any]):
|
11 |
+
|
12 |
+
* image_ids (torch.Tensor): nfaces
|
13 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
14 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
data (Dict[str, Any]):
|
18 |
+
|
19 |
+
* image_ids (torch.Tensor): nfaces
|
20 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
21 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
22 |
+
* alignment
|
23 |
+
"""
|
24 |
+
pass
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/farl.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .network import FaRLVisualFeatures, MMSEG_UPerHead, FaceAlignmentTransformer, denormalize_points, heatmap2points
|
6 |
+
from ..transform import (get_face_align_matrix,
|
7 |
+
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
|
8 |
+
from .base import FaceAlignment
|
9 |
+
from ..util import download_jit
|
10 |
+
import io
|
11 |
+
|
12 |
+
pretrain_settings = {
|
13 |
+
'ibug300w/448': {
|
14 |
+
# inter_ocular 0.028835 epoch 60
|
15 |
+
'num_classes': 68,
|
16 |
+
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.ibug300w.main_ema_jit.pt",
|
17 |
+
'matrix_src_tag': 'points',
|
18 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
19 |
+
target_shape=(448, 448), target_face_scale=0.8),
|
20 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
21 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
22 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
23 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
24 |
+
|
25 |
+
},
|
26 |
+
'aflw19/448': {
|
27 |
+
# diag 0.009329 epoch 15
|
28 |
+
'num_classes': 19,
|
29 |
+
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.aflw19.main_ema_jit.pt",
|
30 |
+
'matrix_src_tag': 'points',
|
31 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
32 |
+
target_shape=(448, 448), target_face_scale=0.8),
|
33 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
34 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
35 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
36 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
37 |
+
},
|
38 |
+
'wflw/448': {
|
39 |
+
# inter_ocular 0.038933 epoch 20
|
40 |
+
'num_classes': 98,
|
41 |
+
'url': "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_alignment.farl.wflw.main_ema_jit.pt",
|
42 |
+
'matrix_src_tag': 'points',
|
43 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
44 |
+
target_shape=(448, 448), target_face_scale=0.8),
|
45 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
46 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
47 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
48 |
+
warp_factor=0.0, warped_shape=(448, 448)),
|
49 |
+
},
|
50 |
+
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
def load_face_alignment_model(model_path: str, num_classes=68):
|
55 |
+
backbone = FaRLVisualFeatures("base", None, forced_input_resolution=448, output_indices=None).cpu()
|
56 |
+
if "jit" in model_path:
|
57 |
+
extra_files = {"backbone": None}
|
58 |
+
heatmap_head = download_jit(model_path, map_location="cpu", _extra_files=extra_files)
|
59 |
+
backbone_weight_io = io.BytesIO(extra_files["backbone"])
|
60 |
+
backbone.load_state_dict(torch.load(backbone_weight_io))
|
61 |
+
# print("load from jit")
|
62 |
+
else:
|
63 |
+
channels = backbone.get_output_channel("base")
|
64 |
+
in_channels = [channels] * 4
|
65 |
+
num_classes = num_classes
|
66 |
+
heatmap_head = MMSEG_UPerHead(in_channels=in_channels, channels=channels, num_classes=num_classes) # this requires mmseg as a dependency
|
67 |
+
state = torch.load(model_path,map_location="cpu")["networks"]["main_ema"]
|
68 |
+
# print("load from checkpoint")
|
69 |
+
|
70 |
+
main_network = FaceAlignmentTransformer(backbone, heatmap_head, heatmap_act="sigmoid").cpu()
|
71 |
+
|
72 |
+
if "jit" not in model_path:
|
73 |
+
main_network.load_state_dict(state, strict=True)
|
74 |
+
|
75 |
+
return main_network
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
class FaRLFaceAlignment(FaceAlignment):
|
80 |
+
""" The face alignment models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
81 |
+
|
82 |
+
Please consider citing
|
83 |
+
```bibtex
|
84 |
+
@article{zheng2021farl,
|
85 |
+
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
86 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
87 |
+
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
88 |
+
Dong and Zeng, Ming and Wen, Fang},
|
89 |
+
journal={arXiv preprint arXiv:2112.03109},
|
90 |
+
year={2021}
|
91 |
+
}
|
92 |
+
```
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, conf_name: Optional[str] = None,
|
96 |
+
model_path: Optional[str] = None, device=None) -> None:
|
97 |
+
super().__init__()
|
98 |
+
if conf_name is None:
|
99 |
+
conf_name = 'ibug300w/448'
|
100 |
+
if model_path is None:
|
101 |
+
model_path = pretrain_settings[conf_name]['url']
|
102 |
+
self.conf_name = conf_name
|
103 |
+
|
104 |
+
setting = pretrain_settings[self.conf_name]
|
105 |
+
self.net = load_face_alignment_model(model_path, num_classes = setting["num_classes"])
|
106 |
+
if device is not None:
|
107 |
+
self.net = self.net.to(device)
|
108 |
+
|
109 |
+
self.heatmap_interpolate_mode = 'bilinear'
|
110 |
+
self.eval()
|
111 |
+
|
112 |
+
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
113 |
+
setting = pretrain_settings[self.conf_name]
|
114 |
+
images = images.float() / 255.0 # backbone 自带 normalize
|
115 |
+
_, _, h, w = images.shape
|
116 |
+
|
117 |
+
simages = images[data['image_ids']]
|
118 |
+
matrix = setting['get_matrix_fn'](data[setting['matrix_src_tag']])
|
119 |
+
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
120 |
+
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
121 |
+
|
122 |
+
w_images = F.grid_sample(
|
123 |
+
simages, grid, mode='bilinear', align_corners=False)
|
124 |
+
|
125 |
+
_, _, warp_h, warp_w = w_images.shape
|
126 |
+
|
127 |
+
heatmap_acted = self.net(w_images)
|
128 |
+
|
129 |
+
warpped_heatmap = F.interpolate(
|
130 |
+
heatmap_acted, size=(warp_h, warp_w),
|
131 |
+
mode=self.heatmap_interpolate_mode, align_corners=False)
|
132 |
+
|
133 |
+
pred_heatmap = F.grid_sample(
|
134 |
+
warpped_heatmap, inv_grid, mode='bilinear', align_corners=False)
|
135 |
+
|
136 |
+
landmark = heatmap2points(pred_heatmap)
|
137 |
+
|
138 |
+
landmark = denormalize_points(landmark, h, w)
|
139 |
+
|
140 |
+
data['alignment'] = landmark
|
141 |
+
|
142 |
+
return data
|
143 |
+
|
144 |
+
|
145 |
+
if __name__=="__main__":
|
146 |
+
image = torch.randn(1, 3, 448, 448)
|
147 |
+
|
148 |
+
aligner1 = FaRLFaceAlignment("wflw/448")
|
149 |
+
|
150 |
+
x1 = aligner1.net(image)
|
151 |
+
|
152 |
+
import argparse
|
153 |
+
|
154 |
+
parser = argparse.ArgumentParser()
|
155 |
+
parser.add_argument("--jit_path", type=str, default=None)
|
156 |
+
args = parser.parse_args()
|
157 |
+
|
158 |
+
if args.jit_path is None:
|
159 |
+
exit(0)
|
160 |
+
|
161 |
+
net = aligner1.net.cpu()
|
162 |
+
|
163 |
+
features, _ = net.backbone(image)
|
164 |
+
|
165 |
+
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
|
166 |
+
traced_script_module = torch.jit.trace(net.heatmap_head, example_inputs=[features])
|
167 |
+
|
168 |
+
buffer = io.BytesIO()
|
169 |
+
|
170 |
+
torch.save(net.backbone.state_dict(), buffer)
|
171 |
+
|
172 |
+
# Save to file
|
173 |
+
torch.jit.save(traced_script_module, args.jit_path,
|
174 |
+
_extra_files={"backbone": buffer.getvalue()})
|
175 |
+
|
176 |
+
aligner2 = FaRLFaceAlignment(model_path=args.jit_path)
|
177 |
+
|
178 |
+
# compare the output
|
179 |
+
x2 = aligner2.net(image)
|
180 |
+
print(torch.allclose(x1, x2))
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from .common import (load_checkpoint, Activation, MLP, Residual)
|
5 |
+
from .geometry import (normalize_points, denormalize_points,
|
6 |
+
heatmap2points)
|
7 |
+
from .mmseg import MMSEG_UPerHead
|
8 |
+
from .transformers import FaRLVisualFeatures
|
9 |
+
from torch import nn
|
10 |
+
from typing import Optional, List, Tuple
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class FaceAlignmentTransformer(nn.Module):
|
15 |
+
"""Face alignment transformer.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
landmark (torch.Tensor): Float32 tensor with shape [b, npoints, 2], coordinates normalized to [0, 1].
|
22 |
+
aux_outputs:
|
23 |
+
heatmap (torch.Tensor): Float32 tensor with shape [b, npoints, S, S]
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, backbone: nn.Module, heatmap_head: nn.Module,
|
27 |
+
heatmap_act: Optional[str] = 'relu'):
|
28 |
+
super().__init__()
|
29 |
+
self.backbone = backbone
|
30 |
+
self.heatmap_head = heatmap_head
|
31 |
+
self.heatmap_act = Activation(heatmap_act)
|
32 |
+
self.float()
|
33 |
+
|
34 |
+
def forward(self, image):
|
35 |
+
features, _ = self.backbone(image)
|
36 |
+
heatmap = self.heatmap_head(features) # b x npoints x s x s
|
37 |
+
heatmap_acted = self.heatmap_act(heatmap)
|
38 |
+
# landmark = heatmap2points(heatmap_acted) # b x npoints x 2
|
39 |
+
# return landmark, {'heatmap': heatmap, 'heatmap_acted': heatmap_acted}
|
40 |
+
return heatmap_acted
|
41 |
+
|
42 |
+
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/common.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from typing import List, Optional, Tuple, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
def load_checkpoint(net: nn.Module, checkpoint_path: str, network_name: str):
|
12 |
+
states = torch.load(open(checkpoint_path, 'rb'), map_location={
|
13 |
+
'cuda:0': f'cuda:{torch.cuda.current_device()}'})
|
14 |
+
network_states = states['networks']
|
15 |
+
net.load_state_dict(network_states[network_name])
|
16 |
+
return net
|
17 |
+
|
18 |
+
|
19 |
+
class Activation(nn.Module):
|
20 |
+
def __init__(self, name: Optional[str], **kwargs):
|
21 |
+
super().__init__()
|
22 |
+
if name == 'relu':
|
23 |
+
self.fn = F.relu
|
24 |
+
elif name == 'softplus':
|
25 |
+
self.fn = F.softplus
|
26 |
+
elif name == 'gelu':
|
27 |
+
self.fn = F.gelu
|
28 |
+
elif name == 'sigmoid':
|
29 |
+
self.fn = torch.sigmoid
|
30 |
+
elif name == 'sigmoid_x':
|
31 |
+
self.epsilon = kwargs.get('epsilon', 1e-3)
|
32 |
+
self.fn = lambda x: torch.clamp(
|
33 |
+
x.sigmoid() * (1.0 + self.epsilon*2.0) - self.epsilon,
|
34 |
+
min=0.0, max=1.0)
|
35 |
+
elif name == None:
|
36 |
+
self.fn = lambda x: x
|
37 |
+
else:
|
38 |
+
raise RuntimeError(f'Unknown activation name: {name}')
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.fn(x)
|
42 |
+
|
43 |
+
|
44 |
+
class MLP(nn.Module):
|
45 |
+
def __init__(self, channels: List[int], act: Optional[str]):
|
46 |
+
super().__init__()
|
47 |
+
assert len(channels) > 1
|
48 |
+
layers = []
|
49 |
+
for i in range(len(channels)-1):
|
50 |
+
layers.append(nn.Linear(channels[i], channels[i+1]))
|
51 |
+
if i+1 < len(channels):
|
52 |
+
layers.append(Activation(act))
|
53 |
+
self.layers = nn.Sequential(*layers)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return self.layers(x)
|
57 |
+
|
58 |
+
|
59 |
+
class Residual(nn.Module):
|
60 |
+
def __init__(self, net: nn.Module, res_weight_init: Optional[float] = 0.0):
|
61 |
+
super().__init__()
|
62 |
+
self.net = net
|
63 |
+
if res_weight_init is not None:
|
64 |
+
self.res_weight = nn.Parameter(torch.tensor(res_weight_init))
|
65 |
+
else:
|
66 |
+
self.res_weight = None
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
if self.res_weight is not None:
|
70 |
+
return self.res_weight * self.net(x) + x
|
71 |
+
else:
|
72 |
+
return self.net(x) + x
|
73 |
+
|
74 |
+
|
75 |
+
class SE(nn.Module):
|
76 |
+
def __init__(self, channel: int, r: int = 1):
|
77 |
+
super().__init__()
|
78 |
+
self.branch = nn.Sequential(
|
79 |
+
nn.Conv2d(channel, channel//r, (1, 1)),
|
80 |
+
nn.ReLU(),
|
81 |
+
nn.Conv2d(channel//r, channel, (1, 1)),
|
82 |
+
nn.Sigmoid()
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
# x: b x channel x h x w
|
87 |
+
v = x.mean([2, 3], keepdim=True) # b x channel x 1 x 1
|
88 |
+
v = self.branch(v) # b x channel x 1 x 1
|
89 |
+
return x * v
|
90 |
+
|
91 |
+
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/geometry.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from typing import Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def normalize_points(points: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
10 |
+
""" Normalize coordinates to [0, 1].
|
11 |
+
"""
|
12 |
+
return (points + 0.5) / torch.tensor([[[w, h]]]).to(points)
|
13 |
+
|
14 |
+
|
15 |
+
def denormalize_points(normalized_points: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
16 |
+
""" Reverse normalize_points.
|
17 |
+
"""
|
18 |
+
return normalized_points * torch.tensor([[[w, h]]]).to(normalized_points) - 0.5
|
19 |
+
|
20 |
+
|
21 |
+
def heatmap2points(heatmap, t_scale: Union[None, float, torch.Tensor] = None):
|
22 |
+
""" Heatmaps -> normalized points [b x npoints x 2(XY)].
|
23 |
+
"""
|
24 |
+
dtype = heatmap.dtype
|
25 |
+
_, _, h, w = heatmap.shape
|
26 |
+
|
27 |
+
# 0 ~ h-1, 0 ~ w-1
|
28 |
+
yy, xx = torch.meshgrid(
|
29 |
+
torch.arange(h).float(),
|
30 |
+
torch.arange(w).float())
|
31 |
+
|
32 |
+
yy = yy.view(1, 1, h, w).to(heatmap)
|
33 |
+
xx = xx.view(1, 1, h, w).to(heatmap)
|
34 |
+
|
35 |
+
if t_scale is not None:
|
36 |
+
heatmap = (heatmap * t_scale).exp()
|
37 |
+
heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)
|
38 |
+
|
39 |
+
yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum # b x npoints
|
40 |
+
xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum # b x npoints
|
41 |
+
|
42 |
+
points = torch.stack([xx_coord, yy_coord], dim=-1) # b x npoints x 2
|
43 |
+
|
44 |
+
normalized_points = normalize_points(points, h, w)
|
45 |
+
return normalized_points
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/mmseg.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class MMSEG_UPerHead(nn.Module):
|
8 |
+
"""Wraps the UPerHead from mmseg for segmentation.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, num_classes: int,
|
12 |
+
in_channels: list = [384, 384, 384, 384], channels: int = 512):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
from mmseg.models.decode_heads import UPerHead
|
16 |
+
self.head = UPerHead(
|
17 |
+
in_channels=in_channels,
|
18 |
+
in_index=[0, 1, 2, 3],
|
19 |
+
pool_scales=(1, 2, 3, 6),
|
20 |
+
channels=channels,
|
21 |
+
dropout_ratio=0.1,
|
22 |
+
num_classes=num_classes,
|
23 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
24 |
+
align_corners=False,
|
25 |
+
loss_decode=dict(
|
26 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
27 |
+
|
28 |
+
def forward(self, inputs):
|
29 |
+
return self.head(inputs)
|
src/pixel3dmm/preprocessing/facer/facer/face_alignment/network/transformers.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
from typing import Optional, List, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
from ... import farl
|
14 |
+
|
15 |
+
|
16 |
+
def _make_fpns(vision_patch_size: int, output_channels: int):
|
17 |
+
if vision_patch_size in {16, 14}:
|
18 |
+
fpn1 = nn.Sequential(
|
19 |
+
nn.ConvTranspose2d(output_channels, output_channels,
|
20 |
+
kernel_size=2, stride=2),
|
21 |
+
nn.SyncBatchNorm(output_channels),
|
22 |
+
nn.GELU(),
|
23 |
+
nn.ConvTranspose2d(output_channels, output_channels, kernel_size=2, stride=2))
|
24 |
+
|
25 |
+
fpn2 = nn.ConvTranspose2d(
|
26 |
+
output_channels, output_channels, kernel_size=2, stride=2)
|
27 |
+
fpn3 = nn.Identity()
|
28 |
+
fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
29 |
+
return nn.ModuleList([fpn1, fpn2, fpn3, fpn4])
|
30 |
+
elif vision_patch_size == 8:
|
31 |
+
fpn1 = nn.Sequential(nn.ConvTranspose2d(
|
32 |
+
output_channels, output_channels, kernel_size=2, stride=2))
|
33 |
+
fpn2 = nn.Identity()
|
34 |
+
fpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
35 |
+
fpn4 = nn.MaxPool2d(kernel_size=4, stride=4)
|
36 |
+
return nn.ModuleList([fpn1, fpn2, fpn3, fpn4])
|
37 |
+
else:
|
38 |
+
raise NotImplementedError()
|
39 |
+
|
40 |
+
|
41 |
+
def _resize_pe(pe: torch.Tensor, new_size: int, mode: str = 'bicubic', num_tokens: int = 1) -> torch.Tensor:
|
42 |
+
"""Resize positional embeddings.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
pe (torch.Tensor): A tensor with shape (num_tokens + old_size ** 2, width). pe[0, :] is the CLS token.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
torch.Tensor: A tensor with shape (num_tokens + new_size **2, width).
|
49 |
+
"""
|
50 |
+
l, w = pe.shape
|
51 |
+
old_size = int(math.sqrt(l-num_tokens))
|
52 |
+
assert old_size ** 2 + num_tokens == l
|
53 |
+
return torch.cat([
|
54 |
+
pe[:num_tokens, :],
|
55 |
+
F.interpolate(pe[num_tokens:, :].reshape(1, old_size, old_size, w).permute(0, 3, 1, 2),
|
56 |
+
(new_size, new_size), mode=mode, align_corners=False).view(w, -1).t()], dim=0)
|
57 |
+
|
58 |
+
|
59 |
+
class FaRLVisualFeatures(nn.Module):
|
60 |
+
"""Extract features from FaRL visual encoder.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w],
|
64 |
+
normalized to [0, 1].
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
List[torch.Tensor]: A list of features.
|
68 |
+
"""
|
69 |
+
image_mean: torch.Tensor
|
70 |
+
image_std: torch.Tensor
|
71 |
+
output_channels: int
|
72 |
+
num_outputs: int
|
73 |
+
|
74 |
+
def __init__(self, model_type: str,
|
75 |
+
model_path: Optional[str] = None, output_indices: Optional[List[int]] = None,
|
76 |
+
forced_input_resolution: Optional[int] = None,
|
77 |
+
apply_fpn: bool = True):
|
78 |
+
super().__init__()
|
79 |
+
self.visual = farl.load_farl(model_type, model_path)
|
80 |
+
|
81 |
+
vision_patch_size = self.visual.conv1.weight.shape[-1]
|
82 |
+
|
83 |
+
self.input_resolution = self.visual.input_resolution
|
84 |
+
if forced_input_resolution is not None and \
|
85 |
+
self.input_resolution != forced_input_resolution:
|
86 |
+
# resizing the positonal embeddings
|
87 |
+
self.visual.positional_embedding = nn.Parameter(
|
88 |
+
_resize_pe(self.visual.positional_embedding,
|
89 |
+
forced_input_resolution//vision_patch_size))
|
90 |
+
self.input_resolution = forced_input_resolution
|
91 |
+
|
92 |
+
self.output_channels = self.visual.transformer.width
|
93 |
+
|
94 |
+
if output_indices is None:
|
95 |
+
output_indices = self.__class__.get_default_output_indices(
|
96 |
+
model_type)
|
97 |
+
self.output_indices = output_indices
|
98 |
+
self.num_outputs = len(output_indices)
|
99 |
+
|
100 |
+
self.register_buffer('image_mean', torch.tensor(
|
101 |
+
[0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1))
|
102 |
+
self.register_buffer('image_std', torch.tensor(
|
103 |
+
[0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1))
|
104 |
+
|
105 |
+
if apply_fpn:
|
106 |
+
self.fpns = _make_fpns(vision_patch_size, self.output_channels)
|
107 |
+
else:
|
108 |
+
self.fpns = None
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def get_output_channel(model_type):
|
112 |
+
if model_type == 'base':
|
113 |
+
return 768
|
114 |
+
if model_type == 'large':
|
115 |
+
return 1024
|
116 |
+
if model_type == 'huge':
|
117 |
+
return 1280
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def get_default_output_indices(model_type):
|
121 |
+
if model_type == 'base':
|
122 |
+
return [3, 5, 7, 11]
|
123 |
+
if model_type == 'large':
|
124 |
+
return [7, 11, 15, 23]
|
125 |
+
if model_type == 'huge':
|
126 |
+
return [8, 14, 20, 31]
|
127 |
+
|
128 |
+
def forward(self, image: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
129 |
+
# b x 3 x res x res
|
130 |
+
_, _, input_h, input_w = image.shape
|
131 |
+
if input_h != self.input_resolution or input_w != self.input_resolution:
|
132 |
+
image = F.interpolate(image, self.input_resolution,
|
133 |
+
mode='bilinear', align_corners=False)
|
134 |
+
|
135 |
+
image = (image - self.image_mean.to(image.device)) / self.image_std.to(image.device)
|
136 |
+
|
137 |
+
x = image.to(self.visual.conv1.weight.data)
|
138 |
+
|
139 |
+
x = self.visual.conv1(x) # shape = [*, width, grid, grid]
|
140 |
+
N, _, S, S = x.shape
|
141 |
+
|
142 |
+
# shape = [*, width, grid ** 2]
|
143 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
144 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
145 |
+
x = torch.cat([self.visual.class_embedding.to(x.dtype) +
|
146 |
+
torch.zeros(x.shape[0], 1, x.shape[-1],
|
147 |
+
dtype=x.dtype, device=x.device),
|
148 |
+
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
149 |
+
|
150 |
+
x = x + self.visual.positional_embedding.to(x.dtype)
|
151 |
+
|
152 |
+
x = self.visual.ln_pre(x)
|
153 |
+
|
154 |
+
x = x.permute(1, 0, 2).contiguous() # NLD -> LND
|
155 |
+
|
156 |
+
features = []
|
157 |
+
cls_tokens = []
|
158 |
+
for blk in self.visual.transformer.resblocks:
|
159 |
+
x = blk(x) # [S ** 2 + 1, N, D]
|
160 |
+
# if idx in self.output_indices:
|
161 |
+
feature = x[1:, :, :].permute(
|
162 |
+
1, 2, 0).view(N, -1, S, S).contiguous().float()
|
163 |
+
features.append(feature)
|
164 |
+
cls_tokens.append(x[0, :, :])
|
165 |
+
|
166 |
+
features = [features[ind] for ind in self.output_indices]
|
167 |
+
cls_tokens = [cls_tokens[ind] for ind in self.output_indices]
|
168 |
+
|
169 |
+
if self.fpns is not None:
|
170 |
+
for i, fpn in enumerate(self.fpns):
|
171 |
+
features[i] = fpn(features[i])
|
172 |
+
|
173 |
+
return features, cls_tokens
|
src/pixel3dmm/preprocessing/facer/facer/face_attribute/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceAttribute
|
2 |
+
from .farl import FaRLFaceAttribute
|
src/pixel3dmm/preprocessing/facer/facer/face_attribute/base.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class FaceAttribute(nn.Module):
|
4 |
+
""" face attribute base class
|
5 |
+
|
6 |
+
Args:
|
7 |
+
images (torch.Tensor): b x c x h x w
|
8 |
+
|
9 |
+
data (Dict[str, Any]):
|
10 |
+
|
11 |
+
* image_ids (torch.Tensor): nfaces
|
12 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
13 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
data (Dict[str, Any]):
|
17 |
+
|
18 |
+
* image_ids (torch.Tensor): nfaces
|
19 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
20 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
21 |
+
* attrs (Dict[str, Any]):
|
22 |
+
* logits (torch.Tensor): nfaces x nclasses
|
23 |
+
"""
|
24 |
+
pass
|
src/pixel3dmm/preprocessing/facer/facer/face_attribute/farl.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from ..transform import get_face_align_matrix, make_tanh_warp_grid
|
6 |
+
from .base import FaceAttribute
|
7 |
+
from ..farl import farl_classification
|
8 |
+
from ..util import download_jit
|
9 |
+
import numpy as np
|
10 |
+
from torchvision.transforms import Normalize
|
11 |
+
|
12 |
+
|
13 |
+
def get_std_points_xray(out_size=256, mid_size=500):
|
14 |
+
std_points_256 = np.array(
|
15 |
+
[
|
16 |
+
[85.82991, 85.7792],
|
17 |
+
[169.0532, 84.3381],
|
18 |
+
[127.574, 137.0006],
|
19 |
+
[90.6964, 174.7014],
|
20 |
+
[167.3069, 173.3733],
|
21 |
+
]
|
22 |
+
)
|
23 |
+
std_points_256[:, 1] += 30
|
24 |
+
old_size = 256
|
25 |
+
mid = mid_size / 2
|
26 |
+
new_std_points = std_points_256 - old_size / 2 + mid
|
27 |
+
target_pts = new_std_points * out_size / mid_size
|
28 |
+
target_pts = torch.from_numpy(target_pts).float()
|
29 |
+
return target_pts
|
30 |
+
|
31 |
+
|
32 |
+
pretrain_settings = {
|
33 |
+
"celeba/224": {
|
34 |
+
# acc 92.06617474555969
|
35 |
+
"num_classes": 40,
|
36 |
+
"layers": [11],
|
37 |
+
"url": "https://github.com/FacePerceiver/facer/releases/download/models-v1/face_attribute.farl.celeba.pt",
|
38 |
+
"matrix_src_tag": "points",
|
39 |
+
"get_matrix_fn": functools.partial(
|
40 |
+
get_face_align_matrix,
|
41 |
+
target_shape=(224, 224),
|
42 |
+
target_pts=get_std_points_xray(out_size=224, mid_size=500),
|
43 |
+
),
|
44 |
+
"get_grid_fn": functools.partial(
|
45 |
+
make_tanh_warp_grid, warp_factor=0.0, warped_shape=(224, 224)
|
46 |
+
),
|
47 |
+
"classes": [
|
48 |
+
"5_o_Clock_Shadow",
|
49 |
+
"Arched_Eyebrows",
|
50 |
+
"Attractive",
|
51 |
+
"Bags_Under_Eyes",
|
52 |
+
"Bald",
|
53 |
+
"Bangs",
|
54 |
+
"Big_Lips",
|
55 |
+
"Big_Nose",
|
56 |
+
"Black_Hair",
|
57 |
+
"Blond_Hair",
|
58 |
+
"Blurry",
|
59 |
+
"Brown_Hair",
|
60 |
+
"Bushy_Eyebrows",
|
61 |
+
"Chubby",
|
62 |
+
"Double_Chin",
|
63 |
+
"Eyeglasses",
|
64 |
+
"Goatee",
|
65 |
+
"Gray_Hair",
|
66 |
+
"Heavy_Makeup",
|
67 |
+
"High_Cheekbones",
|
68 |
+
"Male",
|
69 |
+
"Mouth_Slightly_Open",
|
70 |
+
"Mustache",
|
71 |
+
"Narrow_Eyes",
|
72 |
+
"No_Beard",
|
73 |
+
"Oval_Face",
|
74 |
+
"Pale_Skin",
|
75 |
+
"Pointy_Nose",
|
76 |
+
"Receding_Hairline",
|
77 |
+
"Rosy_Cheeks",
|
78 |
+
"Sideburns",
|
79 |
+
"Smiling",
|
80 |
+
"Straight_Hair",
|
81 |
+
"Wavy_Hair",
|
82 |
+
"Wearing_Earrings",
|
83 |
+
"Wearing_Hat",
|
84 |
+
"Wearing_Lipstick",
|
85 |
+
"Wearing_Necklace",
|
86 |
+
"Wearing_Necktie",
|
87 |
+
"Young",
|
88 |
+
],
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
def load_face_attr(model_path, num_classes=40, layers=[11]):
|
94 |
+
model = farl_classification(num_classes=num_classes, layers=layers)
|
95 |
+
state_dict = download_jit(model_path, jit=False)
|
96 |
+
model.load_state_dict(state_dict)
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
class FaRLFaceAttribute(FaceAttribute):
|
101 |
+
"""The face attribute recognition models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
102 |
+
|
103 |
+
Please consider citing
|
104 |
+
```bibtex
|
105 |
+
@article{zheng2021farl,
|
106 |
+
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
107 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
108 |
+
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
109 |
+
Dong and Zeng, Ming and Wen, Fang},
|
110 |
+
journal={arXiv preprint arXiv:2112.03109},
|
111 |
+
year={2021}
|
112 |
+
}
|
113 |
+
```
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
conf_name: Optional[str] = None,
|
119 |
+
model_path: Optional[str] = None,
|
120 |
+
device=None,
|
121 |
+
) -> None:
|
122 |
+
super().__init__()
|
123 |
+
if conf_name is None:
|
124 |
+
conf_name = "celeba/224"
|
125 |
+
if model_path is None:
|
126 |
+
model_path = pretrain_settings[conf_name]["url"]
|
127 |
+
self.conf_name = conf_name
|
128 |
+
|
129 |
+
setting = pretrain_settings[self.conf_name]
|
130 |
+
self.labels = setting["classes"]
|
131 |
+
self.net = load_face_attr(model_path, num_classes=setting["num_classes"], layers = setting["layers"])
|
132 |
+
if device is not None:
|
133 |
+
self.net = self.net.to(device)
|
134 |
+
self.normalize = Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
135 |
+
self.eval()
|
136 |
+
|
137 |
+
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
138 |
+
setting = pretrain_settings[self.conf_name]
|
139 |
+
images = images.float() / 255.0
|
140 |
+
_, _, h, w = images.shape
|
141 |
+
|
142 |
+
simages = images[data["image_ids"]]
|
143 |
+
matrix = setting["get_matrix_fn"](data[setting["matrix_src_tag"]])
|
144 |
+
grid = setting["get_grid_fn"](matrix=matrix, orig_shape=(h, w))
|
145 |
+
|
146 |
+
w_images = F.grid_sample(simages, grid, mode="bilinear", align_corners=False)
|
147 |
+
w_images = self.normalize(w_images)
|
148 |
+
|
149 |
+
outputs = self.net(w_images)
|
150 |
+
probs = torch.sigmoid(outputs)
|
151 |
+
|
152 |
+
data["attrs"] = probs
|
153 |
+
|
154 |
+
return data
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
model = FaRLFaceAttribute()
|
src/pixel3dmm/preprocessing/facer/facer/face_detection/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceDetector
|
2 |
+
from .retinaface import RetinaFaceDetector
|
src/pixel3dmm/preprocessing/facer/facer/face_detection/base.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class FaceDetector(nn.Module):
|
6 |
+
""" face detector
|
7 |
+
|
8 |
+
Args:
|
9 |
+
images (torch.Tensor): b x c x h x w
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
data (Dict[str, torch.Tensor]):
|
13 |
+
|
14 |
+
* rects: nfaces x 4 (x1, y1, x2, y2)
|
15 |
+
* points: nfaces x 5 x 2 (x, y)
|
16 |
+
* scores: nfaces
|
17 |
+
* image_ids: nfaces
|
18 |
+
"""
|
19 |
+
pass
|
src/pixel3dmm/preprocessing/facer/facer/face_detection/retinaface.py
ADDED
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# largely borrowed from https://github.dev/elliottzheng/batch-face/face_detection/alignment.py
|
2 |
+
|
3 |
+
from typing import Dict, List, Optional, Tuple
|
4 |
+
import torch
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision.models._utils as _utils
|
9 |
+
from .base import FaceDetector
|
10 |
+
|
11 |
+
|
12 |
+
from itertools import product as product
|
13 |
+
from math import ceil
|
14 |
+
|
15 |
+
|
16 |
+
pretrained_urls = {
|
17 |
+
"mobilenet": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/mobilenet0.25_Final.pth",
|
18 |
+
"resnet50": "https://github.com/elliottzheng/face-detection/releases/download/0.0.1/Resnet50_Final.pth"
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def conv_bn(inp, oup, stride=1, leaky=0):
|
23 |
+
return nn.Sequential(
|
24 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
25 |
+
nn.BatchNorm2d(oup),
|
26 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def conv_bn_no_relu(inp, oup, stride):
|
31 |
+
return nn.Sequential(
|
32 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
33 |
+
nn.BatchNorm2d(oup),
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def conv_bn1X1(inp, oup, stride, leaky=0):
|
38 |
+
return nn.Sequential(
|
39 |
+
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
|
40 |
+
nn.BatchNorm2d(oup),
|
41 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def conv_dw(inp, oup, stride, leaky=0.1):
|
46 |
+
return nn.Sequential(
|
47 |
+
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
48 |
+
nn.BatchNorm2d(inp),
|
49 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
50 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
51 |
+
nn.BatchNorm2d(oup),
|
52 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class SSH(nn.Module):
|
57 |
+
def __init__(self, in_channel, out_channel):
|
58 |
+
super(SSH, self).__init__()
|
59 |
+
assert out_channel % 4 == 0
|
60 |
+
leaky = 0
|
61 |
+
if out_channel <= 64:
|
62 |
+
leaky = 0.1
|
63 |
+
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
|
64 |
+
|
65 |
+
self.conv5X5_1 = conv_bn(
|
66 |
+
in_channel, out_channel // 4, stride=1, leaky=leaky)
|
67 |
+
self.conv5X5_2 = conv_bn_no_relu(
|
68 |
+
out_channel // 4, out_channel // 4, stride=1)
|
69 |
+
|
70 |
+
self.conv7X7_2 = conv_bn(
|
71 |
+
out_channel // 4, out_channel // 4, stride=1, leaky=leaky
|
72 |
+
)
|
73 |
+
self.conv7x7_3 = conv_bn_no_relu(
|
74 |
+
out_channel // 4, out_channel // 4, stride=1)
|
75 |
+
|
76 |
+
def forward(self, input):
|
77 |
+
conv3X3 = self.conv3X3(input)
|
78 |
+
|
79 |
+
conv5X5_1 = self.conv5X5_1(input)
|
80 |
+
conv5X5 = self.conv5X5_2(conv5X5_1)
|
81 |
+
|
82 |
+
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
83 |
+
conv7X7 = self.conv7x7_3(conv7X7_2)
|
84 |
+
|
85 |
+
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
86 |
+
out = F.relu(out)
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
class FPN(nn.Module):
|
91 |
+
def __init__(self, in_channels_list, out_channels):
|
92 |
+
super(FPN, self).__init__()
|
93 |
+
leaky = 0
|
94 |
+
if out_channels <= 64:
|
95 |
+
leaky = 0.1
|
96 |
+
self.output1 = conv_bn1X1(
|
97 |
+
in_channels_list[0], out_channels, stride=1, leaky=leaky
|
98 |
+
)
|
99 |
+
self.output2 = conv_bn1X1(
|
100 |
+
in_channels_list[1], out_channels, stride=1, leaky=leaky
|
101 |
+
)
|
102 |
+
self.output3 = conv_bn1X1(
|
103 |
+
in_channels_list[2], out_channels, stride=1, leaky=leaky
|
104 |
+
)
|
105 |
+
|
106 |
+
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
|
107 |
+
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
|
108 |
+
|
109 |
+
def forward(self, input):
|
110 |
+
# names = list(input.keys())
|
111 |
+
input = list(input.values())
|
112 |
+
|
113 |
+
output1 = self.output1(input[0])
|
114 |
+
output2 = self.output2(input[1])
|
115 |
+
output3 = self.output3(input[2])
|
116 |
+
|
117 |
+
up3 = F.interpolate(
|
118 |
+
output3, size=[output2.size(2), output2.size(3)], mode="nearest"
|
119 |
+
)
|
120 |
+
output2 = output2 + up3
|
121 |
+
output2 = self.merge2(output2)
|
122 |
+
|
123 |
+
up2 = F.interpolate(
|
124 |
+
output2, size=[output1.size(2), output1.size(3)], mode="nearest"
|
125 |
+
)
|
126 |
+
output1 = output1 + up2
|
127 |
+
output1 = self.merge1(output1)
|
128 |
+
|
129 |
+
out = [output1, output2, output3]
|
130 |
+
return out
|
131 |
+
|
132 |
+
|
133 |
+
class MobileNetV1(nn.Module):
|
134 |
+
def __init__(self):
|
135 |
+
super(MobileNetV1, self).__init__()
|
136 |
+
self.stage1 = nn.Sequential(
|
137 |
+
conv_bn(3, 8, 2, leaky=0.1), # 3
|
138 |
+
conv_dw(8, 16, 1), # 7
|
139 |
+
conv_dw(16, 32, 2), # 11
|
140 |
+
conv_dw(32, 32, 1), # 19
|
141 |
+
conv_dw(32, 64, 2), # 27
|
142 |
+
conv_dw(64, 64, 1), # 43
|
143 |
+
)
|
144 |
+
self.stage2 = nn.Sequential(
|
145 |
+
conv_dw(64, 128, 2), # 43 + 16 = 59
|
146 |
+
conv_dw(128, 128, 1), # 59 + 32 = 91
|
147 |
+
conv_dw(128, 128, 1), # 91 + 32 = 123
|
148 |
+
conv_dw(128, 128, 1), # 123 + 32 = 155
|
149 |
+
conv_dw(128, 128, 1), # 155 + 32 = 187
|
150 |
+
conv_dw(128, 128, 1), # 187 + 32 = 219
|
151 |
+
)
|
152 |
+
self.stage3 = nn.Sequential(
|
153 |
+
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
154 |
+
conv_dw(256, 256, 1), # 241 + 64 = 301
|
155 |
+
)
|
156 |
+
self.avg = nn.AdaptiveAvgPool2d((1, 1))
|
157 |
+
self.fc = nn.Linear(256, 1000)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
x = self.stage1(x)
|
161 |
+
x = self.stage2(x)
|
162 |
+
x = self.stage3(x)
|
163 |
+
x = self.avg(x)
|
164 |
+
# x = self.model(x)
|
165 |
+
x = x.view(-1, 256)
|
166 |
+
x = self.fc(x)
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class ClassHead(nn.Module):
|
171 |
+
def __init__(self, inchannels=512, num_anchors=3):
|
172 |
+
super(ClassHead, self).__init__()
|
173 |
+
self.num_anchors = num_anchors
|
174 |
+
self.conv1x1 = nn.Conv2d(
|
175 |
+
inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
out = self.conv1x1(x)
|
180 |
+
out = out.permute(0, 2, 3, 1).contiguous()
|
181 |
+
return out.view(out.shape[0], -1, 2)
|
182 |
+
|
183 |
+
|
184 |
+
class BboxHead(nn.Module):
|
185 |
+
def __init__(self, inchannels=512, num_anchors=3):
|
186 |
+
super(BboxHead, self).__init__()
|
187 |
+
self.conv1x1 = nn.Conv2d(
|
188 |
+
inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0
|
189 |
+
)
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
out = self.conv1x1(x)
|
193 |
+
out = out.permute(0, 2, 3, 1).contiguous()
|
194 |
+
return out.view(out.shape[0], -1, 4)
|
195 |
+
|
196 |
+
|
197 |
+
class LandmarkHead(nn.Module):
|
198 |
+
def __init__(self, inchannels=512, num_anchors=3):
|
199 |
+
super(LandmarkHead, self).__init__()
|
200 |
+
self.conv1x1 = nn.Conv2d(
|
201 |
+
inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
out = self.conv1x1(x)
|
206 |
+
out = out.permute(0, 2, 3, 1).contiguous()
|
207 |
+
return out.view(out.shape[0], -1, 10)
|
208 |
+
|
209 |
+
|
210 |
+
class RetinaFace(nn.Module):
|
211 |
+
def __init__(self, cfg=None, phase="train"):
|
212 |
+
"""
|
213 |
+
:param cfg: Network related settings.
|
214 |
+
:param phase: train or test.
|
215 |
+
"""
|
216 |
+
super(RetinaFace, self).__init__()
|
217 |
+
self.phase = phase
|
218 |
+
backbone = None
|
219 |
+
if cfg["name"] == "mobilenet0.25":
|
220 |
+
backbone = MobileNetV1()
|
221 |
+
elif cfg["name"] == "Resnet50":
|
222 |
+
import torchvision.models as models
|
223 |
+
backbone = models.resnet50(pretrained=cfg["pretrain"])
|
224 |
+
|
225 |
+
self.body = _utils.IntermediateLayerGetter(
|
226 |
+
backbone, cfg["return_layers"])
|
227 |
+
in_channels_stage2 = cfg["in_channel"]
|
228 |
+
in_channels_list = [
|
229 |
+
in_channels_stage2 * 2,
|
230 |
+
in_channels_stage2 * 4,
|
231 |
+
in_channels_stage2 * 8,
|
232 |
+
]
|
233 |
+
out_channels = cfg["out_channel"]
|
234 |
+
self.fpn = FPN(in_channels_list, out_channels)
|
235 |
+
self.ssh1 = SSH(out_channels, out_channels)
|
236 |
+
self.ssh2 = SSH(out_channels, out_channels)
|
237 |
+
self.ssh3 = SSH(out_channels, out_channels)
|
238 |
+
|
239 |
+
self.ClassHead = self._make_class_head(
|
240 |
+
fpn_num=3, inchannels=cfg["out_channel"])
|
241 |
+
self.BboxHead = self._make_bbox_head(
|
242 |
+
fpn_num=3, inchannels=cfg["out_channel"])
|
243 |
+
self.LandmarkHead = self._make_landmark_head(
|
244 |
+
fpn_num=3, inchannels=cfg["out_channel"]
|
245 |
+
)
|
246 |
+
|
247 |
+
def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
248 |
+
classhead = nn.ModuleList()
|
249 |
+
for i in range(fpn_num):
|
250 |
+
classhead.append(ClassHead(inchannels, anchor_num))
|
251 |
+
return classhead
|
252 |
+
|
253 |
+
def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
254 |
+
bboxhead = nn.ModuleList()
|
255 |
+
for i in range(fpn_num):
|
256 |
+
bboxhead.append(BboxHead(inchannels, anchor_num))
|
257 |
+
return bboxhead
|
258 |
+
|
259 |
+
def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2):
|
260 |
+
landmarkhead = nn.ModuleList()
|
261 |
+
for i in range(fpn_num):
|
262 |
+
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
|
263 |
+
return landmarkhead
|
264 |
+
|
265 |
+
def forward(self, inputs):
|
266 |
+
out = self.body(inputs)
|
267 |
+
|
268 |
+
# FPN
|
269 |
+
fpn = self.fpn(out)
|
270 |
+
|
271 |
+
# SSH
|
272 |
+
feature1 = self.ssh1(fpn[0])
|
273 |
+
feature2 = self.ssh2(fpn[1])
|
274 |
+
feature3 = self.ssh3(fpn[2])
|
275 |
+
features = [feature1, feature2, feature3]
|
276 |
+
|
277 |
+
bbox_regressions = torch.cat(
|
278 |
+
[self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1
|
279 |
+
)
|
280 |
+
classifications = torch.cat(
|
281 |
+
[self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1
|
282 |
+
)
|
283 |
+
ldm_regressions = torch.cat(
|
284 |
+
[self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1
|
285 |
+
)
|
286 |
+
|
287 |
+
if self.phase == "train":
|
288 |
+
output = (bbox_regressions, classifications, ldm_regressions)
|
289 |
+
else:
|
290 |
+
output = (
|
291 |
+
bbox_regressions,
|
292 |
+
F.softmax(classifications, dim=-1),
|
293 |
+
ldm_regressions,
|
294 |
+
)
|
295 |
+
return output
|
296 |
+
|
297 |
+
|
298 |
+
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
299 |
+
def decode(loc: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor:
|
300 |
+
boxes = torch.cat(
|
301 |
+
(
|
302 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
303 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]),
|
304 |
+
),
|
305 |
+
1,
|
306 |
+
)
|
307 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
308 |
+
boxes[:, 2:] += boxes[:, :2]
|
309 |
+
return boxes
|
310 |
+
|
311 |
+
|
312 |
+
def decode_landm(pre: torch.Tensor, priors: torch.Tensor, variances: Tuple[float, float]) -> torch.Tensor:
|
313 |
+
landms = torch.cat(
|
314 |
+
(
|
315 |
+
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
|
316 |
+
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
|
317 |
+
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
|
318 |
+
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
|
319 |
+
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
|
320 |
+
),
|
321 |
+
dim=1,
|
322 |
+
)
|
323 |
+
return landms
|
324 |
+
|
325 |
+
|
326 |
+
def nms(dets: torch.Tensor, thresh: float) -> List[int]:
|
327 |
+
"""Pure Python NMS baseline."""
|
328 |
+
x1 = dets[:, 0]
|
329 |
+
y1 = dets[:, 1]
|
330 |
+
x2 = dets[:, 2]
|
331 |
+
y2 = dets[:, 3]
|
332 |
+
scores = dets[:, 4]
|
333 |
+
|
334 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
335 |
+
order = torch.flip(scores.argsort(), [0])
|
336 |
+
|
337 |
+
keep = []
|
338 |
+
while order.numel() > 0:
|
339 |
+
i = order[0].item()
|
340 |
+
keep.append(i)
|
341 |
+
xx1 = torch.maximum(x1[i], x1[order[1:]])
|
342 |
+
yy1 = torch.maximum(y1[i], y1[order[1:]])
|
343 |
+
xx2 = torch.minimum(x2[i], x2[order[1:]])
|
344 |
+
yy2 = torch.minimum(y2[i], y2[order[1:]])
|
345 |
+
|
346 |
+
w = torch.maximum(torch.tensor(0.0).to(dets), xx2 - xx1 + 1)
|
347 |
+
h = torch.maximum(torch.tensor(0.0).to(dets), yy2 - yy1 + 1)
|
348 |
+
inter = w * h
|
349 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
350 |
+
|
351 |
+
inds = torch.where(ovr <= thresh)[0]
|
352 |
+
order = order[inds + 1]
|
353 |
+
|
354 |
+
return keep
|
355 |
+
|
356 |
+
|
357 |
+
class PriorBox:
|
358 |
+
def __init__(self, cfg: dict, image_size: Tuple[int, int]):
|
359 |
+
self.min_sizes = cfg["min_sizes"]
|
360 |
+
self.steps = cfg["steps"]
|
361 |
+
self.clip = cfg["clip"]
|
362 |
+
self.image_size = image_size
|
363 |
+
self.feature_maps = [
|
364 |
+
[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)]
|
365 |
+
for step in self.steps
|
366 |
+
]
|
367 |
+
|
368 |
+
def generate_anchors(self, device) -> torch.Tensor:
|
369 |
+
anchors = []
|
370 |
+
for k, f in enumerate(self.feature_maps):
|
371 |
+
min_sizes = self.min_sizes[k]
|
372 |
+
for i, j in product(range(f[0]), range(f[1])):
|
373 |
+
for min_size in min_sizes:
|
374 |
+
s_kx = min_size / self.image_size[1]
|
375 |
+
s_ky = min_size / self.image_size[0]
|
376 |
+
dense_cx = [
|
377 |
+
x * self.steps[k] / self.image_size[1] for x in [j + 0.5]
|
378 |
+
]
|
379 |
+
dense_cy = [
|
380 |
+
y * self.steps[k] / self.image_size[0] for y in [i + 0.5]
|
381 |
+
]
|
382 |
+
for cy, cx in product(dense_cy, dense_cx):
|
383 |
+
anchors += [cx, cy, s_kx, s_ky]
|
384 |
+
|
385 |
+
# back to torch land
|
386 |
+
output = torch.tensor(anchors).view(-1, 4)
|
387 |
+
if self.clip:
|
388 |
+
output.clamp_(max=1, min=0)
|
389 |
+
return output.to(device=device)
|
390 |
+
|
391 |
+
|
392 |
+
cfg_mnet = {
|
393 |
+
"name": "mobilenet0.25",
|
394 |
+
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
395 |
+
"steps": [8, 16, 32],
|
396 |
+
"variance": [0.1, 0.2],
|
397 |
+
"clip": False,
|
398 |
+
"loc_weight": 2.0,
|
399 |
+
"gpu_train": True,
|
400 |
+
"batch_size": 32,
|
401 |
+
"ngpu": 1,
|
402 |
+
"epoch": 250,
|
403 |
+
"decay1": 190,
|
404 |
+
"decay2": 220,
|
405 |
+
"image_size": 640,
|
406 |
+
"pretrain": True,
|
407 |
+
"return_layers": {"stage1": 1, "stage2": 2, "stage3": 3},
|
408 |
+
"in_channel": 32,
|
409 |
+
"out_channel": 64,
|
410 |
+
}
|
411 |
+
|
412 |
+
cfg_re50 = {
|
413 |
+
"name": "Resnet50",
|
414 |
+
"min_sizes": [[16, 32], [64, 128], [256, 512]],
|
415 |
+
"steps": [8, 16, 32],
|
416 |
+
"variance": [0.1, 0.2],
|
417 |
+
"clip": False,
|
418 |
+
"loc_weight": 2.0,
|
419 |
+
"gpu_train": True,
|
420 |
+
"batch_size": 24,
|
421 |
+
"ngpu": 4,
|
422 |
+
"epoch": 100,
|
423 |
+
"decay1": 70,
|
424 |
+
"decay2": 90,
|
425 |
+
"image_size": 840,
|
426 |
+
"pretrain": False,
|
427 |
+
"return_layers": {"layer2": 1, "layer3": 2, "layer4": 3},
|
428 |
+
"in_channel": 256,
|
429 |
+
"out_channel": 256,
|
430 |
+
}
|
431 |
+
|
432 |
+
|
433 |
+
def check_keys(model, pretrained_state_dict):
|
434 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
435 |
+
model_keys = set(model.state_dict().keys())
|
436 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
437 |
+
assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
|
438 |
+
return True
|
439 |
+
|
440 |
+
|
441 |
+
def remove_prefix(state_dict, prefix):
|
442 |
+
""" Old style model is stored with all names of parameters sharing common prefix 'module.' """
|
443 |
+
def f(x): return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
444 |
+
return {f(key): value for key, value in state_dict.items()}
|
445 |
+
|
446 |
+
|
447 |
+
def load_model(model, pretrained_path, load_to_cpu, network: str):
|
448 |
+
if pretrained_path is None:
|
449 |
+
url = pretrained_urls[network]
|
450 |
+
if load_to_cpu:
|
451 |
+
pretrained_dict = torch.utils.model_zoo.load_url(
|
452 |
+
url, map_location=lambda storage, loc: storage
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
pretrained_dict = torch.utils.model_zoo.load_url(
|
456 |
+
url, map_location=lambda storage, loc: storage.cuda(device)
|
457 |
+
)
|
458 |
+
else:
|
459 |
+
if load_to_cpu:
|
460 |
+
pretrained_dict = torch.load(
|
461 |
+
pretrained_path, map_location=lambda storage, loc: storage
|
462 |
+
)
|
463 |
+
else:
|
464 |
+
device = torch.cuda.current_device()
|
465 |
+
pretrained_dict = torch.load(
|
466 |
+
pretrained_path, map_location=lambda storage, loc: storage.cuda(
|
467 |
+
device)
|
468 |
+
)
|
469 |
+
if "state_dict" in pretrained_dict.keys():
|
470 |
+
pretrained_dict = remove_prefix(
|
471 |
+
pretrained_dict["state_dict"], "module.")
|
472 |
+
else:
|
473 |
+
pretrained_dict = remove_prefix(pretrained_dict, "module.")
|
474 |
+
check_keys(model, pretrained_dict)
|
475 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
476 |
+
return model
|
477 |
+
|
478 |
+
|
479 |
+
def load_net(model_path, network="mobilenet"):
|
480 |
+
if network == "mobilenet":
|
481 |
+
cfg = cfg_mnet
|
482 |
+
elif network == "resnet50":
|
483 |
+
cfg = cfg_re50
|
484 |
+
else:
|
485 |
+
raise NotImplementedError(network)
|
486 |
+
# net and model
|
487 |
+
net = RetinaFace(cfg=cfg, phase="test")
|
488 |
+
net = load_model(net, model_path, True, network=network)
|
489 |
+
net.eval()
|
490 |
+
cudnn.benchmark = True
|
491 |
+
# net = net.to(device)
|
492 |
+
return net
|
493 |
+
|
494 |
+
|
495 |
+
def parse_det(det: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, float]:
|
496 |
+
landmarks = det[5:].reshape(5, 2)
|
497 |
+
box = det[:4]
|
498 |
+
score = det[4]
|
499 |
+
return box, landmarks, score.item()
|
500 |
+
|
501 |
+
|
502 |
+
def post_process(
|
503 |
+
loc: torch.Tensor,
|
504 |
+
conf: torch.Tensor,
|
505 |
+
landms: torch.Tensor,
|
506 |
+
prior_data: torch.Tensor,
|
507 |
+
cfg: dict,
|
508 |
+
scale: float,
|
509 |
+
scale1: float,
|
510 |
+
resize,
|
511 |
+
confidence_threshold,
|
512 |
+
top_k,
|
513 |
+
nms_threshold,
|
514 |
+
keep_top_k,
|
515 |
+
):
|
516 |
+
boxes = decode(loc, prior_data, cfg["variance"])
|
517 |
+
boxes = boxes * scale / resize
|
518 |
+
# boxes = boxes.cpu().numpy()
|
519 |
+
# scores = conf.cpu().numpy()[:, 1]
|
520 |
+
scores = conf[:, 1]
|
521 |
+
landms_copy = decode_landm(landms, prior_data, cfg["variance"])
|
522 |
+
|
523 |
+
landms_copy = landms_copy * scale1 / resize
|
524 |
+
# landms_copy = landms_copy.cpu().numpy()
|
525 |
+
|
526 |
+
# ignore low scores
|
527 |
+
inds = torch.where(scores > confidence_threshold)[0]
|
528 |
+
boxes = boxes[inds]
|
529 |
+
landms_copy = landms_copy[inds]
|
530 |
+
scores = scores[inds]
|
531 |
+
|
532 |
+
# keep top-K before NMS
|
533 |
+
order = torch.flip(scores.argsort(), [0])[:top_k]
|
534 |
+
boxes = boxes[order]
|
535 |
+
landms_copy = landms_copy[order]
|
536 |
+
scores = scores[order]
|
537 |
+
|
538 |
+
# do NMS
|
539 |
+
dets = torch.hstack((boxes, scores.unsqueeze(-1))).to(
|
540 |
+
dtype=torch.float32, copy=False)
|
541 |
+
keep = nms(dets, nms_threshold)
|
542 |
+
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
543 |
+
dets = dets[keep, :]
|
544 |
+
landms_copy = landms_copy[keep]
|
545 |
+
|
546 |
+
# keep top-K faster NMS
|
547 |
+
dets = dets[:keep_top_k, :]
|
548 |
+
landms_copy = landms_copy[:keep_top_k, :]
|
549 |
+
|
550 |
+
dets = torch.cat((dets, landms_copy), dim=1)
|
551 |
+
# show image
|
552 |
+
dets = sorted(dets, key=lambda x: x[4], reverse=True)
|
553 |
+
dets = [parse_det(x) for x in dets]
|
554 |
+
|
555 |
+
return dets
|
556 |
+
|
557 |
+
|
558 |
+
# @torch.no_grad()
|
559 |
+
def batch_detect(net: nn.Module, images: torch.Tensor, threshold: float = 0.5):
|
560 |
+
confidence_threshold = threshold
|
561 |
+
cfg = cfg_mnet
|
562 |
+
top_k = 5000
|
563 |
+
nms_threshold = 0.4
|
564 |
+
keep_top_k = 750
|
565 |
+
resize = 1
|
566 |
+
|
567 |
+
img = images.float()
|
568 |
+
mean = torch.as_tensor([104, 117, 123], dtype=img.dtype, device=img.device).view(
|
569 |
+
1, 3, 1, 1
|
570 |
+
)
|
571 |
+
img -= mean
|
572 |
+
(
|
573 |
+
_,
|
574 |
+
_,
|
575 |
+
im_height,
|
576 |
+
im_width,
|
577 |
+
) = img.shape
|
578 |
+
scale = torch.as_tensor(
|
579 |
+
[im_width, im_height, im_width, im_height],
|
580 |
+
dtype=img.dtype,
|
581 |
+
device=img.device,
|
582 |
+
)
|
583 |
+
scale = scale.to(img.device)
|
584 |
+
|
585 |
+
loc, conf, landms = net(img) # forward pass
|
586 |
+
|
587 |
+
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
|
588 |
+
prior_data = priorbox.generate_anchors(device=img.device)
|
589 |
+
scale1 = torch.as_tensor(
|
590 |
+
[
|
591 |
+
img.shape[3],
|
592 |
+
img.shape[2],
|
593 |
+
img.shape[3],
|
594 |
+
img.shape[2],
|
595 |
+
img.shape[3],
|
596 |
+
img.shape[2],
|
597 |
+
img.shape[3],
|
598 |
+
img.shape[2],
|
599 |
+
img.shape[3],
|
600 |
+
img.shape[2],
|
601 |
+
],
|
602 |
+
dtype=img.dtype,
|
603 |
+
device=img.device,
|
604 |
+
)
|
605 |
+
scale1 = scale1.to(img.device)
|
606 |
+
|
607 |
+
all_dets = [
|
608 |
+
post_process(
|
609 |
+
loc_i,
|
610 |
+
conf_i,
|
611 |
+
landms_i,
|
612 |
+
prior_data,
|
613 |
+
cfg,
|
614 |
+
scale,
|
615 |
+
scale1,
|
616 |
+
resize,
|
617 |
+
confidence_threshold,
|
618 |
+
top_k,
|
619 |
+
nms_threshold,
|
620 |
+
keep_top_k,
|
621 |
+
)
|
622 |
+
for loc_i, conf_i, landms_i in zip(loc, conf, landms)
|
623 |
+
]
|
624 |
+
|
625 |
+
rects = []
|
626 |
+
points = []
|
627 |
+
scores = []
|
628 |
+
image_ids = []
|
629 |
+
for image_id, faces_in_one_image in enumerate(all_dets):
|
630 |
+
for rect, landmarks, score in faces_in_one_image:
|
631 |
+
rects.append(rect)
|
632 |
+
points.append(landmarks)
|
633 |
+
scores.append(score)
|
634 |
+
image_ids.append(image_id)
|
635 |
+
|
636 |
+
if len(rects) == 0:
|
637 |
+
return {
|
638 |
+
'rects': torch.Tensor().to(img.device),
|
639 |
+
'points': torch.Tensor().to(img.device),
|
640 |
+
'scores': torch.Tensor().to(img.device),
|
641 |
+
'image_ids': torch.Tensor().to(img.device),
|
642 |
+
}
|
643 |
+
|
644 |
+
return {
|
645 |
+
'rects': torch.stack(rects, dim=0).to(img.device),
|
646 |
+
'points': torch.stack(points, dim=0).to(img.device),
|
647 |
+
'scores': torch.tensor(scores).to(img.device),
|
648 |
+
'image_ids': torch.tensor(image_ids).to(img.device),
|
649 |
+
}
|
650 |
+
|
651 |
+
|
652 |
+
class RetinaFaceDetector(FaceDetector):
|
653 |
+
"""RetinaFaceDetector
|
654 |
+
|
655 |
+
Args:
|
656 |
+
images (torch.Tensor): b x c x h x w, uint8, 0~255.
|
657 |
+
|
658 |
+
Returns:
|
659 |
+
faces (Dict[str, torch.Tensor]):
|
660 |
+
|
661 |
+
* image_ids: n, int
|
662 |
+
* rects: n x 4 (x1, y1, x2, y2)
|
663 |
+
* points: n x 5 x 2 (x, y)
|
664 |
+
* scores: n
|
665 |
+
"""
|
666 |
+
|
667 |
+
def __init__(self, conf_name: Optional[str] = None,
|
668 |
+
model_path: Optional[str] = None, threshold=0.8) -> None:
|
669 |
+
super().__init__()
|
670 |
+
if conf_name is None:
|
671 |
+
conf_name = 'mobilenet'
|
672 |
+
self.net = load_net(model_path, conf_name)
|
673 |
+
self.threshold = threshold
|
674 |
+
self.eval()
|
675 |
+
|
676 |
+
def forward(self, images: torch.Tensor) -> Dict[str, torch.Tensor]:
|
677 |
+
return batch_detect(self.net, images.clone(), threshold=self.threshold)
|
src/pixel3dmm/preprocessing/facer/facer/face_parsing/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .base import FaceParser
|
2 |
+
from .farl import FaRLFaceParser
|
src/pixel3dmm/preprocessing/facer/facer/face_parsing/base.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class FaceParser(nn.Module):
|
5 |
+
""" face parser
|
6 |
+
|
7 |
+
Args:
|
8 |
+
images (torch.Tensor): b x c x h x w
|
9 |
+
|
10 |
+
data (Dict[str, Any]):
|
11 |
+
|
12 |
+
* image_ids (torch.Tensor): nfaces
|
13 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
14 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
data (Dict[str, Any]):
|
18 |
+
|
19 |
+
* image_ids (torch.Tensor): nfaces
|
20 |
+
* rects (torch.Tensor): nfaces x 4 (x1, y1, x2, y2)
|
21 |
+
* points (torch.Tensor): nfaces x 5 x 2 (x, y)
|
22 |
+
* seg (Dict[str, Any]):
|
23 |
+
|
24 |
+
* logits (torch.Tensor): nfaces x nclasses x h x w
|
25 |
+
* label_names (List[str]): nclasses
|
26 |
+
"""
|
27 |
+
pass
|
src/pixel3dmm/preprocessing/facer/facer/face_parsing/farl.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Dict, Any
|
2 |
+
import functools
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from ..util import download_jit
|
7 |
+
from ..transform import (get_crop_and_resize_matrix, get_face_align_matrix, get_face_align_matrix_celebm,
|
8 |
+
make_inverted_tanh_warp_grid, make_tanh_warp_grid)
|
9 |
+
from .base import FaceParser
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
pretrain_settings = {
|
13 |
+
'lapa/448': {
|
14 |
+
'url': [
|
15 |
+
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt',
|
16 |
+
],
|
17 |
+
'matrix_src_tag': 'points',
|
18 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix,
|
19 |
+
target_shape=(448, 448), target_face_scale=1.0),
|
20 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
21 |
+
warp_factor=0.8, warped_shape=(448, 448)),
|
22 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
23 |
+
warp_factor=0.8, warped_shape=(448, 448)),
|
24 |
+
'label_names': ['background', 'face', 'rb', 'lb', 're',
|
25 |
+
'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
|
26 |
+
},
|
27 |
+
'celebm/448': {
|
28 |
+
'url': [
|
29 |
+
'https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.celebm.main_ema_181500_jit.pt',
|
30 |
+
],
|
31 |
+
'matrix_src_tag': 'points',
|
32 |
+
'get_matrix_fn': functools.partial(get_face_align_matrix_celebm,
|
33 |
+
target_shape=(448, 448)),
|
34 |
+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
|
35 |
+
warp_factor=0, warped_shape=(448, 448)),
|
36 |
+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
|
37 |
+
warp_factor=0, warped_shape=(448, 448)),
|
38 |
+
'label_names': [
|
39 |
+
'background', 'neck', 'face', 'cloth', 'rr', 'lr', 'rb', 'lb', 're',
|
40 |
+
'le', 'nose', 'imouth', 'llip', 'ulip', 'hair',
|
41 |
+
'eyeg', 'hat', 'earr', 'neck_l']
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
class FaRLFaceParser(FaceParser):
|
47 |
+
""" The face parsing models from [FaRL](https://github.com/FacePerceiver/FaRL).
|
48 |
+
|
49 |
+
Please consider citing
|
50 |
+
```bibtex
|
51 |
+
@article{zheng2021farl,
|
52 |
+
title={General Facial Representation Learning in a Visual-Linguistic Manner},
|
53 |
+
author={Zheng, Yinglin and Yang, Hao and Zhang, Ting and Bao, Jianmin and Chen,
|
54 |
+
Dongdong and Huang, Yangyu and Yuan, Lu and Chen,
|
55 |
+
Dong and Zeng, Ming and Wen, Fang},
|
56 |
+
journal={arXiv preprint arXiv:2112.03109},
|
57 |
+
year={2021}
|
58 |
+
}
|
59 |
+
```
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, conf_name: Optional[str] = None, model_path: Optional[str] = None, device=None) -> None:
|
63 |
+
super().__init__()
|
64 |
+
if conf_name is None:
|
65 |
+
conf_name = 'lapa/448'
|
66 |
+
if model_path is None:
|
67 |
+
model_path = pretrain_settings[conf_name]['url']
|
68 |
+
self.conf_name = conf_name
|
69 |
+
self.net = download_jit(model_path, map_location=device)
|
70 |
+
self.eval()
|
71 |
+
self.device = device
|
72 |
+
self.setting = pretrain_settings[conf_name]
|
73 |
+
self.label_names = self.setting['label_names']
|
74 |
+
|
75 |
+
|
76 |
+
def get_warp_grid(self, images: torch.Tensor, matrix_src):
|
77 |
+
_, _, h, w = images.shape
|
78 |
+
matrix = self.setting['get_matrix_fn'](matrix_src)
|
79 |
+
grid = self.setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
80 |
+
inv_grid = self.setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
|
81 |
+
return grid, inv_grid
|
82 |
+
|
83 |
+
def warp_images(self, images: torch.Tensor, data: Dict[str, Any]):
|
84 |
+
simages = self.unify_image_dtype(images)
|
85 |
+
simages = simages[data['image_ids']]
|
86 |
+
matrix_src = data[self.setting['matrix_src_tag']]
|
87 |
+
grid, inv_grid = self.get_warp_grid(simages, matrix_src)
|
88 |
+
|
89 |
+
w_images = F.grid_sample(
|
90 |
+
simages, grid, mode='bilinear', align_corners=False)
|
91 |
+
return w_images, grid, inv_grid
|
92 |
+
|
93 |
+
|
94 |
+
def decode_image_to_cv2(self, images: torch.Tensor):
|
95 |
+
'''
|
96 |
+
output: b x 3 x h x w, torch.uint8, [0, 255]
|
97 |
+
'''
|
98 |
+
assert images.ndim == 4
|
99 |
+
assert images.shape[1] == 3
|
100 |
+
images = images.permute(0, 2, 3, 1).cpu().numpy() * 255
|
101 |
+
images = images.astype(np.uint8)
|
102 |
+
return images
|
103 |
+
|
104 |
+
def unify_image_dtype(self, images: torch.Tensor|np.ndarray|list):
|
105 |
+
'''
|
106 |
+
output: b x 3 x h x w, torch.float32, [0, 1]
|
107 |
+
'''
|
108 |
+
if isinstance(images, np.ndarray):
|
109 |
+
images = torch.from_numpy(images)
|
110 |
+
elif isinstance(images, torch.Tensor):
|
111 |
+
pass
|
112 |
+
elif isinstance(images, list):
|
113 |
+
assert len(images) > 0, "images is empty"
|
114 |
+
first_image = images[0]
|
115 |
+
if isinstance(first_image, np.ndarray):
|
116 |
+
images = [torch.from_numpy(image).permute(2, 0, 1) for image in images]
|
117 |
+
images = torch.stack(images)
|
118 |
+
elif isinstance(first_image, torch.Tensor):
|
119 |
+
images = torch.stack(images)
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Unsupported image type: {type(first_image)}")
|
122 |
+
|
123 |
+
else:
|
124 |
+
raise ValueError(f"Unsupported image type: {type(images)}")
|
125 |
+
|
126 |
+
assert images.ndim == 4
|
127 |
+
assert images.shape[1] == 3
|
128 |
+
|
129 |
+
max_val = images.max()
|
130 |
+
if max_val <= 1:
|
131 |
+
assert images.dtype == torch.float32 or images.dtype == torch.float16
|
132 |
+
elif max_val <= 255:
|
133 |
+
assert images.dtype == torch.uint8
|
134 |
+
images = images.float() / 255.0
|
135 |
+
else:
|
136 |
+
raise ValueError(f"Unsupported image type: {images.dtype}")
|
137 |
+
if images.device != self.device:
|
138 |
+
images = images.to(device=self.device)
|
139 |
+
return images
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
@torch.inference_mode()
|
143 |
+
def forward(self, images: torch.Tensor, data: Dict[str, Any]):
|
144 |
+
'''
|
145 |
+
images: b x 3 x h x w , torch.uint8, [0, 255]
|
146 |
+
data: {'rects': rects, 'points': points, 'scores': scores, 'image_ids': image_ids}
|
147 |
+
'''
|
148 |
+
w_images, grid, inv_grid = self.warp_images(images, data)
|
149 |
+
w_seg_logits = self.forward_warped(w_images, return_preds=False)
|
150 |
+
|
151 |
+
seg_logits = F.grid_sample(
|
152 |
+
w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
|
153 |
+
|
154 |
+
data['seg'] = {'logits': seg_logits, 'label_names': self.label_names}
|
155 |
+
return data
|
156 |
+
|
157 |
+
|
158 |
+
def logits2predictions(self, logits: torch.Tensor):
|
159 |
+
return logits.argmax(dim=1)
|
160 |
+
|
161 |
+
@torch.no_grad()
|
162 |
+
@torch.inference_mode()
|
163 |
+
def forward_warped(self, images: torch.Tensor, return_preds: bool = True):
|
164 |
+
'''
|
165 |
+
images: b x 3 x h x w , torch.uint8, [0, 255]
|
166 |
+
'''
|
167 |
+
images = self.unify_image_dtype(images)
|
168 |
+
seg_logits, _ = self.net(images) # nfaces x c x h x w
|
169 |
+
# seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
|
170 |
+
if return_preds:
|
171 |
+
seg_preds = self.logits2predictions(seg_logits)
|
172 |
+
return seg_logits, seg_preds, self.label_names
|
173 |
+
else:
|
174 |
+
return seg_logits
|
src/pixel3dmm/preprocessing/facer/facer/farl/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from .model import load_farl, VisualTransformer
|
5 |
+
from .classification import farl_classification
|
src/pixel3dmm/preprocessing/facer/facer/farl/classification.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torch.utils.checkpoint as checkpoint
|
4 |
+
from .model import VisualTransformer
|
5 |
+
|
6 |
+
|
7 |
+
class VITClassificationHeadV0(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
num_features: int,
|
11 |
+
channel: int,
|
12 |
+
num_labels: int,
|
13 |
+
norm=False,
|
14 |
+
dropout=0.0,
|
15 |
+
ret_feat=False,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.weights = nn.Parameter(
|
19 |
+
torch.ones(1, num_features * 3, 1, dtype=torch.float32)
|
20 |
+
)
|
21 |
+
self.final_fc = nn.Linear(channel, num_labels)
|
22 |
+
self.norm = norm
|
23 |
+
if self.norm:
|
24 |
+
for i in range(num_features * 3):
|
25 |
+
setattr(self, f"norm_{i}", nn.LayerNorm(channel))
|
26 |
+
self.dropout = nn.Dropout(p=dropout)
|
27 |
+
self.ret_feat = ret_feat
|
28 |
+
|
29 |
+
def forward(self, features, cls_tokens):
|
30 |
+
xs = []
|
31 |
+
for feature, cls_token in zip(features, cls_tokens):
|
32 |
+
# feature: b x c x s x s
|
33 |
+
# cls_token: b x c
|
34 |
+
xs.append(feature.mean([2, 3]))
|
35 |
+
xs.append(feature.max(-1).values.max(-1).values)
|
36 |
+
xs.append(cls_token)
|
37 |
+
if self.norm:
|
38 |
+
xs = [getattr(self, f"norm_{i}")(x) for i, x in enumerate(xs)]
|
39 |
+
xs = torch.stack(xs, dim=1) # b x 3N x c
|
40 |
+
feat = (xs * self.weights.softmax(dim=1)).sum(1) # b x c
|
41 |
+
x = self.dropout(feat)
|
42 |
+
x = self.final_fc(x) # b x num_labels
|
43 |
+
if self.ret_feat:
|
44 |
+
return x, feat
|
45 |
+
else:
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class FACTransformer(nn.Module):
|
50 |
+
"""A face attribute classification transformer leveraging multiple cls_tokens.
|
51 |
+
Args:
|
52 |
+
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
|
53 |
+
Returns:
|
54 |
+
logits (torch.Tensor): Float32 tensor with shape [b, n_classes].
|
55 |
+
aux_outputs:
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, backbone: nn.Module, head: nn.Module):
|
59 |
+
super().__init__()
|
60 |
+
self.backbone = backbone
|
61 |
+
self.head = head
|
62 |
+
self.cuda().float()
|
63 |
+
|
64 |
+
def forward(self, image):
|
65 |
+
logits = self.head(*self.backbone(image))
|
66 |
+
return logits
|
67 |
+
|
68 |
+
|
69 |
+
def add_method(obj, name, method):
|
70 |
+
import types
|
71 |
+
|
72 |
+
setattr(obj, name, types.MethodType(method, obj))
|
73 |
+
|
74 |
+
|
75 |
+
def get_clip_encode_func(layers):
|
76 |
+
def func(self, x):
|
77 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
78 |
+
# shape = [*, width, grid ** 2]
|
79 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
80 |
+
extra_tokens = getattr(self, "extra_tokens", [])
|
81 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
82 |
+
class_token = self.class_embedding.to(x.dtype) + torch.zeros(
|
83 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
84 |
+
)
|
85 |
+
special_tokens = [
|
86 |
+
getattr(self, name).to(x.dtype)
|
87 |
+
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
|
88 |
+
for name in extra_tokens
|
89 |
+
]
|
90 |
+
x = torch.cat(
|
91 |
+
[class_token, *special_tokens, x], dim=1
|
92 |
+
) # shape = [*, grid ** 2 + 1, width]
|
93 |
+
x = x + self.positional_embedding.to(x.dtype)
|
94 |
+
x = self.ln_pre(x)
|
95 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
96 |
+
outs = []
|
97 |
+
max_layer = max(layers)
|
98 |
+
use_checkpoint = self.transformer.use_checkpoint
|
99 |
+
for layer_i, blk in enumerate(self.transformer.resblocks):
|
100 |
+
if layer_i > max_layer:
|
101 |
+
break
|
102 |
+
if self.training and use_checkpoint:
|
103 |
+
x = checkpoint.checkpoint(blk, x)
|
104 |
+
else:
|
105 |
+
x = blk(x)
|
106 |
+
outs.append(x)
|
107 |
+
|
108 |
+
outs = torch.stack(outs).permute(0, 2, 1, 3)
|
109 |
+
cls_tokens = outs[layers, :, 0, :]
|
110 |
+
|
111 |
+
extra_token_feats = {}
|
112 |
+
for i, name in enumerate(extra_tokens):
|
113 |
+
extra_token_feats[name] = outs[layers, :, i + 1, :]
|
114 |
+
L, B, N, C = outs.shape
|
115 |
+
import math
|
116 |
+
|
117 |
+
W = int(math.sqrt(N - 1 - len(extra_tokens)))
|
118 |
+
features = (
|
119 |
+
outs[layers, :, 1 + len(extra_tokens) :, :]
|
120 |
+
.reshape(len(layers), B, W, W, C)
|
121 |
+
.permute(0, 1, 4, 2, 3)
|
122 |
+
)
|
123 |
+
if getattr(self, "ret_special", False):
|
124 |
+
return features, cls_tokens, extra_token_feats
|
125 |
+
else:
|
126 |
+
return features, cls_tokens
|
127 |
+
|
128 |
+
return func
|
129 |
+
|
130 |
+
|
131 |
+
def farl_classification(num_classes=2, layers=list(range(12))):
|
132 |
+
model = VisualTransformer(
|
133 |
+
input_resolution=224,
|
134 |
+
patch_size=16,
|
135 |
+
width=768,
|
136 |
+
layers=12,
|
137 |
+
heads=12,
|
138 |
+
output_dim=512,
|
139 |
+
)
|
140 |
+
channel = 768
|
141 |
+
model = model.cuda()
|
142 |
+
del model.proj
|
143 |
+
del model.ln_post
|
144 |
+
add_method(model, "forward", get_clip_encode_func(layers))
|
145 |
+
head = VITClassificationHeadV0(
|
146 |
+
num_features=len(layers), channel=channel, num_labels=num_classes, norm=True
|
147 |
+
)
|
148 |
+
model = FACTransformer(model, head)
|
149 |
+
return model
|
src/pixel3dmm/preprocessing/facer/facer/farl/model.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# Licensed under the MIT License.
|
3 |
+
|
4 |
+
from collections import OrderedDict
|
5 |
+
import logging
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
import torch.utils.checkpoint as checkpoint
|
10 |
+
import numpy as np
|
11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
12 |
+
|
13 |
+
|
14 |
+
class Bottleneck(nn.Module):
|
15 |
+
expansion = 4
|
16 |
+
|
17 |
+
def __init__(self, inplanes, planes, stride=1):
|
18 |
+
super().__init__()
|
19 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
20 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
21 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
22 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
23 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
24 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
25 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
26 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
self.stride = stride
|
30 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
31 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
32 |
+
self.downsample = nn.Sequential(OrderedDict([
|
33 |
+
("-1", nn.AvgPool2d(stride)),
|
34 |
+
("0", nn.Conv2d(inplanes, planes *
|
35 |
+
self.expansion, 1, stride=1, bias=False)),
|
36 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
37 |
+
]))
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor):
|
40 |
+
identity = x
|
41 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
42 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
43 |
+
out = self.avgpool(out)
|
44 |
+
out = self.bn3(self.conv3(out))
|
45 |
+
if self.downsample is not None:
|
46 |
+
identity = self.downsample(x)
|
47 |
+
out += identity
|
48 |
+
out = self.relu(out)
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class AttentionPool2d(nn.Module):
|
53 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
54 |
+
super().__init__()
|
55 |
+
self.positional_embedding = nn.Parameter(
|
56 |
+
torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
|
57 |
+
)
|
58 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
59 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
60 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
61 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
62 |
+
self.num_heads = num_heads
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2]
|
66 |
+
* x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
67 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
68 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
69 |
+
x, _ = F.multi_head_attention_forward(
|
70 |
+
query=x, key=x, value=x,
|
71 |
+
embed_dim_to_check=x.shape[-1],
|
72 |
+
num_heads=self.num_heads,
|
73 |
+
q_proj_weight=self.q_proj.weight,
|
74 |
+
k_proj_weight=self.k_proj.weight,
|
75 |
+
v_proj_weight=self.v_proj.weight,
|
76 |
+
in_proj_weight=None,
|
77 |
+
in_proj_bias=torch.cat(
|
78 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
|
79 |
+
),
|
80 |
+
bias_k=None,
|
81 |
+
bias_v=None,
|
82 |
+
add_zero_attn=False,
|
83 |
+
dropout_p=0,
|
84 |
+
out_proj_weight=self.c_proj.weight,
|
85 |
+
out_proj_bias=self.c_proj.bias,
|
86 |
+
use_separate_proj_weight=True,
|
87 |
+
training=self.training,
|
88 |
+
need_weights=False
|
89 |
+
)
|
90 |
+
return x[0]
|
91 |
+
|
92 |
+
|
93 |
+
class ModifiedResNet(nn.Module):
|
94 |
+
"""
|
95 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
96 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
97 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
98 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
102 |
+
super().__init__()
|
103 |
+
self.output_dim = output_dim
|
104 |
+
self.input_resolution = input_resolution
|
105 |
+
# the 3-layer stem
|
106 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3,
|
107 |
+
stride=2, padding=1, bias=False)
|
108 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
109 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2,
|
110 |
+
kernel_size=3, padding=1, bias=False)
|
111 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
112 |
+
self.conv3 = nn.Conv2d(
|
113 |
+
width // 2, width, kernel_size=3, padding=1, bias=False)
|
114 |
+
self.bn3 = nn.BatchNorm2d(width)
|
115 |
+
self.avgpool = nn.AvgPool2d(2)
|
116 |
+
self.relu = nn.ReLU(inplace=True)
|
117 |
+
# residual layers
|
118 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
119 |
+
self.layer1 = self._make_layer(width, layers[0])
|
120 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
121 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
122 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
123 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
124 |
+
self.attnpool = AttentionPool2d(
|
125 |
+
input_resolution // 32, embed_dim, heads, output_dim
|
126 |
+
)
|
127 |
+
self.apply(self._init_weights)
|
128 |
+
|
129 |
+
def _init_weights(self, m):
|
130 |
+
if isinstance(m, (nn.BatchNorm2d, LayerNorm)):
|
131 |
+
nn.init.constant_(m.weight, 1)
|
132 |
+
nn.init.constant_(m.bias, 0)
|
133 |
+
elif isinstance(m, (nn.Linear, nn.Conv2d)):
|
134 |
+
trunc_normal_(m.weight, std=0.02)
|
135 |
+
if m.bias is not None:
|
136 |
+
nn.init.constant_(m.bias, 0)
|
137 |
+
|
138 |
+
def _make_layer(self, planes, blocks, stride=1):
|
139 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
140 |
+
self._inplanes = planes * Bottleneck.expansion
|
141 |
+
for _ in range(1, blocks):
|
142 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
143 |
+
return nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
def stem(x):
|
147 |
+
for conv, bn in [
|
148 |
+
(self.conv1, self.bn1),
|
149 |
+
(self.conv2, self.bn2),
|
150 |
+
(self.conv3, self.bn3)
|
151 |
+
]:
|
152 |
+
x = self.relu(bn(conv(x)))
|
153 |
+
x = self.avgpool(x)
|
154 |
+
return x
|
155 |
+
x = x.type(self.conv1.weight.dtype)
|
156 |
+
x = stem(x)
|
157 |
+
x = self.layer1(x)
|
158 |
+
x = self.layer2(x)
|
159 |
+
x = self.layer3(x)
|
160 |
+
x = self.layer4(x)
|
161 |
+
x = self.attnpool(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
|
165 |
+
class LayerNorm(nn.Module):
|
166 |
+
def __init__(self, hidden_size, eps=1e-5):
|
167 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
168 |
+
"""
|
169 |
+
super(LayerNorm, self).__init__()
|
170 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
171 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
172 |
+
self.variance_epsilon = eps
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
pdtype = x.dtype
|
176 |
+
x = x.float()
|
177 |
+
u = x.mean(-1, keepdim=True)
|
178 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
179 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
180 |
+
return self.weight * x.to(pdtype) + self.bias
|
181 |
+
|
182 |
+
|
183 |
+
class QuickGELU(nn.Module):
|
184 |
+
def forward(self, x: torch.Tensor):
|
185 |
+
return x * torch.sigmoid(1.702 * x)
|
186 |
+
|
187 |
+
|
188 |
+
class ResidualAttentionBlock(nn.Module):
|
189 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path=0.):
|
190 |
+
super().__init__()
|
191 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
192 |
+
self.ln_1 = LayerNorm(d_model)
|
193 |
+
self.mlp = nn.Sequential(OrderedDict([
|
194 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
195 |
+
("gelu", QuickGELU()),
|
196 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
197 |
+
]))
|
198 |
+
self.ln_2 = LayerNorm(d_model)
|
199 |
+
self.attn_mask = attn_mask
|
200 |
+
self.drop_path = DropPath(
|
201 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
202 |
+
|
203 |
+
def add_drop_path(self, drop_path):
|
204 |
+
self.drop_path = DropPath(
|
205 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
206 |
+
|
207 |
+
def attention(self, x: torch.Tensor):
|
208 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
|
209 |
+
if self.attn_mask is not None else None
|
210 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
211 |
+
|
212 |
+
def forward(self, x: torch.Tensor):
|
213 |
+
x = x + self.drop_path(self.attention(self.ln_1(x)))
|
214 |
+
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class Transformer(nn.Module):
|
219 |
+
def __init__(self,
|
220 |
+
width: int,
|
221 |
+
layers: int,
|
222 |
+
heads: int,
|
223 |
+
attn_mask: torch.Tensor = None,
|
224 |
+
use_checkpoint=True,
|
225 |
+
drop_rate=0.,
|
226 |
+
attn_drop_rate=0.,
|
227 |
+
drop_path_rate=0.,
|
228 |
+
):
|
229 |
+
super().__init__()
|
230 |
+
self.width = width
|
231 |
+
self.layers = layers
|
232 |
+
self.use_checkpoint = use_checkpoint
|
233 |
+
# stochastic depth decay rule
|
234 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers)]
|
235 |
+
self.resblocks = nn.ModuleList([
|
236 |
+
ResidualAttentionBlock(width, heads, attn_mask, drop_path=dpr[i])
|
237 |
+
for i in range(layers)
|
238 |
+
])
|
239 |
+
self.apply(self._init_weights)
|
240 |
+
|
241 |
+
def _init_weights(self, m):
|
242 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
243 |
+
trunc_normal_(m.weight, std=0.02)
|
244 |
+
if m.bias is not None:
|
245 |
+
nn.init.constant_(m.bias, 0)
|
246 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
nn.init.constant_(m.weight, 1.0)
|
249 |
+
|
250 |
+
def forward(self, x: torch.Tensor):
|
251 |
+
for i, blk in enumerate(self.resblocks):
|
252 |
+
x = blk(x)
|
253 |
+
return x
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
class VisualTransformer(nn.Module):
|
258 |
+
positional_embedding: nn.Parameter
|
259 |
+
|
260 |
+
def __init__(self,
|
261 |
+
input_resolution: int,
|
262 |
+
patch_size: int,
|
263 |
+
width: int,
|
264 |
+
layers: int,
|
265 |
+
heads: int,
|
266 |
+
output_dim: int,
|
267 |
+
pool_type: str = 'default',
|
268 |
+
skip_cls: bool = False,
|
269 |
+
drop_path_rate=0.,
|
270 |
+
**kwargs):
|
271 |
+
super().__init__()
|
272 |
+
self.pool_type = pool_type
|
273 |
+
self.skip_cls = skip_cls
|
274 |
+
self.input_resolution = input_resolution
|
275 |
+
self.output_dim = output_dim
|
276 |
+
self.conv1 = nn.Conv2d(
|
277 |
+
in_channels=3,
|
278 |
+
out_channels=width,
|
279 |
+
kernel_size=patch_size,
|
280 |
+
stride=patch_size,
|
281 |
+
bias=False
|
282 |
+
)
|
283 |
+
self.config = kwargs.get("config", None)
|
284 |
+
self.sequence_length = (input_resolution // patch_size) ** 2 + 1
|
285 |
+
self.conv_pool = nn.Identity()
|
286 |
+
if (self.pool_type == 'linear'):
|
287 |
+
if (not self.skip_cls):
|
288 |
+
self.conv_pool = nn.Conv1d(
|
289 |
+
width, width, self.sequence_length, stride=self.sequence_length, groups=width)
|
290 |
+
else:
|
291 |
+
self.conv_pool = nn.Conv1d(
|
292 |
+
width, width, self.sequence_length-1, stride=self.sequence_length, groups=width)
|
293 |
+
scale = width ** -0.5
|
294 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
295 |
+
self.positional_embedding = nn.Parameter(
|
296 |
+
scale * torch.randn(
|
297 |
+
self.sequence_length, width
|
298 |
+
)
|
299 |
+
)
|
300 |
+
self.ln_pre = LayerNorm(width)
|
301 |
+
self.transformer = Transformer(
|
302 |
+
width, layers, heads, drop_path_rate=drop_path_rate)
|
303 |
+
self.ln_post = LayerNorm(width)
|
304 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
305 |
+
if self.config is not None and self.config.MIM.ENABLE:
|
306 |
+
logging.info("MIM ENABLED")
|
307 |
+
self.mim = True
|
308 |
+
self.lm_transformer = Transformer(
|
309 |
+
width, self.config.MIM.LAYERS, heads)
|
310 |
+
self.ln_lm = LayerNorm(width)
|
311 |
+
self.lm_head = nn.Linear(width, self.config.MIM.VOCAB_SIZE)
|
312 |
+
self.mask_token = nn.Parameter(scale * torch.randn(width))
|
313 |
+
else:
|
314 |
+
self.mim = False
|
315 |
+
self.apply(self._init_weights)
|
316 |
+
|
317 |
+
def _init_weights(self, m):
|
318 |
+
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv1d)):
|
319 |
+
trunc_normal_(m.weight, std=0.02)
|
320 |
+
if m.bias is not None:
|
321 |
+
nn.init.constant_(m.bias, 0)
|
322 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
323 |
+
nn.init.constant_(m.bias, 0)
|
324 |
+
nn.init.constant_(m.weight, 1.0)
|
325 |
+
|
326 |
+
def forward(self, x: torch.Tensor):
|
327 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
328 |
+
# shape = [*, width, grid ** 2]
|
329 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
330 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
331 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1],
|
332 |
+
dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
333 |
+
x = x + self.positional_embedding.to(x.dtype)
|
334 |
+
x = self.ln_pre(x)
|
335 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
336 |
+
x = self.transformer(x)
|
337 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
338 |
+
if (self.pool_type == 'average'):
|
339 |
+
if self.skip_cls:
|
340 |
+
x = x[:, 1:, :]
|
341 |
+
x = torch.mean(x, dim=1)
|
342 |
+
elif (self.pool_type == 'linear'):
|
343 |
+
if self.skip_cls:
|
344 |
+
x = x[:, 1:, :]
|
345 |
+
x = x.permute(0, 2, 1)
|
346 |
+
x = self.conv_pool(x)
|
347 |
+
x = x.permute(0, 2, 1).squeeze()
|
348 |
+
else:
|
349 |
+
x = x[:, 0, :]
|
350 |
+
x = self.ln_post(x)
|
351 |
+
if self.proj is not None:
|
352 |
+
x = x @ self.proj
|
353 |
+
return x
|
354 |
+
|
355 |
+
def forward_mim(self, x: torch.Tensor, bool_masked_pos, return_all_tokens=False, disable_vlc=False):
|
356 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
357 |
+
# shape = [*, width, grid ** 2]
|
358 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
359 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
360 |
+
batch_size, seq_len, _ = x.size()
|
361 |
+
mask_token = self.mask_token.unsqueeze(
|
362 |
+
0).unsqueeze(0).expand(batch_size, seq_len, -1)
|
363 |
+
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
|
364 |
+
masked_x = x * (1 - w) + mask_token * w
|
365 |
+
if disable_vlc:
|
366 |
+
x = masked_x
|
367 |
+
masked_start = 0
|
368 |
+
else:
|
369 |
+
x = torch.cat([x, masked_x], 0)
|
370 |
+
masked_start = batch_size
|
371 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(
|
372 |
+
x.shape[0], 1, x.shape[-1],
|
373 |
+
dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
374 |
+
x = x + self.positional_embedding.to(x.dtype)
|
375 |
+
x = self.ln_pre(x)
|
376 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
377 |
+
x = self.transformer(x)
|
378 |
+
masked_x = x[:, masked_start:]
|
379 |
+
masked_x = self.lm_transformer(masked_x)
|
380 |
+
masked_x = masked_x.permute(1, 0, 2)
|
381 |
+
masked_x = masked_x[:, 1:]
|
382 |
+
masked_x = self.ln_lm(masked_x)
|
383 |
+
if not return_all_tokens:
|
384 |
+
masked_x = masked_x[bool_masked_pos]
|
385 |
+
logits = self.lm_head(masked_x)
|
386 |
+
assert self.pool_type == "default"
|
387 |
+
result = {"logits": logits}
|
388 |
+
if not disable_vlc:
|
389 |
+
x = x[0, :batch_size]
|
390 |
+
x = self.ln_post(x)
|
391 |
+
if self.proj is not None:
|
392 |
+
x = x @ self.proj
|
393 |
+
result["feature"] = x
|
394 |
+
return result
|
395 |
+
|
396 |
+
|
397 |
+
def load_farl(model_type, model_file=None) -> VisualTransformer:
|
398 |
+
if model_type == "base":
|
399 |
+
model = VisualTransformer(
|
400 |
+
input_resolution=224, patch_size=16, width=768, layers=12, heads=12, output_dim=512)
|
401 |
+
elif model_type == "large":
|
402 |
+
model = VisualTransformer(
|
403 |
+
input_resolution=224, patch_size=16, width=1024, layers=24, heads=16, output_dim=512)
|
404 |
+
elif model_type == "huge":
|
405 |
+
model = VisualTransformer(
|
406 |
+
input_resolution=224, patch_size=14, width=1280, layers=32, heads=16, output_dim=512)
|
407 |
+
else:
|
408 |
+
raise
|
409 |
+
model.transformer.use_checkpoint = False
|
410 |
+
if model_file is not None:
|
411 |
+
checkpoint = torch.load(model_file, map_location='cpu')
|
412 |
+
state_dict = {}
|
413 |
+
for name, weight in checkpoint["state_dict"].items():
|
414 |
+
if name.startswith("visual"):
|
415 |
+
state_dict[name[7:]] = weight
|
416 |
+
inco = model.load_state_dict(state_dict, strict=False)
|
417 |
+
# print(inco.missing_keys)
|
418 |
+
assert len(inco.missing_keys) == 0
|
419 |
+
return model
|
src/pixel3dmm/preprocessing/facer/facer/io.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def read_hwc(path: str) -> torch.Tensor:
|
7 |
+
"""Read an image from a given path.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
path (str): The given path.
|
11 |
+
"""
|
12 |
+
image = Image.open(path)
|
13 |
+
np_image = np.array(image.convert('RGB'))
|
14 |
+
return torch.from_numpy(np_image)
|
15 |
+
|
16 |
+
|
17 |
+
def write_hwc(image: torch.Tensor, path: str):
|
18 |
+
"""Write an image to a given path.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
image (torch.Tensor): The image.
|
22 |
+
path (str): The given path.
|
23 |
+
"""
|
24 |
+
|
25 |
+
Image.fromarray(image.cpu().numpy()).save(path)
|
26 |
+
|
27 |
+
|
28 |
+
|
src/pixel3dmm/preprocessing/facer/facer/show.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from .util import bchw2hwc
|
7 |
+
|
8 |
+
|
9 |
+
def set_figsize(*args):
|
10 |
+
if len(args) == 0:
|
11 |
+
plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]
|
12 |
+
elif len(args) == 1:
|
13 |
+
plt.rcParams["figure.figsize"] = (args[0], args[0])
|
14 |
+
elif len(args) == 2:
|
15 |
+
plt.rcParams["figure.figsize"] = tuple(args)
|
16 |
+
else:
|
17 |
+
raise RuntimeError(
|
18 |
+
f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)')
|
19 |
+
|
20 |
+
|
21 |
+
def show_hwc(image: torch.Tensor):
|
22 |
+
if image.dtype != torch.uint8:
|
23 |
+
image = image.to(torch.uint8)
|
24 |
+
if image.size(2) == 1:
|
25 |
+
image = image.repeat(1, 1, 3)
|
26 |
+
pimage = Image.fromarray(image.cpu().numpy())
|
27 |
+
plt.imshow(pimage)
|
28 |
+
plt.show()
|
29 |
+
|
30 |
+
|
31 |
+
def show_bchw(image: torch.Tensor):
|
32 |
+
show_hwc(bchw2hwc(image))
|
33 |
+
|
34 |
+
|
35 |
+
def show_bhw(image: torch.Tensor):
|
36 |
+
show_bchw(image.unsqueeze(1))
|
src/pixel3dmm/preprocessing/facer/facer/transform.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Callable, Tuple, Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def get_crop_and_resize_matrix(
|
9 |
+
box: torch.Tensor, target_shape: Tuple[int, int],
|
10 |
+
target_face_scale: float = 1.0, make_square_crop: bool = True,
|
11 |
+
offset_xy: Optional[Tuple[float, float]] = None, align_corners: bool = True,
|
12 |
+
offset_box_coords: bool = False) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
Args:
|
15 |
+
box: b x 4(x1, y1, x2, y2)
|
16 |
+
align_corners (bool): Set this to `True` only if the box you give has coordinates
|
17 |
+
ranging from `0` to `h-1` or `w-1`.
|
18 |
+
|
19 |
+
offset_box_coords (bool): Set this to `True` if the box you give has coordinates
|
20 |
+
ranging from `0` to `h` or `w`.
|
21 |
+
|
22 |
+
Set this to `False` if the box coordinates range from `-0.5` to `h-0.5` or `w-0.5`.
|
23 |
+
|
24 |
+
If the box coordinates range from `0` to `h-1` or `w-1`, set `align_corners=True`.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: b x 3 x 3.
|
28 |
+
"""
|
29 |
+
if offset_xy is None:
|
30 |
+
offset_xy = (0.0, 0.0)
|
31 |
+
|
32 |
+
x1, y1, x2, y2 = box.split(1, dim=1) # b x 1
|
33 |
+
cx = (x1 + x2) / 2 + offset_xy[0]
|
34 |
+
cy = (y1 + y2) / 2 + offset_xy[1]
|
35 |
+
rx = (x2 - x1) / 2 / target_face_scale
|
36 |
+
ry = (y2 - y1) / 2 / target_face_scale
|
37 |
+
if make_square_crop:
|
38 |
+
rx = ry = torch.maximum(rx, ry)
|
39 |
+
|
40 |
+
x1, y1, x2, y2 = cx - rx, cy - ry, cx + rx, cy + ry
|
41 |
+
|
42 |
+
h, w, *_ = target_shape
|
43 |
+
|
44 |
+
zeros_pl = torch.zeros_like(x1)
|
45 |
+
ones_pl = torch.ones_like(x1)
|
46 |
+
|
47 |
+
if align_corners:
|
48 |
+
# x -> (x - x1) / (x2 - x1) * (w - 1)
|
49 |
+
# y -> (y - y1) / (y2 - y1) * (h - 1)
|
50 |
+
ax = 1.0 / (x2 - x1) * (w - 1)
|
51 |
+
ay = 1.0 / (y2 - y1) * (h - 1)
|
52 |
+
matrix = torch.cat([
|
53 |
+
ax, zeros_pl, -x1 * ax,
|
54 |
+
zeros_pl, ay, -y1 * ay,
|
55 |
+
zeros_pl, zeros_pl, ones_pl
|
56 |
+
], dim=1).reshape(-1, 3, 3) # b x 3 x 3
|
57 |
+
else:
|
58 |
+
if offset_box_coords:
|
59 |
+
# x1, x2 \in [0, w], y1, y2 \in [0, h]
|
60 |
+
# first we should offset x1, x2, y1, y2 to be ranging in
|
61 |
+
# [-0.5, w-0.5] and [-0.5, h-0.5]
|
62 |
+
# so to convert these pixel coordinates into boundary coordinates.
|
63 |
+
x1, x2, y1, y2 = x1-0.5, x2-0.5, y1-0.5, y2-0.5
|
64 |
+
|
65 |
+
# x -> (x - x1) / (x2 - x1) * w - 0.5
|
66 |
+
# y -> (y - y1) / (y2 - y1) * h - 0.5
|
67 |
+
ax = 1.0 / (x2 - x1) * w
|
68 |
+
ay = 1.0 / (y2 - y1) * h
|
69 |
+
matrix = torch.cat([
|
70 |
+
ax, zeros_pl, -x1 * ax - 0.5*ones_pl,
|
71 |
+
zeros_pl, ay, -y1 * ay - 0.5*ones_pl,
|
72 |
+
zeros_pl, zeros_pl, ones_pl
|
73 |
+
], dim=1).reshape(-1, 3, 3) # b x 3 x 3
|
74 |
+
return matrix
|
75 |
+
|
76 |
+
|
77 |
+
def get_similarity_transform_matrix(
|
78 |
+
from_pts: torch.Tensor, to_pts: torch.Tensor) -> torch.Tensor:
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
from_pts, to_pts: b x n x 2
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
torch.Tensor: b x 3 x 3
|
85 |
+
"""
|
86 |
+
mfrom = from_pts.mean(dim=1, keepdim=True) # b x 1 x 2
|
87 |
+
mto = to_pts.mean(dim=1, keepdim=True) # b x 1 x 2
|
88 |
+
|
89 |
+
a1 = (from_pts - mfrom).square().sum([1, 2], keepdim=False) # b
|
90 |
+
c1 = ((to_pts - mto) * (from_pts - mfrom)).sum([1, 2], keepdim=False) # b
|
91 |
+
|
92 |
+
to_delta = to_pts - mto
|
93 |
+
from_delta = from_pts - mfrom
|
94 |
+
c2 = (to_delta[:, :, 0] * from_delta[:, :, 1] - to_delta[:,
|
95 |
+
:, 1] * from_delta[:, :, 0]).sum([1], keepdim=False) # b
|
96 |
+
|
97 |
+
a = c1 / a1
|
98 |
+
b = c2 / a1
|
99 |
+
dx = mto[:, 0, 0] - a * mfrom[:, 0, 0] - b * mfrom[:, 0, 1] # b
|
100 |
+
dy = mto[:, 0, 1] + b * mfrom[:, 0, 0] - a * mfrom[:, 0, 1] # b
|
101 |
+
|
102 |
+
ones_pl = torch.ones_like(a1)
|
103 |
+
zeros_pl = torch.zeros_like(a1)
|
104 |
+
|
105 |
+
return torch.stack([
|
106 |
+
a, b, dx,
|
107 |
+
-b, a, dy,
|
108 |
+
zeros_pl, zeros_pl, ones_pl,
|
109 |
+
], dim=-1).reshape(-1, 3, 3)
|
110 |
+
|
111 |
+
|
112 |
+
@functools.lru_cache()
|
113 |
+
def _standard_face_pts():
|
114 |
+
pts = torch.tensor([
|
115 |
+
196.0, 226.0,
|
116 |
+
316.0, 226.0,
|
117 |
+
256.0, 286.0,
|
118 |
+
220.0, 360.4,
|
119 |
+
292.0, 360.4], dtype=torch.float32) / 256.0 - 1.0
|
120 |
+
return torch.reshape(pts, (5, 2))
|
121 |
+
|
122 |
+
|
123 |
+
def get_face_align_matrix(
|
124 |
+
face_pts: torch.Tensor, target_shape: Tuple[int, int],
|
125 |
+
target_face_scale: float = 1.0, offset_xy: Optional[Tuple[float, float]] = None,
|
126 |
+
target_pts: Optional[torch.Tensor] = None):
|
127 |
+
|
128 |
+
if target_pts is None:
|
129 |
+
with torch.no_grad():
|
130 |
+
std_pts = _standard_face_pts().to(face_pts) # [-1 1]
|
131 |
+
h, w, *_ = target_shape
|
132 |
+
target_pts = (std_pts * target_face_scale + 1) * \
|
133 |
+
torch.tensor([w-1, h-1]).to(face_pts) / 2.0
|
134 |
+
if offset_xy is not None:
|
135 |
+
target_pts[:, 0] += offset_xy[0]
|
136 |
+
target_pts[:, 1] += offset_xy[1]
|
137 |
+
else:
|
138 |
+
target_pts = target_pts.to(face_pts)
|
139 |
+
|
140 |
+
if target_pts.dim() == 2:
|
141 |
+
target_pts = target_pts.unsqueeze(0)
|
142 |
+
if target_pts.size(0) == 1:
|
143 |
+
target_pts = target_pts.broadcast_to(face_pts.shape)
|
144 |
+
|
145 |
+
assert target_pts.shape == face_pts.shape
|
146 |
+
|
147 |
+
return get_similarity_transform_matrix(face_pts, target_pts)
|
148 |
+
|
149 |
+
|
150 |
+
def rot90(v):
|
151 |
+
return np.array([-v[1], v[0]])
|
152 |
+
|
153 |
+
|
154 |
+
def get_quad(lm: torch.Tensor):
|
155 |
+
# N,2
|
156 |
+
lm = lm.detach().cpu().numpy()
|
157 |
+
# Choose oriented crop rectangle.
|
158 |
+
eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5
|
159 |
+
mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5
|
160 |
+
eye_to_eye = lm[1] - lm[0]
|
161 |
+
eye_to_mouth = mouth_avg - eye_avg
|
162 |
+
x = eye_to_eye - rot90(eye_to_mouth)
|
163 |
+
x /= np.hypot(*x)
|
164 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
165 |
+
y = rot90(x)
|
166 |
+
c = eye_avg + eye_to_mouth * 0.1
|
167 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
168 |
+
quad_for_coeffs = quad[[0,3, 2,1]] # 顺序改一下
|
169 |
+
return torch.from_numpy(quad_for_coeffs).float()
|
170 |
+
|
171 |
+
|
172 |
+
def get_face_align_matrix_celebm(
|
173 |
+
face_pts: torch.Tensor, target_shape: Tuple[int, int]):
|
174 |
+
|
175 |
+
face_pts = torch.stack([get_quad(pts) for pts in face_pts], dim=0).to(face_pts)
|
176 |
+
|
177 |
+
assert target_shape[0] == target_shape[1]
|
178 |
+
target_size = target_shape[0]
|
179 |
+
target_pts = torch.as_tensor([[0, 0], [target_size,0], [target_size, target_size], [0, target_size]]).to(face_pts)
|
180 |
+
|
181 |
+
if target_pts.dim() == 2:
|
182 |
+
target_pts = target_pts.unsqueeze(0)
|
183 |
+
if target_pts.size(0) == 1:
|
184 |
+
target_pts = target_pts.broadcast_to(face_pts.shape)
|
185 |
+
|
186 |
+
assert target_pts.shape == face_pts.shape
|
187 |
+
|
188 |
+
return get_similarity_transform_matrix(face_pts, target_pts)
|
189 |
+
|
190 |
+
@functools.lru_cache(maxsize=128)
|
191 |
+
def _meshgrid(h, w) -> Tuple[torch.Tensor, torch.Tensor]:
|
192 |
+
yy, xx = torch.meshgrid(torch.arange(h).float(),
|
193 |
+
torch.arange(w).float(),
|
194 |
+
indexing='ij')
|
195 |
+
return yy + 0.5, xx + 0.5
|
196 |
+
|
197 |
+
|
198 |
+
def _forge_grid(batch_size: int, device: torch.device,
|
199 |
+
output_shape: Tuple[int, int],
|
200 |
+
fn: Callable[[torch.Tensor], torch.Tensor]
|
201 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
202 |
+
""" Forge transform maps with a given function `fn`.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
output_shape (tuple): (b, h, w, ...).
|
206 |
+
fn (Callable[[torch.Tensor], torch.Tensor]): The function that accepts
|
207 |
+
a bxnx2 array and outputs the transformed bxnx2 array. Both input
|
208 |
+
and output store (x, y) coordinates.
|
209 |
+
|
210 |
+
Note:
|
211 |
+
both input and output arrays of `fn` should store (y, x) coordinates.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
Tuple[torch.Tensor, torch.Tensor]: Two maps `X` and `Y`, where for each
|
215 |
+
pixel (y, x) or coordinate (x, y),
|
216 |
+
`(X[y, x], Y[y, x]) = fn([x, y])`
|
217 |
+
"""
|
218 |
+
h, w, *_ = output_shape
|
219 |
+
yy, xx = _meshgrid(h, w) # h x w
|
220 |
+
yy = yy.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
|
221 |
+
xx = xx.unsqueeze(0).broadcast_to(batch_size, h, w).to(device)
|
222 |
+
|
223 |
+
in_xxyy = torch.stack(
|
224 |
+
[xx, yy], dim=-1).reshape([batch_size, h*w, 2]) # (h x w) x 2
|
225 |
+
out_xxyy: torch.Tensor = fn(in_xxyy) # (h x w) x 2
|
226 |
+
return out_xxyy.reshape(batch_size, h, w, 2)
|
227 |
+
|
228 |
+
|
229 |
+
def _safe_arctanh(x: torch.Tensor, eps: float = 0.001) -> torch.Tensor:
|
230 |
+
return torch.clamp(x, -1+eps, 1-eps).arctanh()
|
231 |
+
|
232 |
+
|
233 |
+
def inverted_tanh_warp_transform(coords: torch.Tensor, matrix: torch.Tensor,
|
234 |
+
warp_factor: float, warped_shape: Tuple[int, int]):
|
235 |
+
""" Inverted tanh-warp function.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
coords (torch.Tensor): b x n x 2 (x, y). The transformed coordinates.
|
239 |
+
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
|
240 |
+
from the original image to the aligned yet not-warped image.
|
241 |
+
warp_factor (float): The warp factor.
|
242 |
+
0 means linear transform, 1 means full tanh warp.
|
243 |
+
warped_shape (tuple): [height, width].
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
torch.Tensor: b x n x 2 (x, y). The original coordinates.
|
247 |
+
"""
|
248 |
+
h, w, *_ = warped_shape
|
249 |
+
# h -= 1
|
250 |
+
# w -= 1
|
251 |
+
|
252 |
+
w_h = torch.tensor([[w, h]]).to(coords)
|
253 |
+
|
254 |
+
if warp_factor > 0:
|
255 |
+
# normalize coordinates to [-1, +1]
|
256 |
+
coords = coords / w_h * 2 - 1
|
257 |
+
|
258 |
+
nl_part1 = coords > 1.0 - warp_factor
|
259 |
+
nl_part2 = coords < -1.0 + warp_factor
|
260 |
+
|
261 |
+
ret_nl_part1 = _safe_arctanh(
|
262 |
+
(coords - 1.0 + warp_factor) /
|
263 |
+
warp_factor) * warp_factor + \
|
264 |
+
1.0 - warp_factor
|
265 |
+
ret_nl_part2 = _safe_arctanh(
|
266 |
+
(coords + 1.0 - warp_factor) /
|
267 |
+
warp_factor) * warp_factor - \
|
268 |
+
1.0 + warp_factor
|
269 |
+
|
270 |
+
coords = torch.where(nl_part1, ret_nl_part1,
|
271 |
+
torch.where(nl_part2, ret_nl_part2, coords))
|
272 |
+
|
273 |
+
# denormalize
|
274 |
+
coords = (coords + 1) / 2 * w_h
|
275 |
+
|
276 |
+
coords_homo = torch.cat(
|
277 |
+
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
|
278 |
+
|
279 |
+
inv_matrix = torch.linalg.inv(matrix) # b x 3 x 3
|
280 |
+
# inv_matrix = np.linalg.inv(matrix)
|
281 |
+
coords_homo = torch.bmm(
|
282 |
+
coords_homo, inv_matrix.permute(0, 2, 1)) # b x n x 3
|
283 |
+
return coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]
|
284 |
+
|
285 |
+
|
286 |
+
def tanh_warp_transform(
|
287 |
+
coords: torch.Tensor, matrix: torch.Tensor,
|
288 |
+
warp_factor: float, warped_shape: Tuple[int, int]):
|
289 |
+
""" Tanh-warp function.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
coords (torch.Tensor): b x n x 2 (x, y). The original coordinates.
|
293 |
+
matrix: b x 3 x 3. A matrix that transforms un-normalized coordinates
|
294 |
+
from the original image to the aligned yet not-warped image.
|
295 |
+
warp_factor (float): The warp factor.
|
296 |
+
0 means linear transform, 1 means full tanh warp.
|
297 |
+
warped_shape (tuple): [height, width].
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
torch.Tensor: b x n x 2 (x, y). The transformed coordinates.
|
301 |
+
"""
|
302 |
+
h, w, *_ = warped_shape
|
303 |
+
# h -= 1
|
304 |
+
# w -= 1
|
305 |
+
w_h = torch.tensor([[w, h]]).to(coords)
|
306 |
+
|
307 |
+
coords_homo = torch.cat(
|
308 |
+
[coords, torch.ones_like(coords[:, :, [0]])], dim=-1) # b x n x 3
|
309 |
+
|
310 |
+
coords_homo = torch.bmm(coords_homo, matrix.transpose(2, 1)) # b x n x 3
|
311 |
+
coords = (coords_homo[:, :, :2] / coords_homo[:, :, [2, 2]]) # b x n x 2
|
312 |
+
|
313 |
+
if warp_factor > 0:
|
314 |
+
# normalize coordinates to [-1, +1]
|
315 |
+
coords = coords / w_h * 2 - 1
|
316 |
+
|
317 |
+
nl_part1 = coords > 1.0 - warp_factor
|
318 |
+
nl_part2 = coords < -1.0 + warp_factor
|
319 |
+
|
320 |
+
ret_nl_part1 = torch.tanh(
|
321 |
+
(coords - 1.0 + warp_factor) /
|
322 |
+
warp_factor) * warp_factor + \
|
323 |
+
1.0 - warp_factor
|
324 |
+
ret_nl_part2 = torch.tanh(
|
325 |
+
(coords + 1.0 - warp_factor) /
|
326 |
+
warp_factor) * warp_factor - \
|
327 |
+
1.0 + warp_factor
|
328 |
+
|
329 |
+
coords = torch.where(nl_part1, ret_nl_part1,
|
330 |
+
torch.where(nl_part2, ret_nl_part2, coords))
|
331 |
+
|
332 |
+
# denormalize
|
333 |
+
coords = (coords + 1) / 2 * w_h
|
334 |
+
|
335 |
+
return coords
|
336 |
+
|
337 |
+
|
338 |
+
def make_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
|
339 |
+
warped_shape: Tuple[int, int],
|
340 |
+
orig_shape: Tuple[int, int]):
|
341 |
+
"""
|
342 |
+
Args:
|
343 |
+
matrix: bx3x3 matrix.
|
344 |
+
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
|
345 |
+
`warp_factor=0.0` represents a cropping.
|
346 |
+
warped_shape: The target image shape to transform to.
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
torch.Tensor: b x h x w x 2 (x, y).
|
350 |
+
"""
|
351 |
+
orig_h, orig_w, *_ = orig_shape
|
352 |
+
w_h = torch.tensor([orig_w, orig_h]).to(matrix).reshape(1, 1, 1, 2)
|
353 |
+
return _forge_grid(
|
354 |
+
matrix.size(0), matrix.device,
|
355 |
+
warped_shape,
|
356 |
+
functools.partial(inverted_tanh_warp_transform,
|
357 |
+
matrix=matrix,
|
358 |
+
warp_factor=warp_factor,
|
359 |
+
warped_shape=warped_shape)) / w_h*2-1
|
360 |
+
|
361 |
+
|
362 |
+
def make_inverted_tanh_warp_grid(matrix: torch.Tensor, warp_factor: float,
|
363 |
+
warped_shape: Tuple[int, int],
|
364 |
+
orig_shape: Tuple[int, int]):
|
365 |
+
"""
|
366 |
+
Args:
|
367 |
+
matrix: bx3x3 matrix.
|
368 |
+
warp_factor: The warping factor. `warp_factor=1.0` represents a vannila Tanh-warping,
|
369 |
+
`warp_factor=0.0` represents a cropping.
|
370 |
+
warped_shape: The target image shape to transform to.
|
371 |
+
orig_shape: The original image shape that is transformed from.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
torch.Tensor: b x h x w x 2 (x, y).
|
375 |
+
"""
|
376 |
+
h, w, *_ = warped_shape
|
377 |
+
w_h = torch.tensor([w, h]).to(matrix).reshape(1, 1, 1, 2)
|
378 |
+
return _forge_grid(
|
379 |
+
matrix.size(0), matrix.device,
|
380 |
+
orig_shape,
|
381 |
+
functools.partial(tanh_warp_transform,
|
382 |
+
matrix=matrix,
|
383 |
+
warp_factor=warp_factor,
|
384 |
+
warped_shape=warped_shape)) / w_h * 2-1
|
src/pixel3dmm/preprocessing/facer/facer/util.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Any, Optional, Union, List, Dict
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
import errno
|
7 |
+
import sys
|
8 |
+
import validators
|
9 |
+
import requests
|
10 |
+
import json
|
11 |
+
|
12 |
+
|
13 |
+
def hwc2bchw(images: torch.Tensor) -> torch.Tensor:
|
14 |
+
return images.unsqueeze(0).permute(0, 3, 1, 2)
|
15 |
+
|
16 |
+
|
17 |
+
def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2,
|
18 |
+
background_value: float = 0) -> torch.Tensor:
|
19 |
+
""" make a grid image from an image batch.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
images (torch.Tensor): input image batch.
|
23 |
+
nrows: rows of grid.
|
24 |
+
border: border size in pixel.
|
25 |
+
background_value: color value of background.
|
26 |
+
"""
|
27 |
+
assert images.ndim == 4 # n x c x h x w
|
28 |
+
images = images.permute(0, 2, 3, 1) # n x h x w x c
|
29 |
+
n, h, w, c = images.shape
|
30 |
+
if nrows is None:
|
31 |
+
nrows = max(int(math.sqrt(n)), 1)
|
32 |
+
ncols = (n + nrows - 1) // nrows
|
33 |
+
result = torch.full([(h + border) * nrows - border,
|
34 |
+
(w + border) * ncols - border, c], background_value,
|
35 |
+
device=images.device,
|
36 |
+
dtype=images.dtype)
|
37 |
+
|
38 |
+
for i, single_image in enumerate(images):
|
39 |
+
row = i // ncols
|
40 |
+
col = i % ncols
|
41 |
+
yy = (h + border) * row
|
42 |
+
xx = (w + border) * col
|
43 |
+
result[yy:(yy + h), xx:(xx + w), :] = single_image
|
44 |
+
return result
|
45 |
+
|
46 |
+
|
47 |
+
def bchw2bhwc(images: torch.Tensor) -> torch.Tensor:
|
48 |
+
return images.permute(0, 2, 3, 1)
|
49 |
+
|
50 |
+
|
51 |
+
def bhwc2bchw(images: torch.Tensor) -> torch.Tensor:
|
52 |
+
return images.permute(0, 3, 1, 2)
|
53 |
+
|
54 |
+
|
55 |
+
def bhwc2hwc(images: torch.Tensor, *kargs, **kwargs) -> torch.Tensor:
|
56 |
+
return bchw2hwc(bhwc2bchw(images), *kargs, **kwargs)
|
57 |
+
|
58 |
+
|
59 |
+
def select_data(selection, data):
|
60 |
+
if isinstance(data, dict):
|
61 |
+
return {name: select_data(selection, val) for name, val in data.items()}
|
62 |
+
elif isinstance(data, (list, tuple)):
|
63 |
+
return [select_data(selection, val) for val in data]
|
64 |
+
elif isinstance(data, torch.Tensor):
|
65 |
+
return data[selection]
|
66 |
+
return data
|
67 |
+
|
68 |
+
|
69 |
+
def download_from_github(to_path, organisation, repository, file_path, branch='main', username=None, access_token=None):
|
70 |
+
""" download files (including LFS files) from github.
|
71 |
+
|
72 |
+
For example, in order to downlod https://github.com/FacePerceiver/facer/blob/main/README.md, call with
|
73 |
+
```
|
74 |
+
download_from_github(
|
75 |
+
to_path='README.md', organisation='FacePerceiver',
|
76 |
+
repository='facer', file_path='README.md', branch='main')
|
77 |
+
```
|
78 |
+
"""
|
79 |
+
if username is not None:
|
80 |
+
assert access_token is not None
|
81 |
+
auth = (username, access_token)
|
82 |
+
else:
|
83 |
+
auth = None
|
84 |
+
r = requests.get(f'https://api.github.com/repos/{organisation}/{repository}/contents/{file_path}?ref={branch}',
|
85 |
+
auth=auth)
|
86 |
+
data = json.loads(r.content)
|
87 |
+
torch.hub.download_url_to_file(data['download_url'], to_path)
|
88 |
+
|
89 |
+
|
90 |
+
def is_github_url(url: str):
|
91 |
+
"""
|
92 |
+
A typical github url should be like
|
93 |
+
https://github.com/FacePerceiver/facer/blob/main/facer/util.py or
|
94 |
+
https://github.com/FacePerceiver/facer/raw/main/facer/util.py.
|
95 |
+
"""
|
96 |
+
return ('blob' in url or 'raw' in url) and url.startswith('https://github.com/')
|
97 |
+
|
98 |
+
|
99 |
+
def get_github_components(url: str):
|
100 |
+
assert is_github_url(url)
|
101 |
+
organisation, repository, blob_or_raw, branch, * \
|
102 |
+
path = url[len('https://github.com/'):].split('/')
|
103 |
+
assert blob_or_raw in {'blob', 'raw'}
|
104 |
+
return organisation, repository, branch, '/'.join(path)
|
105 |
+
|
106 |
+
|
107 |
+
def download_url_to_file(url, dst, **kwargs):
|
108 |
+
if is_github_url(url):
|
109 |
+
org, rep, branch, path = get_github_components(url)
|
110 |
+
download_from_github(dst, org, rep, path, branch, kwargs.get(
|
111 |
+
'username', None), kwargs.get('access_token', None))
|
112 |
+
else:
|
113 |
+
torch.hub.download_url_to_file(url, dst)
|
114 |
+
|
115 |
+
|
116 |
+
def select_data(selection, data):
|
117 |
+
if isinstance(data, dict):
|
118 |
+
return {name: select_data(selection, val) for name, val in data.items()}
|
119 |
+
elif isinstance(data, (list, tuple)):
|
120 |
+
return [select_data(selection, val) for val in data]
|
121 |
+
elif isinstance(data, torch.Tensor):
|
122 |
+
return data[selection]
|
123 |
+
return data
|
124 |
+
|
125 |
+
|
126 |
+
def download_jit(url_or_paths: Union[str, List[str]], model_dir=None, map_location=None, jit=True, **kwargs):
|
127 |
+
if isinstance(url_or_paths, str):
|
128 |
+
url_or_paths = [url_or_paths]
|
129 |
+
|
130 |
+
for url_or_path in url_or_paths:
|
131 |
+
try:
|
132 |
+
if validators.url(url_or_path):
|
133 |
+
url = url_or_path
|
134 |
+
if model_dir is None:
|
135 |
+
if hasattr(torch.hub, 'get_dir'):
|
136 |
+
hub_dir = torch.hub.get_dir()
|
137 |
+
else:
|
138 |
+
hub_dir = os.path.join(os.path.expanduser(
|
139 |
+
'~'), '.cache', 'torch', 'hub')
|
140 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
141 |
+
|
142 |
+
try:
|
143 |
+
os.makedirs(model_dir)
|
144 |
+
except OSError as e:
|
145 |
+
if e.errno == errno.EEXIST:
|
146 |
+
# Directory already exists, ignore.
|
147 |
+
pass
|
148 |
+
else:
|
149 |
+
# Unexpected OSError, re-raise.
|
150 |
+
raise
|
151 |
+
|
152 |
+
parts = urlparse(url)
|
153 |
+
filename = os.path.basename(parts.path)
|
154 |
+
cached_file = os.path.join(model_dir, filename)
|
155 |
+
if not os.path.exists(cached_file):
|
156 |
+
sys.stderr.write(
|
157 |
+
'Downloading: "{}" to {}\n'.format(url, cached_file))
|
158 |
+
download_url_to_file(url, cached_file)
|
159 |
+
else:
|
160 |
+
cached_file = url_or_path
|
161 |
+
if jit:
|
162 |
+
return torch.jit.load(cached_file, map_location=map_location, **kwargs)
|
163 |
+
else:
|
164 |
+
return torch.load(cached_file, map_location=map_location, **kwargs)
|
165 |
+
except:
|
166 |
+
sys.stderr.write(f'failed downloading from {url_or_path}\n')
|
167 |
+
raise
|
168 |
+
|
169 |
+
raise RuntimeError('failed to download jit models from all given urls')
|
src/pixel3dmm/preprocessing/facer/facer/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__="0.0.5"
|
src/pixel3dmm/preprocessing/facer/requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch >= 1.9.1
|
2 |
+
torchvision
|
3 |
+
timm
|
4 |
+
pillow
|
5 |
+
numpy
|
6 |
+
ipywidgets
|
7 |
+
scikit-image
|
8 |
+
matplotlib
|
9 |
+
validators
|
10 |
+
requests
|
11 |
+
opencv-python
|
src/pixel3dmm/preprocessing/facer/samples/data/ffhq_15723.jpg
ADDED
![]() |
Git LFS Details
|
src/pixel3dmm/preprocessing/facer/samples/data/fire.webp
ADDED
![]() |
src/pixel3dmm/preprocessing/facer/samples/data/girl.jpg
ADDED
![]() |
src/pixel3dmm/preprocessing/facer/samples/data/sideface.jpg
ADDED
![]() |
src/pixel3dmm/preprocessing/facer/samples/data/twogirls.jpg
ADDED
![]() |
src/pixel3dmm/preprocessing/facer/samples/data/weirdface.jpg
ADDED
![]() |
Git LFS Details
|
src/pixel3dmm/preprocessing/facer/samples/data/weirdface2.jpg
ADDED
![]() |
src/pixel3dmm/preprocessing/facer/samples/data/weirdface3.jpg
ADDED
![]() |
src/pixel3dmm/preprocessing/facer/samples/download.ipynb
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 5,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import sys\n",
|
10 |
+
"sys.path.append('..')"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 6,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"data": {
|
20 |
+
"application/vnd.jupyter.widget-view+json": {
|
21 |
+
"model_id": "03bf1e12ed8a4b4ebf4ee9c5acda4a4f",
|
22 |
+
"version_major": 2,
|
23 |
+
"version_minor": 0
|
24 |
+
},
|
25 |
+
"text/plain": [
|
26 |
+
" 0%| | 0.00/1.25k [00:00<?, ?B/s]"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
"metadata": {},
|
30 |
+
"output_type": "display_data"
|
31 |
+
}
|
32 |
+
],
|
33 |
+
"source": [
|
34 |
+
"import facer\n",
|
35 |
+
"facer.util.download_from_github(\n",
|
36 |
+
" to_path='.downloaded/download.ipynb', organisation='FacePerceiver',\n",
|
37 |
+
" repository='facer', file_path='samples/download.ipynb', branch='main')\n"
|
38 |
+
]
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"metadata": {
|
42 |
+
"interpreter": {
|
43 |
+
"hash": "f55a4f7b5691d52d9535fa31bc30af4661bee88c7bc930914ceb21aeaa908798"
|
44 |
+
},
|
45 |
+
"kernelspec": {
|
46 |
+
"display_name": "Python 3.8.12 ('haya38')",
|
47 |
+
"language": "python",
|
48 |
+
"name": "python3"
|
49 |
+
},
|
50 |
+
"language_info": {
|
51 |
+
"codemirror_mode": {
|
52 |
+
"name": "ipython",
|
53 |
+
"version": 3
|
54 |
+
},
|
55 |
+
"file_extension": ".py",
|
56 |
+
"mimetype": "text/x-python",
|
57 |
+
"name": "python",
|
58 |
+
"nbconvert_exporter": "python",
|
59 |
+
"pygments_lexer": "ipython3",
|
60 |
+
"version": "3.8.13"
|
61 |
+
},
|
62 |
+
"orig_nbformat": 4
|
63 |
+
},
|
64 |
+
"nbformat": 4,
|
65 |
+
"nbformat_minor": 2
|
66 |
+
}
|
src/pixel3dmm/preprocessing/facer/samples/example_output/alignment.png
ADDED
![]() |
Git LFS Details
|
src/pixel3dmm/preprocessing/facer/samples/example_output/detect.png
ADDED
![]() |
Git LFS Details
|
src/pixel3dmm/preprocessing/facer/samples/example_output/parsing.png
ADDED
![]() |
Git LFS Details
|
src/pixel3dmm/preprocessing/facer/samples/face_alignment.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/pixel3dmm/preprocessing/facer/samples/face_attribute.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/pixel3dmm/preprocessing/facer/samples/face_detect.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/pixel3dmm/preprocessing/facer/samples/face_parsing.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/pixel3dmm/preprocessing/facer/samples/transform.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/pixel3dmm/preprocessing/facer/scripts/build.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python setup.py bdist_wheel
|