Spaces:
Configuration error
Configuration error
Commit
·
8dc9718
1
Parent(s):
736c8f2
Add files with Git LFS support
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- .gitignore +14 -0
- DockerfileAPI +7 -0
- LICENSE +33 -0
- README.md +183 -11
- README_ZH.md +173 -0
- api.py +479 -0
- assets/.gitignore +2 -0
- assets/docs/API.md +41 -0
- assets/docs/API_ZH.md +47 -0
- assets/gradio/gradio_description_animate_clear.md +6 -0
- assets/gradio/gradio_description_animation.md +19 -0
- assets/gradio/gradio_description_retargeting.md +14 -0
- assets/gradio/gradio_description_upload.md +16 -0
- assets/gradio/gradio_title.md +19 -0
- assets/mask_template.png +0 -0
- camera.bat +32 -0
- configs/onnx_infer.yaml +114 -0
- configs/onnx_mp_infer.yaml +108 -0
- configs/trt_infer.yaml +114 -0
- configs/trt_mp_infer.yaml +108 -0
- requirements.txt +18 -0
- requirements_macos.txt +18 -0
- requirements_win.txt +17 -0
- run.py +322 -0
- scripts/all_onnx2trt.bat +29 -0
- scripts/all_onnx2trt.sh +17 -0
- scripts/all_onnx2trt_animal.sh +12 -0
- scripts/onnx2trt.py +161 -0
- scripts/start_api.sh +3 -0
- src/__init__.py +5 -0
- src/models/JoyVASA/__init__.py +6 -0
- src/models/JoyVASA/common.py +46 -0
- src/models/JoyVASA/dit_talking_head.py +538 -0
- src/models/JoyVASA/helper.py +32 -0
- src/models/JoyVASA/hubert.py +51 -0
- src/models/JoyVASA/wav2vec2.py +119 -0
- src/models/XPose/__init__.py +6 -0
- src/models/XPose/config_model/UniPose_SwinT.py +125 -0
- src/models/XPose/config_model/__init__.py +6 -0
- src/models/XPose/config_model/coco_transformer.py +8 -0
- src/models/XPose/models/UniPose/__init__.py +10 -0
- src/models/XPose/models/UniPose/attention.py +373 -0
- src/models/XPose/models/UniPose/backbone.py +211 -0
- src/models/XPose/models/UniPose/deformable_transformer.py +1230 -0
- src/models/XPose/models/UniPose/fuse_modules.py +276 -0
- src/models/XPose/models/UniPose/mask_generate.py +56 -0
- src/models/XPose/models/UniPose/ops/__init__.py +6 -0
- src/models/XPose/models/UniPose/ops/functions/__init__.py +10 -0
- src/models/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py +61 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.idea
|
3 |
+
*.pyc
|
4 |
+
.DS_Store
|
5 |
+
checkpoints
|
6 |
+
results
|
7 |
+
venv
|
8 |
+
*.egg-info
|
9 |
+
build
|
10 |
+
dist
|
11 |
+
*.eg
|
12 |
+
checkpoints_test
|
13 |
+
logs
|
14 |
+
third_party
|
DockerfileAPI
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM shaoguo/faster_liveportrait:v3
|
2 |
+
USER root
|
3 |
+
RUN mkdir -p /root/FasterLiveportrait
|
4 |
+
RUN chown -R /root/FasterLiveportrait
|
5 |
+
COPY . /root/FasterLiveportrait
|
6 |
+
WORKDIR /root/FasterLiveportrait
|
7 |
+
CMD ["/bin/bash && bash scripts/start_api.sh"]
|
LICENSE
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 warmshao
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
|
23 |
+
---
|
24 |
+
|
25 |
+
ADDITIONAL NOTICE FOR MODELS:
|
26 |
+
|
27 |
+
This repository may contain or reference machine learning models. These models
|
28 |
+
are subject to their respective licenses, which may differ from the MIT license
|
29 |
+
applied to the code in this repository. Users are responsible for complying
|
30 |
+
with the license terms of any models they use. This repository and its
|
31 |
+
maintainers assume no responsibility for model licensing compliance.
|
32 |
+
|
33 |
+
Please check the original source and license of each model before use.
|
README.md
CHANGED
@@ -1,11 +1,183 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## FasterLivePortrait: Bring portraits to life in Real Time!
|
2 |
+
<a href="README.md">English</a> | <a href="README_ZH.md">中文</a>
|
3 |
+
|
4 |
+
**Original repository: [LivePortrait](https://github.com/KwaiVGI/LivePortrait), thanks to the authors for sharing**
|
5 |
+
|
6 |
+
**New features:**
|
7 |
+
* Achieved real-time running of LivePortrait on RTX 3090 GPU using TensorRT, reaching speeds of 30+ FPS. This is the speed for rendering a single frame, including pre- and post-processing, not just the model inference speed.
|
8 |
+
* Seamless support for native gradio app, with several times faster speed and support for simultaneous inference on multiple faces and Animal Model.
|
9 |
+
* Added support for [JoyVASA](https://github.com/jdh-algo/JoyVASA), which can drive videos or images with audio.
|
10 |
+
|
11 |
+
**If you find this project useful, please give it a star ✨✨**
|
12 |
+
|
13 |
+
### Demo (Explore more features)
|
14 |
+
* Anyone want this? Fell free to contact me.
|
15 |
+
|
16 |
+
<video src="https://github.com/user-attachments/assets/554c37fc-d098-4938-a638-1660d85d222e" controls="controls" width="500" height="300">Your browser does not support this video!</video>
|
17 |
+
|
18 |
+
|
19 |
+
* Text-driven video, based on kokoro-82M:
|
20 |
+
|
21 |
+
<video src="https://github.com/user-attachments/assets/04e962e2-6c57-4d01-ae4a-2f6d2d501c5a" controls="controls" width="500" height="300">Your browser does not support this video!</video>
|
22 |
+
|
23 |
+
* Audio-driven video (real-time):
|
24 |
+
|
25 |
+
<video src="https://github.com/user-attachments/assets/98bb5ff7-0796-42db-9d7b-e04ddd2c3c14" controls="controls" width="500" height="300">Your browser does not support this video!</video>
|
26 |
+
|
27 |
+
* Animal-driven:
|
28 |
+
|
29 |
+
<video src="https://github.com/user-attachments/assets/dada0a92-593a-480b-a034-cbcce16e38b9" controls="controls" width="500" height="300">Your browser does not support this video!</video>
|
30 |
+
|
31 |
+
* Multiple faces driven simultaneously:
|
32 |
+
|
33 |
+
<video src="https://github.com/KwaiVGI/LivePortrait/assets/138360003/b37de35d-6feb-4100-b73f-58ac23121483" controls="controls" width="500" height="300">Your browser does not support this video!</video>
|
34 |
+
|
35 |
+
|
36 |
+
### Environment Setup
|
37 |
+
* Option 1 (recommended): If you are a Windows user, you can directly download the [integrated package](https://github.com/warmshao/FasterLivePortrait/releases/tag/v1.8).
|
38 |
+
* You need to install [git](https://git-scm.com/downloads) first, then double-click `update.bat` to update the code.
|
39 |
+
* Double-click `scripts/all_onnx2trt.bat` to convert onnx files to tensorrt files.
|
40 |
+
* Double-click `webui.bat` to open the webpage, or double-click `camera.bat` to open the camera for real-time operation.
|
41 |
+
* Option 2: Docker.A docker image is provided for eliminating the need to install onnxruntime-gpu and TensorRT manually.
|
42 |
+
* Install [Docker](https://docs.docker.com/desktop/install/windows-install/) according to your system
|
43 |
+
* Download the image: `docker pull shaoguo/faster_liveportrait:v3`
|
44 |
+
* Execute the command, replace `$FasterLivePortrait_ROOT` with the local directory where you downloaded FasterLivePortrait:
|
45 |
+
```shell
|
46 |
+
docker run -it --gpus=all \
|
47 |
+
--name faster_liveportrait \
|
48 |
+
-v $FasterLivePortrait_ROOT:/root/FasterLivePortrait \
|
49 |
+
--restart=always \
|
50 |
+
-p 9870:9870 \
|
51 |
+
shaoguo/faster_liveportrait:v3 \
|
52 |
+
/bin/bash
|
53 |
+
```
|
54 |
+
* Option 3: Create a new Python virtual environment and install the necessary Python packages manually.
|
55 |
+
* First, install [ffmpeg](https://www.ffmpeg.org/download.html)
|
56 |
+
* Run `pip install -r requirements.txt`
|
57 |
+
* Then follow the tutorials below to install onnxruntime-gpu or TensorRT. Note that this has only been tested on Linux systems.
|
58 |
+
|
59 |
+
### Usage
|
60 |
+
#### 1. TensorRT Inference(Recommended)
|
61 |
+
* (Ignored in Docker) Install TensorRT 8.x (versions >=10.x are not compatible). Remember the installation path of [TensorRT](https://developer.nvidia.com/tensorrt).
|
62 |
+
* (Ignored in Docker) Install the grid_sample TensorRT plugin, as the model uses grid sample that requires 5D input, which is not supported by the native grid_sample operator.
|
63 |
+
* `git clone https://github.com/SeanWangJS/grid-sample3d-trt-plugin`
|
64 |
+
* Modify line 30 in `CMakeLists.txt` to: `set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "60;70;75;80;86")`
|
65 |
+
* `export PATH=/usr/local/cuda/bin:$PATH`
|
66 |
+
* `mkdir build && cd build`
|
67 |
+
* `cmake .. -DTensorRT_ROOT=$TENSORRT_HOME`, replace $TENSORRT_HOME with your own TensorRT root directory.
|
68 |
+
* `make`, remember the address of the .so file, replace `/opt/grid-sample3d-trt-plugin/build/libgrid_sample_3d_plugin.so` in `scripts/onnx2trt.py` and `src/models/predictor.py` with your own .so file path
|
69 |
+
* Download ONNX model files:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`. Convert all ONNX models to TensorRT, run `sh scripts/all_onnx2trt.sh` and `sh scripts/all_onnx2trt_animal.sh`
|
70 |
+
* Test the pipeline using tensorrt:
|
71 |
+
```shell
|
72 |
+
python run.py \
|
73 |
+
--src_image assets/examples/source/s10.jpg \
|
74 |
+
--dri_video assets/examples/driving/d14.mp4 \
|
75 |
+
--cfg configs/trt_infer.yaml
|
76 |
+
* To run in real-time using a camera:
|
77 |
+
```shell
|
78 |
+
python run.py \
|
79 |
+
--src_image assets/examples/source/s10.jpg \
|
80 |
+
--dri_video 0 \
|
81 |
+
--cfg configs/trt_infer.yaml \
|
82 |
+
--realtime
|
83 |
+
```
|
84 |
+
|
85 |
+
#### 2. Onnxruntime Inference
|
86 |
+
* First, download the converted onnx model files:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`.
|
87 |
+
* (Ignored in Docker)If you want to use onnxruntime cpu inference, simply `pip install onnxruntime`. However, cpu inference is extremely slow and not recommended. The latest onnxruntime-gpu still doesn't support grid_sample cuda, but I found a branch that supports it. Follow these steps to install `onnxruntime-gpu` from source:
|
88 |
+
* `git clone https://github.com/microsoft/onnxruntime`
|
89 |
+
* `git checkout liqun/ImageDecoder-cuda`. Thanks to liqun for the grid_sample with cuda implementation!
|
90 |
+
* Run the following commands to compile, changing `cuda_version` and `CMAKE_CUDA_ARCHITECTURES` according to your machine (your cuDNN version must be 8.x, 9.x is not compatible):
|
91 |
+
```shell
|
92 |
+
./build.sh --parallel \
|
93 |
+
--build_shared_lib --use_cuda \
|
94 |
+
--cuda_version 11.8 \
|
95 |
+
--cuda_home /usr/local/cuda --cudnn_home /usr/local/cuda/ \
|
96 |
+
--config Release --build_wheel --skip_tests \
|
97 |
+
--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="60;70;75;80;86" \
|
98 |
+
--cmake_extra_defines CMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
|
99 |
+
--disable_contrib_ops \
|
100 |
+
--allow_running_as_root
|
101 |
+
```
|
102 |
+
* `pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl`
|
103 |
+
* Test the pipeline using onnxruntime:
|
104 |
+
```
|
105 |
+
python run.py \
|
106 |
+
--src_image assets/examples/source/s10.jpg \
|
107 |
+
--dri_video assets/examples/driving/d14.mp4 \
|
108 |
+
--cfg configs/onnx_infer.yaml
|
109 |
+
```
|
110 |
+
|
111 |
+
|
112 |
+
### Gradio WebUI
|
113 |
+
* onnxruntime: `python webui.py --mode onnx`
|
114 |
+
* tensorrt: `python webui.py --mode trt`
|
115 |
+
* The default port is 9870. Open the webpage: `http://localhost:9870/`
|
116 |
+
|
117 |
+
Hotkeys for webcam mode (when render window is on focus)\
|
118 |
+
Q > exit\
|
119 |
+
S > Stitching\
|
120 |
+
Z > RelativeMotion\
|
121 |
+
X > AnimationRegion\
|
122 |
+
C > CropDrivingVideo\
|
123 |
+
K,L > AdjustSourceScale\
|
124 |
+
N,M > AdjustDriverScale
|
125 |
+
|
126 |
+
## License
|
127 |
+
|
128 |
+
- **Code**: This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
129 |
+
- **Models**: Any machine learning models used in this project are subject to their respective licenses. Please refer to the original model sources for license information. We do not take responsibility for model license compliance.
|
130 |
+
|
131 |
+
|
132 |
+
**Changelog**
|
133 |
+
- [x] **2025/06/29:** LivePortrait animal v1.1 onnx models are available. Download from [this](https://huggingface.co/warmshao/FasterLivePortrait/tree/main/liveportrait_animal_onnx_v1.1).
|
134 |
+
- [x] **2024/12/22:** Add API Deployment `python api.py`, For more information, please refer to the [tutorial](assets/docs/API.md).
|
135 |
+
- [x] **2024/12/21:** Added support for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M), enabling text-driven video or image generation.
|
136 |
+
- Updated code: `git pull origin master` and install the latest Python dependencies `pip install requirements.txt`, or simply double-click `update.bat` on Windows.
|
137 |
+
- Download the model: `huggingface-cli download hexgrad/Kokoro-82M --local-dir .\checkpoints\Kokoro-82M`.
|
138 |
+
- For Linux, install `espeak-ng`: `apt-get -qq -y install espeak-ng > /dev/null 2>&1`
|
139 |
+
- For Windows, refer to [manual installation instructions](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12) and configure the `espeak-ng` environment variables. The current read location is [here](src/pipelines/gradio_live_portrait_pipeline.py:437); modify it if your installation path differs.
|
140 |
+
- Now you can use it normally in the "Drive Text" tab.
|
141 |
+
- [x] **2024/12/16:** Added support for [JoyVASA](https://github.com/jdh-algo/JoyVASA), which can drive videos or images with audio. Very cool!
|
142 |
+
- Update code, then download the models: `huggingface-cli download TencentGameMate/chinese-hubert-base --local-dir .\checkpoints\chinese-hubert-base` and `huggingface-cli download jdh-algo/JoyVASA --local-dir ./checkpoints/JoyVASA`
|
143 |
+
- After launching the webui, follow the tutorial below. When the source is a video, it's recommended to only drive the mouth movements
|
144 |
+
|
145 |
+
<video src="https://github.com/user-attachments/assets/42fb24be-0cde-4138-9671-e52eec95e7f5" controls="controls" width="500" height="400">您的浏览器不支持播放该视频!</video>
|
146 |
+
|
147 |
+
- [x] **2024/12/14:** Added pickle and image driving, as well as region driving animation_region.
|
148 |
+
- Please update the latest code. Windows users can directly double-click `update.bat` to update, but note that your local code will be overwritten.
|
149 |
+
- Running `python run.py` now automatically saves the corresponding pickle to the same directory as the driving video, allowing for direct reuse.
|
150 |
+
- After opening webui, you can experience the new pickle and image driving, as well as the region driving animation_region features. Note that for image driving, remember to disable `relative motion`.
|
151 |
+
- [x] **2024/08/11:** Optimized paste_back speed and fixed some bugs.
|
152 |
+
- Used torchgeometry + cuda to optimize the paste_back function, significantly improving speed. Example: `python run.py --src_image assets/examples/source/s39.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_infer.yaml --paste_back --animal`
|
153 |
+
- Fixed issues with Xpose ops causing errors on some GPUs and other bugs. Please use the latest docker image: `docker pull shaoguo/faster_liveportrait:v3`
|
154 |
+
- [x] **2024/08/11:** Optimized paste_back speed and fixed some bugs.
|
155 |
+
- Used torchgeometry + cuda to optimize the paste_back function, significantly improving speed. Example: `python run.py --src_image assets/examples/source/s39.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_infer.yaml --paste_back --animal`
|
156 |
+
- Fixed issues with Xpose ops causing errors on some GPUs and other bugs. Please use the latest docker image: `docker pull shaoguo/faster_liveportrait:v3`
|
157 |
+
- [x] **2024/08/07:** Added support for animal models and MediaPipe models, so you no longer need to worry about copyright issues.
|
158 |
+
- Added support for animal models.
|
159 |
+
- Download the animal ONNX file: `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`, then convert it to TRT format.
|
160 |
+
- Update the Docker image: `docker pull shaoguo/faster_liveportrait:v3`. Using animal model:`python run.py --src_image assets/examples/source/s39.jpg --dri_video 0 --cfg configs/trt_infer.yaml --realtime --animal`
|
161 |
+
- Windows users can download the latest [Windows all-in-one package](https://github.com/warmshao/FasterLivePortrait/releases) from the release page, then unzip and use it.
|
162 |
+
- Simple usage tutorial:
|
163 |
+
|
164 |
+
<video src="https://github.com/user-attachments/assets/dc37e2dd-551a-43b0-8929-fc5d5fe16ec5" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
|
165 |
+
|
166 |
+
- Using MediaPipe model to replace InsightFace
|
167 |
+
- For web usage: `python webui.py --mode trt --mp` or `python webui.py --mode onnx --mp`
|
168 |
+
- For local webcam: `python run.py --src_image assets/examples/source/s12.jpg --dri_video 0 --cfg configs/trt_mp_infer.yaml`
|
169 |
+
- [x] **2024/07/24:** Windows integration package, no installation required, one-click run, supports TensorRT and OnnxruntimeGPU. Thanks to @zhanghongyong123456 for their contribution in this [issue](https://github.com/warmshao/FasterLivePortrait/issues/22).
|
170 |
+
- [Optional] If you have already installed CUDA and cuDNN on your Windows computer, please skip this step. I have only verified on CUDA 12.2. If you haven't installed CUDA or encounter CUDA-related errors, you need to follow these steps:
|
171 |
+
- Download [CUDA 12.2](https://developer.nvidia.com/cuda-12-2-0-download-archive?target_os=Windows&target_arch=x86_64), double-click the exe and install following the default settings step by step.
|
172 |
+
- Download the [cuDNN](https://developer.nvidia.com/downloads/compute/cudnn/secure/8.9.7/local_installers/12.x/cudnn-windows-x86_64-8.9.7.29_cuda12-archive.zip) zip file, extract it, and copy the lib, bin, and include folders from the cuDNN folder to the CUDA 12.2 folder (default is C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2)
|
173 |
+
- Download the installation-free [Windows integration package](https://github.com/warmshao/FasterLivePortrait/releases) from the release page and extract it.
|
174 |
+
- Enter `FasterLivePortrait-windows` and double-click `scripts/all_onnx2trt.bat` to convert onnx files, which will take some time.
|
175 |
+
- For web demo: Double-click `webui.bat`, open the webpage: `http://localhost:9870/`
|
176 |
+
- For real-time camera operation, double-click `camera.bat`,press `q` to stop. If you want to change the target image, run in command line: `camera.bat assets/examples/source/s9.jpg`
|
177 |
+
- [x] **2024/07/18:** macOS support added(No need for Docker, Python is enough). M1/M2 chips are faster, but it's still quite slow 😟
|
178 |
+
- Install ffmpeg: `brew install ffmpeg`
|
179 |
+
- Set up a Python 3.10 virtual environment. Recommend using [miniforge](https://github.com/conda-forge/miniforge): `conda create -n flip python=3.10 && conda activate flip`
|
180 |
+
- Install requirements: `pip install -r requirements_macos.txt`
|
181 |
+
- Download ONNX files: `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`
|
182 |
+
- Test: `python webui.py --mode onnx`
|
183 |
+
- [x] **2024/07/17:** Added support for Docker environment, providing a runnable image.
|
README_ZH.md
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## FasterLivePortrait:Bring portrait to life in Real Time!
|
2 |
+
<a href="README.md">English</a> | <a href="README_ZH.md">中文</a>
|
3 |
+
|
4 |
+
**原仓库: [LivePortrait](https://github.com/KwaiVGI/LivePortrait),感谢作者的分享**
|
5 |
+
|
6 |
+
**新增功能:**
|
7 |
+
* 通过TensorRT实现在RTX 3090显卡上**实时**运行LivePortrait,速度达到 30+ FPS. 这个速度是实测渲染出一帧的速度,而不仅仅是模型的推理时间。
|
8 |
+
* 无缝支持原生的gradio app, 速度快了好几倍,同时支持多张人脸、Animal模型。
|
9 |
+
* 增加[JoyVASA](https://github.com/jdh-algo/JoyVASA)的支持,可以用音频驱动视频或图片。
|
10 |
+
|
11 |
+
**如果你觉得这个项目有用,帮我点个star吧✨✨**
|
12 |
+
|
13 |
+
### Demo(还有很多功能等你探索)
|
14 |
+
* 文本驱动视频,基于kokoro-82M:
|
15 |
+
|
16 |
+
<video src="https://github.com/user-attachments/assets/04e962e2-6c57-4d01-ae4a-2f6d2d501c5a" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
|
17 |
+
* 声音驱动视频(可以实时):
|
18 |
+
|
19 |
+
<video src="https://github.com/user-attachments/assets/98bb5ff7-0796-42db-9d7b-e04ddd2c3c14" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
|
20 |
+
* 动物驱动:
|
21 |
+
|
22 |
+
<video src="https://github.com/user-attachments/assets/dada0a92-593a-480b-a034-cbcce16e38b9" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
|
23 |
+
* 多张人脸同时驱动:
|
24 |
+
|
25 |
+
<video src="https://github.com/KwaiVGI/LivePortrait/assets/138360003/b37de35d-6feb-4100-b73f-58ac23121483" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
|
26 |
+
|
27 |
+
|
28 |
+
### 环境安装
|
29 |
+
* 方式1:如果你是Windows用户,推荐可以直接下载[整合包](https://github.com/warmshao/FasterLivePortrait/releases/tag/v1.8)。
|
30 |
+
* 需要先安装好[git](https://git-scm.com/downloads), 双击`update.bat`更新代码。
|
31 |
+
* 双击`scripts/all_onnx2trt.bat`转换onnx文件为tensorrt文件。
|
32 |
+
* 双击`webui.bat`打开网页,或者双击`camera.bat`打开摄像头实时运行。
|
33 |
+
* 方式2:Docker,提供了一个镜像,不用再自己安装onnxruntime-gpu和TensorRT。
|
34 |
+
* 根据自己的系统安装[docker](https://docs.docker.com/desktop/install/windows-install/)
|
35 |
+
* 下载镜像:`docker pull shaoguo/faster_liveportrait:v3`
|
36 |
+
* 执行命令, `$FasterLivePortrait_ROOT`要替换成你下载的FasterLivePortrait在本地的目录:
|
37 |
+
```shell
|
38 |
+
docker run -it --gpus=all \
|
39 |
+
--name faster_liveportrait \
|
40 |
+
-v $FasterLivePortrait_ROOT:/root/FasterLivePortrait \
|
41 |
+
--restart=always \
|
42 |
+
-p 9870:9870 \
|
43 |
+
shaoguo/faster_liveportrait:v3 \
|
44 |
+
/bin/bash
|
45 |
+
```
|
46 |
+
* 然后可以根据下面Onnxruntime 推理和TensorRT 推理教程进行使用。
|
47 |
+
|
48 |
+
* 方式3:新建一个python虚拟环境,自己安装必要的python包
|
49 |
+
* 请先安装[ffmpeg](https://www.ffmpeg.org/download.html)
|
50 |
+
* `pip install -r requirements.txt`
|
51 |
+
* 再根据以下教程安装onnxruntime-gpu或TensorRT。
|
52 |
+
|
53 |
+
### 使用方法
|
54 |
+
#### 1. TensorRT 推理(推荐, 可以实时)
|
55 |
+
* (Docker环境可忽略)安装TensorRT,请记住[TensorRT](https://developer.nvidia.com/tensorrt)安装的路径。
|
56 |
+
* (Docker环境可忽略)安装 grid_sample的tensorrt插件,因为模型用到的grid sample需要有5d的输入,原生的grid_sample 算子不支持。
|
57 |
+
* `git clone https://github.com/SeanWangJS/grid-sample3d-trt-plugin`
|
58 |
+
* 修改`CMakeLists.txt`中第30行为:`set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "60;70;75;80;86")`
|
59 |
+
* `export PATH=/usr/local/cuda/bin:$PATH`
|
60 |
+
* `mkdir build && cd build`
|
61 |
+
* `cmake .. -DTensorRT_ROOT=$TENSORRT_HOME`,$TENSORRT_HOME 替换成你自己TensorRT的根目录。
|
62 |
+
* `make`,记住so文件的地址,将`scripts/onnx2trt.py`和`src/models/predictor.py`里`/opt/grid-sample3d-trt-plugin/build/libgrid_sample_3d_plugin.so`替换成自己的so路径
|
63 |
+
* 下载Onnx文件:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`。将onnx模型转为tensorrt,运行`sh scripts/all_onnx2trt.sh`和`sh scripts/all_onnx2trt_animal.sh`
|
64 |
+
* 用tensorrt测试pipeline:
|
65 |
+
```shell
|
66 |
+
python run.py \
|
67 |
+
--src_image assets/examples/source/s10.jpg \
|
68 |
+
--dri_video assets/examples/driving/d14.mp4 \
|
69 |
+
--cfg configs/trt_infer.yaml
|
70 |
+
```
|
71 |
+
如果要使用摄像头实时运行:
|
72 |
+
```shell
|
73 |
+
python run.py \
|
74 |
+
--src_image assets/examples/source/s10.jpg \
|
75 |
+
--dri_video 0 \
|
76 |
+
--cfg configs/trt_infer.yaml \
|
77 |
+
--realtime
|
78 |
+
```
|
79 |
+
#### 2. Onnxruntime 推理
|
80 |
+
* 首先下载我转换好的[模型onnx文件](https://huggingface.co/warmshao/FasterLivePortrait): `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`。
|
81 |
+
* (Docker环境可忽略)如果你要用onnxruntime cpu推理的话,直接`pip install onnxruntime`即可,但是cpu推理超级慢。但是最新的onnxruntime-gpu仍然无法支持grid_sample cuda,好在我看到一位大佬在分支上支持了,按照以下步骤源码安装`onnxruntime-gpu`:
|
82 |
+
* `git clone https://github.com/microsoft/onnxruntime`
|
83 |
+
* `git checkout liqun/ImageDecoder-cuda`. Thanks for liqun's grid_sample with cuda implementation!
|
84 |
+
* 运行以下命令编译,`cuda_version`和`CMAKE_CUDA_ARCHITECTURES`根据自己的机器更改:
|
85 |
+
```shell
|
86 |
+
./build.sh --parallel \
|
87 |
+
--build_shared_lib --use_cuda \
|
88 |
+
--cuda_version 11.8 \
|
89 |
+
--cuda_home /usr/local/cuda --cudnn_home /usr/local/cuda/ \
|
90 |
+
--config Release --build_wheel --skip_tests \
|
91 |
+
--cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="60;70;75;80;86" \
|
92 |
+
--cmake_extra_defines CMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
|
93 |
+
--disable_contrib_ops \
|
94 |
+
--allow_running_as_root
|
95 |
+
```
|
96 |
+
* `pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl`就可以了
|
97 |
+
* 用onnxruntime测试pipeline:
|
98 |
+
```shell
|
99 |
+
python run.py \
|
100 |
+
--src_image assets/examples/source/s10.jpg \
|
101 |
+
--dri_video assets/examples/driving/d14.mp4 \
|
102 |
+
--cfg configs/onnx_infer.yaml
|
103 |
+
```
|
104 |
+
|
105 |
+
### Gradio WebUI
|
106 |
+
* onnxruntime: `python webui.py --mode onnx`
|
107 |
+
* tensorrt: `python webui.py --mode trt`
|
108 |
+
* 默认端口在9870,打开网页:`http://localhost:9870/`
|
109 |
+
|
110 |
+
Hotkeys for webcam mode (when render window is on focus)\
|
111 |
+
Q > exit\
|
112 |
+
S > Stitching\
|
113 |
+
Z > RelativeMotion\
|
114 |
+
X > AnimationRegion\
|
115 |
+
C > CropDrivingVideo\
|
116 |
+
K,L > AdjustSourceScale\
|
117 |
+
N,M > AdjustDriverScale
|
118 |
+
|
119 |
+
## 许可证
|
120 |
+
|
121 |
+
- **代码**: 本项目采用 MIT 许可证 - 详细信息请查看 [LICENSE](LICENSE) 文件。
|
122 |
+
- **模型**: 本项目中使用的任何机器学习模型均遵循其各自的许可证。请参考原始模型来源获取许可证信息。我们不承担模型许可证合规性的责任。
|
123 |
+
|
124 |
+
|
125 |
+
**日志**
|
126 |
+
- [x] **2025/06/29:** [LivePortrait animal v1.1 onnx模型](https://huggingface.co/warmshao/FasterLivePortrait/tree/main/liveportrait_animal_onnx_v1.1)。
|
127 |
+
- [x] **2024/12/22:** 增加api部署`python api.py`, 其他参考[教程](assets/docs/API_ZH.md)使用。
|
128 |
+
- [x] **2024/12/21:** 增加[Kokoro-82M](hhttps://huggingface.co/hexgrad/Kokoro-82M)的支持,可以用文本驱动视频或图片。
|
129 |
+
- 更新代码, `git pull origin master`并安装最新的python依赖 `pip install requirements.txt`, 或者 windows下直接双击 `update.bat`.
|
130 |
+
- 然后下载模型: `huggingface-cli download hexgrad/Kokoro-82M --local-dir .\checkpoints\Kokoro-82M`.
|
131 |
+
- 如果是Linux请安装`apt-get -qq -y install espeak-ng > /dev/null 2>&1`
|
132 |
+
- 如果是windows请参考[自行安装](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12)并配置好`espeak-ng`环境变量。我是在[这里](src/pipelines/gradio_live_portrait_pipeline.py:437)读取,如果你的位置变了,请自行修改。
|
133 |
+
- 然后就可以在Drive Text的标签页正常使用了。
|
134 |
+
- [x] **2024/12/16:** 增加[JoyVASA](https://github.com/jdh-algo/JoyVASA)的支持,可以用音频驱动视频或图片。非常酷!
|
135 |
+
- 更新代码,然后下载模型: `huggingface-cli download TencentGameMate/chinese-hubert-base --local-dir .\checkpoints\chinese-hubert-base` 和 ` huggingface-cli download jdh-algo/JoyVASA --local-dir ./checkpoints/JoyVASA`
|
136 |
+
- 启动webui后根据以下教程使用即可,建议source 是视频的情况下只驱动嘴部
|
137 |
+
|
138 |
+
<video src="https://github.com/user-attachments/assets/42fb24be-0cde-4138-9671-e52eec95e7f5" controls="controls" width="500" height="400">您的浏览器不支持播放该视频!</video>
|
139 |
+
|
140 |
+
- [x] **2024/12/14:** 增加pickle和image驱动以及区域驱动`animation_region`。
|
141 |
+
- 请更新最新的代码,windows用户可以直接双击`update.bat`更新,但请注意本地的代码将会被覆盖。
|
142 |
+
- `python run.py ` 现在运行 `driving video`会自动保存对应的pickle到跟`driving video`一样的目录,可以直接复用。
|
143 |
+
- 打开`webui`后即可体验新的pickle和image驱动以及区域驱动`animation_region`等功能。注意image驱动记得把`relative motion`取消掉。
|
144 |
+
- [x] **2024/08/11:** 优化paste_back的速度,修复一些bug。
|
145 |
+
- 用torchgeometry + cuda优化paste_back函数,现在速度提升了很多。示例:`python run.py --src_image assets/examples/source/s39.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_infer.yaml --paste_back --animal`
|
146 |
+
- 修复Xpose的ops在一些显卡运行报错的问题等bug。请使用最新的镜像:`docker pull shaoguo/faster_liveportrait:v3`
|
147 |
+
- [x] **2024/08/07:** 增加animal模型的支持,同时支持mediapipe模型,现在你不用再担心版权的问题。
|
148 |
+
- 增加对animal模型的支持。
|
149 |
+
- 需要下载animal的onnx文件:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`,然后转换成trt文件。
|
150 |
+
- 更新镜像`docker pull shaoguo/faster_liveportrait:v3`, 使用animal模型的示例:`python run.py --src_image assets/examples/source/s39.jpg --dri_video 0 --cfg configs/trt_infer.yaml --realtime --animal`
|
151 |
+
- windows系统可以从release页下载最新的[windows 整合包](https://github.com/warmshao/FasterLivePortrait/releases),解压后使用。
|
152 |
+
- 简单的使用教程:
|
153 |
+
|
154 |
+
<video src="https://github.com/user-attachments/assets/dc37e2dd-551a-43b0-8929-fc5d5fe16ec5" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
|
155 |
+
|
156 |
+
- 使用mediapipe模型替代insight_face
|
157 |
+
- 网页端使用: `python webui.py --mode trt --mp` 或 `python webui.py --mode onnx --mp`
|
158 |
+
- 本地摄像头运行: `python run.py --src_image assets/examples/source/s12.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_mp_infer.yaml`
|
159 |
+
- [x] **2024/07/24:** Windows的整合包, 免安装一键运行,支持TensorRT和OnnxruntimeGPU。感谢@zhanghongyong123456在[issue](https://github.com/warmshao/FasterLivePortrait/issues/22)的贡献。
|
160 |
+
- 【可选】如果你的windows电脑已经装过cuda和cudnn,请忽略这一步。我只在cuda12.2上验证过,如果没安装cuda或报cuda相关的错,你需要按照以下步骤进行安装:
|
161 |
+
- 下载[cuda12.2](https://developer.nvidia.com/cuda-12-2-0-download-archive?target_os=Windows&target_arch=x86_64), 双击exe后按照默认设置一步步安装即可。
|
162 |
+
- 下载[cudnn](https://developer.nvidia.com/downloads/compute/cudnn/secure/8.9.7/local_installers/12.x/cudnn-windows-x86_64-8.9.7.29_cuda12-archive.zip) 压缩包,解压后将cudnn 文件夹下的lib、bin、include 文件夹复制到 CUDA12.2 文件夹下(默认为C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2)
|
163 |
+
- 从release页下载免安装[windows 整合包](https://github.com/warmshao/FasterLivePortrait/releases)并解压。
|
164 |
+
- 进入`FasterLivePortrait-windows`后双击`scripts/all_onnx2trt.bat`对onnx文件进行转换,这会等上一段时间。
|
165 |
+
- 网页端demo:双击`webui.bat`, 打开网页:`http://localhost:9870/`
|
166 |
+
- 摄像头实时运行,双击`camera.bat`,按`q`停止。如果你想更换目标图像,命令行运行:`camera.bat assets/examples/source/s9.jpg`。
|
167 |
+
- [x] **2024/07/18:** MacOS支持(不需要Docker,python就可以了),M1/M2的速度比较快,但还是很慢😟
|
168 |
+
- 安装ffmpeg: `brew install ffmpeg`
|
169 |
+
- 安装python=3.10的虚拟环境,推荐可以用[miniforge](https://github.com/conda-forge/miniforge).`conda create -n flip python=3.10 && conda activate flip`
|
170 |
+
- `pip install -r requirements_macos.txt`
|
171 |
+
- 下载onnx文件: `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`
|
172 |
+
- 测试: `python webui.py --mode onnx`
|
173 |
+
- [x] **2024/07/17:** 增加docker环境的支持,提供可运行的镜像。
|
api.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/9/13 0:23
|
3 |
+
# @Project : FasterLivePortrait
|
4 |
+
# @FileName: api.py
|
5 |
+
import pdb
|
6 |
+
import shutil
|
7 |
+
from typing import Optional, Dict, Any
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import uvicorn
|
12 |
+
import cv2
|
13 |
+
import time
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
import datetime
|
17 |
+
import platform
|
18 |
+
import pickle
|
19 |
+
from tqdm import tqdm
|
20 |
+
from pydantic import BaseModel
|
21 |
+
from fastapi import APIRouter, Depends, FastAPI, Request, Response, UploadFile
|
22 |
+
from fastapi import File, Body, Form
|
23 |
+
from omegaconf import OmegaConf
|
24 |
+
from fastapi.responses import StreamingResponse
|
25 |
+
from zipfile import ZipFile
|
26 |
+
from src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
|
27 |
+
from src.utils.utils import video_has_audio
|
28 |
+
from src.utils import logger
|
29 |
+
|
30 |
+
# model dir
|
31 |
+
project_dir = os.path.dirname(__file__)
|
32 |
+
checkpoints_dir = os.environ.get("FLIP_CHECKPOINT_DIR", os.path.join(project_dir, "checkpoints"))
|
33 |
+
log_dir = os.path.join(project_dir, "logs")
|
34 |
+
os.makedirs(log_dir, exist_ok=True)
|
35 |
+
result_dir = os.path.join(project_dir, "results")
|
36 |
+
os.makedirs(result_dir, exist_ok=True)
|
37 |
+
|
38 |
+
logger_f = logger.get_logger("faster_liveportrait_api", log_file=os.path.join(log_dir, "log_run.log"))
|
39 |
+
|
40 |
+
app = FastAPI()
|
41 |
+
|
42 |
+
global pipe
|
43 |
+
|
44 |
+
if platform.system().lower() == 'windows':
|
45 |
+
FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe"
|
46 |
+
else:
|
47 |
+
FFMPEG = "ffmpeg"
|
48 |
+
|
49 |
+
|
50 |
+
def check_all_checkpoints_exist(infer_cfg):
|
51 |
+
"""
|
52 |
+
check whether all checkpoints exist
|
53 |
+
:return:
|
54 |
+
"""
|
55 |
+
ret = True
|
56 |
+
for name in infer_cfg.models:
|
57 |
+
if not isinstance(infer_cfg.models[name].model_path, str):
|
58 |
+
for i in range(len(infer_cfg.models[name].model_path)):
|
59 |
+
infer_cfg.models[name].model_path[i] = infer_cfg.models[name].model_path[i].replace("./checkpoints",
|
60 |
+
checkpoints_dir)
|
61 |
+
if not os.path.exists(infer_cfg.models[name].model_path[i]) and not os.path.exists(
|
62 |
+
infer_cfg.models[name].model_path[i][:-4] + ".onnx"):
|
63 |
+
return False
|
64 |
+
else:
|
65 |
+
infer_cfg.models[name].model_path = infer_cfg.models[name].model_path.replace("./checkpoints",
|
66 |
+
checkpoints_dir)
|
67 |
+
if not os.path.exists(infer_cfg.models[name].model_path) and not os.path.exists(
|
68 |
+
infer_cfg.models[name].model_path[:-4] + ".onnx"):
|
69 |
+
return False
|
70 |
+
for name in infer_cfg.animal_models:
|
71 |
+
if not isinstance(infer_cfg.animal_models[name].model_path, str):
|
72 |
+
for i in range(len(infer_cfg.animal_models[name].model_path)):
|
73 |
+
infer_cfg.animal_models[name].model_path[i] = infer_cfg.animal_models[name].model_path[i].replace(
|
74 |
+
"./checkpoints",
|
75 |
+
checkpoints_dir)
|
76 |
+
if not os.path.exists(infer_cfg.animal_models[name].model_path[i]) and not os.path.exists(
|
77 |
+
infer_cfg.animal_models[name].model_path[i][:-4] + ".onnx"):
|
78 |
+
return False
|
79 |
+
else:
|
80 |
+
infer_cfg.animal_models[name].model_path = infer_cfg.animal_models[name].model_path.replace("./checkpoints",
|
81 |
+
checkpoints_dir)
|
82 |
+
if not os.path.exists(infer_cfg.animal_models[name].model_path) and not os.path.exists(
|
83 |
+
infer_cfg.animal_models[name].model_path[:-4] + ".onnx"):
|
84 |
+
return False
|
85 |
+
|
86 |
+
# XPOSE
|
87 |
+
xpose_model_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/xpose.pth")
|
88 |
+
if not os.path.exists(xpose_model_path):
|
89 |
+
return False
|
90 |
+
embeddings_cache_9_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/clip_embedding_9.pkl")
|
91 |
+
if not os.path.exists(embeddings_cache_9_path):
|
92 |
+
return False
|
93 |
+
embeddings_cache_68_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/clip_embedding_68.pkl")
|
94 |
+
if not os.path.exists(embeddings_cache_68_path):
|
95 |
+
return False
|
96 |
+
return ret
|
97 |
+
|
98 |
+
|
99 |
+
def convert_onnx_to_trt_models(infer_cfg):
|
100 |
+
ret = True
|
101 |
+
for name in infer_cfg.models:
|
102 |
+
if not isinstance(infer_cfg.models[name].model_path, str):
|
103 |
+
for i in range(len(infer_cfg.models[name].model_path)):
|
104 |
+
trt_path = infer_cfg.models[name].model_path[i]
|
105 |
+
onnx_path = trt_path[:-4] + ".onnx"
|
106 |
+
if not os.path.exists(trt_path):
|
107 |
+
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
|
108 |
+
logger_f.info(f"convert onnx model: {onnx_path}")
|
109 |
+
result = subprocess.run(convert_cmd, shell=True, check=True)
|
110 |
+
# 检查结果
|
111 |
+
if result.returncode == 0:
|
112 |
+
logger_f.info(f"convert onnx model: {onnx_path} successful")
|
113 |
+
else:
|
114 |
+
logger_f.error(f"convert onnx model: {onnx_path} failed")
|
115 |
+
return False
|
116 |
+
else:
|
117 |
+
trt_path = infer_cfg.models[name].model_path
|
118 |
+
onnx_path = trt_path[:-4] + ".onnx"
|
119 |
+
if not os.path.exists(trt_path):
|
120 |
+
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
|
121 |
+
logger_f.info(f"convert onnx model: {onnx_path}")
|
122 |
+
result = subprocess.run(convert_cmd, shell=True, check=True)
|
123 |
+
# 检查结果
|
124 |
+
if result.returncode == 0:
|
125 |
+
logger_f.info(f"convert onnx model: {onnx_path} successful")
|
126 |
+
else:
|
127 |
+
logger_f.error(f"convert onnx model: {onnx_path} failed")
|
128 |
+
return False
|
129 |
+
|
130 |
+
for name in infer_cfg.animal_models:
|
131 |
+
if not isinstance(infer_cfg.animal_models[name].model_path, str):
|
132 |
+
for i in range(len(infer_cfg.animal_models[name].model_path)):
|
133 |
+
trt_path = infer_cfg.animal_models[name].model_path[i]
|
134 |
+
onnx_path = trt_path[:-4] + ".onnx"
|
135 |
+
if not os.path.exists(trt_path):
|
136 |
+
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
|
137 |
+
logger_f.info(f"convert onnx model: {onnx_path}")
|
138 |
+
result = subprocess.run(convert_cmd, shell=True, check=True)
|
139 |
+
# 检查结果
|
140 |
+
if result.returncode == 0:
|
141 |
+
logger_f.info(f"convert onnx model: {onnx_path} successful")
|
142 |
+
else:
|
143 |
+
logger_f.error(f"convert onnx model: {onnx_path} failed")
|
144 |
+
return False
|
145 |
+
else:
|
146 |
+
trt_path = infer_cfg.animal_models[name].model_path
|
147 |
+
onnx_path = trt_path[:-4] + ".onnx"
|
148 |
+
if not os.path.exists(trt_path):
|
149 |
+
convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
|
150 |
+
logger_f.info(f"convert onnx model: {onnx_path}")
|
151 |
+
result = subprocess.run(convert_cmd, shell=True, check=True)
|
152 |
+
# 检查结果
|
153 |
+
if result.returncode == 0:
|
154 |
+
logger_f.info(f"convert onnx model: {onnx_path} successful")
|
155 |
+
else:
|
156 |
+
logger_f.error(f"convert onnx model: {onnx_path} failed")
|
157 |
+
return False
|
158 |
+
return ret
|
159 |
+
|
160 |
+
|
161 |
+
@app.on_event("startup")
|
162 |
+
async def startup_event():
|
163 |
+
global pipe
|
164 |
+
# default use trt model
|
165 |
+
cfg_file = os.path.join(project_dir, "configs/trt_infer.yaml")
|
166 |
+
infer_cfg = OmegaConf.load(cfg_file)
|
167 |
+
checkpoints_exist = check_all_checkpoints_exist(infer_cfg)
|
168 |
+
|
169 |
+
# first: download model if not exist
|
170 |
+
if not checkpoints_exist:
|
171 |
+
download_cmd = f"huggingface-cli download warmshao/FasterLivePortrait --local-dir {checkpoints_dir}"
|
172 |
+
logger_f.info(f"download model: {download_cmd}")
|
173 |
+
result = subprocess.run(download_cmd, shell=True, check=True)
|
174 |
+
# 检查结果
|
175 |
+
if result.returncode == 0:
|
176 |
+
logger_f.info(f"Download checkpoints to {checkpoints_dir} successful")
|
177 |
+
else:
|
178 |
+
logger_f.error(f"Download checkpoints to {checkpoints_dir} failed")
|
179 |
+
exit(1)
|
180 |
+
# second: convert onnx model to trt
|
181 |
+
convert_ret = convert_onnx_to_trt_models(infer_cfg)
|
182 |
+
if not convert_ret:
|
183 |
+
logger_f.error(f"convert onnx model to trt failed")
|
184 |
+
exit(1)
|
185 |
+
|
186 |
+
infer_cfg.infer_params.flag_pasteback = True
|
187 |
+
pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=True)
|
188 |
+
|
189 |
+
|
190 |
+
def run_with_video(source_image_path, driving_video_path, save_dir):
|
191 |
+
global pipe
|
192 |
+
ret = pipe.prepare_source(source_image_path, realtime=False)
|
193 |
+
if not ret:
|
194 |
+
logger_f.warning(f"no face in {source_image_path}! exit!")
|
195 |
+
return
|
196 |
+
vcap = cv2.VideoCapture(driving_video_path)
|
197 |
+
fps = int(vcap.get(cv2.CAP_PROP_FPS))
|
198 |
+
h, w = pipe.src_imgs[0].shape[:2]
|
199 |
+
|
200 |
+
# render output video
|
201 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
202 |
+
vsave_crop_path = os.path.join(save_dir,
|
203 |
+
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_video_path)}-crop.mp4")
|
204 |
+
vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512 * 2, 512))
|
205 |
+
vsave_org_path = os.path.join(save_dir,
|
206 |
+
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_video_path)}-org.mp4")
|
207 |
+
vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
|
208 |
+
|
209 |
+
infer_times = []
|
210 |
+
motion_lst = []
|
211 |
+
c_eyes_lst = []
|
212 |
+
c_lip_lst = []
|
213 |
+
|
214 |
+
frame_ind = 0
|
215 |
+
while vcap.isOpened():
|
216 |
+
ret, frame = vcap.read()
|
217 |
+
if not ret:
|
218 |
+
break
|
219 |
+
t0 = time.time()
|
220 |
+
first_frame = frame_ind == 0
|
221 |
+
dri_crop, out_crop, out_org, dri_motion_info = pipe.run(frame, pipe.src_imgs[0], pipe.src_infos[0],
|
222 |
+
first_frame=first_frame)
|
223 |
+
frame_ind += 1
|
224 |
+
if out_crop is None:
|
225 |
+
logger_f.warning(f"no face in driving frame:{frame_ind}")
|
226 |
+
continue
|
227 |
+
|
228 |
+
motion_lst.append(dri_motion_info[0])
|
229 |
+
c_eyes_lst.append(dri_motion_info[1])
|
230 |
+
c_lip_lst.append(dri_motion_info[2])
|
231 |
+
|
232 |
+
infer_times.append(time.time() - t0)
|
233 |
+
# print(time.time() - t0)
|
234 |
+
dri_crop = cv2.resize(dri_crop, (512, 512))
|
235 |
+
out_crop = np.concatenate([dri_crop, out_crop], axis=1)
|
236 |
+
out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
|
237 |
+
vout_crop.write(out_crop)
|
238 |
+
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
|
239 |
+
vout_org.write(out_org)
|
240 |
+
vcap.release()
|
241 |
+
vout_crop.release()
|
242 |
+
vout_org.release()
|
243 |
+
if video_has_audio(driving_video_path):
|
244 |
+
vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4"
|
245 |
+
subprocess.call(
|
246 |
+
[FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path,
|
247 |
+
"-b:v", "10M", "-c:v",
|
248 |
+
"libx264", "-map", "0:v", "-map", "1:a",
|
249 |
+
"-c:a", "aac",
|
250 |
+
"-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
|
251 |
+
vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4"
|
252 |
+
subprocess.call(
|
253 |
+
[FFMPEG, "-i", vsave_org_path, "-i", driving_video_path,
|
254 |
+
"-b:v", "10M", "-c:v",
|
255 |
+
"libx264", "-map", "0:v", "-map", "1:a",
|
256 |
+
"-c:a", "aac",
|
257 |
+
"-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
|
258 |
+
|
259 |
+
logger_f.info(vsave_crop_path_new)
|
260 |
+
logger_f.info(vsave_org_path_new)
|
261 |
+
else:
|
262 |
+
logger_f.info(vsave_crop_path)
|
263 |
+
logger_f.info(vsave_org_path)
|
264 |
+
|
265 |
+
logger_f.info(
|
266 |
+
"inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
|
267 |
+
np.mean(infer_times) * 1000))
|
268 |
+
# save driving motion to pkl
|
269 |
+
template_dct = {
|
270 |
+
'n_frames': len(motion_lst),
|
271 |
+
'output_fps': fps,
|
272 |
+
'motion': motion_lst,
|
273 |
+
'c_eyes_lst': c_eyes_lst,
|
274 |
+
'c_lip_lst': c_lip_lst,
|
275 |
+
}
|
276 |
+
template_pkl_path = os.path.join(save_dir,
|
277 |
+
f"{os.path.basename(driving_video_path)}.pkl")
|
278 |
+
with open(template_pkl_path, "wb") as fw:
|
279 |
+
pickle.dump(template_dct, fw)
|
280 |
+
logger_f.info(f"save driving motion pkl file at : {template_pkl_path}")
|
281 |
+
|
282 |
+
|
283 |
+
def run_with_pkl(source_image_path, driving_pickle_path, save_dir):
|
284 |
+
global pipe
|
285 |
+
ret = pipe.prepare_source(source_image_path, realtime=False)
|
286 |
+
if not ret:
|
287 |
+
logger_f.warning(f"no face in {source_image_path}! exit!")
|
288 |
+
return
|
289 |
+
|
290 |
+
with open(driving_pickle_path, "rb") as fin:
|
291 |
+
dri_motion_infos = pickle.load(fin)
|
292 |
+
|
293 |
+
fps = int(dri_motion_infos["output_fps"])
|
294 |
+
h, w = pipe.src_imgs[0].shape[:2]
|
295 |
+
|
296 |
+
# render output video
|
297 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
298 |
+
vsave_crop_path = os.path.join(save_dir,
|
299 |
+
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_pickle_path)}-crop.mp4")
|
300 |
+
vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512, 512))
|
301 |
+
vsave_org_path = os.path.join(save_dir,
|
302 |
+
f"{os.path.basename(source_image_path)}-{os.path.basename(driving_pickle_path)}-org.mp4")
|
303 |
+
vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
|
304 |
+
|
305 |
+
infer_times = []
|
306 |
+
motion_lst = dri_motion_infos["motion"]
|
307 |
+
c_eyes_lst = dri_motion_infos["c_eyes_lst"] if "c_eyes_lst" in dri_motion_infos else dri_motion_infos[
|
308 |
+
"c_d_eyes_lst"]
|
309 |
+
c_lip_lst = dri_motion_infos["c_lip_lst"] if "c_lip_lst" in dri_motion_infos else dri_motion_infos["c_d_lip_lst"]
|
310 |
+
|
311 |
+
frame_num = len(motion_lst)
|
312 |
+
for frame_ind in tqdm(range(frame_num)):
|
313 |
+
t0 = time.time()
|
314 |
+
first_frame = frame_ind == 0
|
315 |
+
dri_motion_info_ = [motion_lst[frame_ind], c_eyes_lst[frame_ind], c_lip_lst[frame_ind]]
|
316 |
+
out_crop, out_org = pipe.run_with_pkl(dri_motion_info_, pipe.src_imgs[0], pipe.src_infos[0],
|
317 |
+
first_frame=first_frame)
|
318 |
+
if out_crop is None:
|
319 |
+
logger_f.warning(f"no face in driving frame:{frame_ind}")
|
320 |
+
continue
|
321 |
+
|
322 |
+
infer_times.append(time.time() - t0)
|
323 |
+
# print(time.time() - t0)
|
324 |
+
out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
|
325 |
+
vout_crop.write(out_crop)
|
326 |
+
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
|
327 |
+
vout_org.write(out_org)
|
328 |
+
|
329 |
+
vout_crop.release()
|
330 |
+
vout_org.release()
|
331 |
+
logger_f.info(vsave_crop_path)
|
332 |
+
logger_f.info(vsave_org_path)
|
333 |
+
logger_f.info(
|
334 |
+
"inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
|
335 |
+
np.mean(infer_times) * 1000))
|
336 |
+
|
337 |
+
|
338 |
+
class LivePortraitParams(BaseModel):
|
339 |
+
flag_pickle: bool = False
|
340 |
+
flag_relative_input: bool = True
|
341 |
+
flag_do_crop_input: bool = True
|
342 |
+
flag_remap_input: bool = True
|
343 |
+
driving_multiplier: float = 1.0
|
344 |
+
flag_stitching: bool = True
|
345 |
+
flag_crop_driving_video_input: bool = True
|
346 |
+
flag_video_editing_head_rotation: bool = False
|
347 |
+
flag_is_animal: bool = True
|
348 |
+
scale: float = 2.3
|
349 |
+
vx_ratio: float = 0.0
|
350 |
+
vy_ratio: float = -0.125
|
351 |
+
scale_crop_driving_video: float = 2.2
|
352 |
+
vx_ratio_crop_driving_video: float = 0.0
|
353 |
+
vy_ratio_crop_driving_video: float = -0.1
|
354 |
+
driving_smooth_observation_variance: float = 1e-7
|
355 |
+
|
356 |
+
|
357 |
+
@app.post("/predict/")
|
358 |
+
async def upload_files(
|
359 |
+
source_image: Optional[UploadFile] = File(None),
|
360 |
+
driving_video: Optional[UploadFile] = File(None),
|
361 |
+
driving_pickle: Optional[UploadFile] = File(None),
|
362 |
+
flag_is_animal: bool = Form(...),
|
363 |
+
flag_pickle: bool = Form(...),
|
364 |
+
flag_relative_input: bool = Form(...),
|
365 |
+
flag_do_crop_input: bool = Form(...),
|
366 |
+
flag_remap_input: bool = Form(...),
|
367 |
+
driving_multiplier: float = Form(...),
|
368 |
+
flag_stitching: bool = Form(...),
|
369 |
+
flag_crop_driving_video_input: bool = Form(...),
|
370 |
+
flag_video_editing_head_rotation: bool = Form(...),
|
371 |
+
scale: float = Form(...),
|
372 |
+
vx_ratio: float = Form(...),
|
373 |
+
vy_ratio: float = Form(...),
|
374 |
+
scale_crop_driving_video: float = Form(...),
|
375 |
+
vx_ratio_crop_driving_video: float = Form(...),
|
376 |
+
vy_ratio_crop_driving_video: float = Form(...),
|
377 |
+
driving_smooth_observation_variance: float = Form(...)
|
378 |
+
):
|
379 |
+
# 根据传入的表单参数构建 infer_params
|
380 |
+
infer_params = LivePortraitParams(
|
381 |
+
flag_is_animal=flag_is_animal,
|
382 |
+
flag_pickle=flag_pickle,
|
383 |
+
flag_relative_input=flag_relative_input,
|
384 |
+
flag_do_crop_input=flag_do_crop_input,
|
385 |
+
flag_remap_input=flag_remap_input,
|
386 |
+
driving_multiplier=driving_multiplier,
|
387 |
+
flag_stitching=flag_stitching,
|
388 |
+
flag_crop_driving_video_input=flag_crop_driving_video_input,
|
389 |
+
flag_video_editing_head_rotation=flag_video_editing_head_rotation,
|
390 |
+
scale=scale,
|
391 |
+
vx_ratio=vx_ratio,
|
392 |
+
vy_ratio=vy_ratio,
|
393 |
+
scale_crop_driving_video=scale_crop_driving_video,
|
394 |
+
vx_ratio_crop_driving_video=vx_ratio_crop_driving_video,
|
395 |
+
vy_ratio_crop_driving_video=vy_ratio_crop_driving_video,
|
396 |
+
driving_smooth_observation_variance=driving_smooth_observation_variance
|
397 |
+
)
|
398 |
+
|
399 |
+
global pipe
|
400 |
+
pipe.init_vars()
|
401 |
+
if infer_params.flag_is_animal != pipe.is_animal:
|
402 |
+
pipe.init_models(is_animal=infer_params.flag_is_animal)
|
403 |
+
|
404 |
+
args_user = {
|
405 |
+
'flag_relative_motion': infer_params.flag_relative_input,
|
406 |
+
'flag_do_crop': infer_params.flag_do_crop_input,
|
407 |
+
'flag_pasteback': infer_params.flag_remap_input,
|
408 |
+
'driving_multiplier': infer_params.driving_multiplier,
|
409 |
+
'flag_stitching': infer_params.flag_stitching,
|
410 |
+
'flag_crop_driving_video': infer_params.flag_crop_driving_video_input,
|
411 |
+
'flag_video_editing_head_rotation': infer_params.flag_video_editing_head_rotation,
|
412 |
+
'src_scale': infer_params.scale,
|
413 |
+
'src_vx_ratio': infer_params.vx_ratio,
|
414 |
+
'src_vy_ratio': infer_params.vy_ratio,
|
415 |
+
'dri_scale': infer_params.scale_crop_driving_video,
|
416 |
+
'dri_vx_ratio': infer_params.vx_ratio_crop_driving_video,
|
417 |
+
'dri_vy_ratio': infer_params.vy_ratio_crop_driving_video,
|
418 |
+
}
|
419 |
+
# update config from user input
|
420 |
+
update_ret = pipe.update_cfg(args_user)
|
421 |
+
|
422 |
+
# 保存 source_image 到指定目录
|
423 |
+
temp_dir = os.path.join(result_dir, f"temp-{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}")
|
424 |
+
os.makedirs(temp_dir, exist_ok=True)
|
425 |
+
if source_image and source_image.filename:
|
426 |
+
source_image_path = os.path.join(temp_dir, source_image.filename)
|
427 |
+
with open(source_image_path, "wb") as buffer:
|
428 |
+
buffer.write(await source_image.read()) # 将内容写入文件
|
429 |
+
else:
|
430 |
+
source_image_path = None
|
431 |
+
|
432 |
+
if driving_video and driving_video.filename:
|
433 |
+
driving_video_path = os.path.join(temp_dir, driving_video.filename)
|
434 |
+
with open(driving_video_path, "wb") as buffer:
|
435 |
+
buffer.write(await driving_video.read()) # 将内容写入文件
|
436 |
+
else:
|
437 |
+
driving_video_path = None
|
438 |
+
|
439 |
+
if driving_pickle and driving_pickle.filename:
|
440 |
+
driving_pickle_path = os.path.join(temp_dir, driving_pickle.filename)
|
441 |
+
with open(driving_pickle_path, "wb") as buffer:
|
442 |
+
buffer.write(await driving_pickle.read()) # 将内容写入文件
|
443 |
+
else:
|
444 |
+
driving_pickle_path = None
|
445 |
+
|
446 |
+
save_dir = os.path.join(result_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}")
|
447 |
+
os.makedirs(save_dir, exist_ok=True)
|
448 |
+
|
449 |
+
if infer_params.flag_pickle:
|
450 |
+
if source_image_path and driving_pickle_path:
|
451 |
+
run_with_pkl(source_image_path, driving_pickle_path, save_dir)
|
452 |
+
else:
|
453 |
+
if source_image_path and driving_video_path:
|
454 |
+
run_with_video(source_image_path, driving_video_path, save_dir)
|
455 |
+
# zip all files and return
|
456 |
+
# 使用 BytesIO 在内存中创建一个字节流
|
457 |
+
zip_buffer = io.BytesIO()
|
458 |
+
|
459 |
+
# 使用 ZipFile 将文件夹内容压缩到 zip_buffer 中
|
460 |
+
with ZipFile(zip_buffer, "w") as zip_file:
|
461 |
+
for root, dirs, files in os.walk(save_dir):
|
462 |
+
for file in files:
|
463 |
+
file_path = os.path.join(root, file)
|
464 |
+
# 添加文件到 ZIP 文件中
|
465 |
+
zip_file.write(file_path, arcname=os.path.relpath(file_path, save_dir))
|
466 |
+
|
467 |
+
# 确保缓冲区指针在开始位置,以便读取整个内容
|
468 |
+
zip_buffer.seek(0)
|
469 |
+
shutil.rmtree(temp_dir)
|
470 |
+
shutil.rmtree(save_dir)
|
471 |
+
# 通过 StreamingResponse 返回 zip 文件
|
472 |
+
return StreamingResponse(zip_buffer, media_type="application/zip",
|
473 |
+
headers={"Content-Disposition": "attachment; filename=output.zip"})
|
474 |
+
|
475 |
+
|
476 |
+
if __name__ == "__main__":
|
477 |
+
import uvicorn
|
478 |
+
|
479 |
+
uvicorn.run(app, host=os.environ.get("FLIP_IP", "127.0.0.1"), port=os.environ.get("FLIP_PORT", 9871))
|
assets/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
examples/driving/*.pkl
|
2 |
+
examples/driving/*_crop.mp4
|
assets/docs/API.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## FasterLivePortrait API Usage Guide
|
2 |
+
|
3 |
+
### Building the Image
|
4 |
+
* Decide on an image name, for example `shaoguo/faster_liveportrait_api:v1.0`. Replace the `-t` parameter in the following command with your chosen name.
|
5 |
+
* Run `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .`
|
6 |
+
|
7 |
+
### Running the Image
|
8 |
+
Ensure that your machine has Nvidia GPU drivers installed. CUDA version should be 12.0 or higher. Two scenarios are described below.
|
9 |
+
|
10 |
+
* Running on a Local Machine (typically for self-testing)
|
11 |
+
* Modify the image name according to what you defined above.
|
12 |
+
* Confirm the service port number, default is `9871`. You can define your own by changing the `SERVER_PORT` environment variable in the command below. Remember to also change `-p 9871:9871` to map the port.
|
13 |
+
* Set the model path environment variable `CHECKPOINT_DIR`. If you've previously downloaded FasterLivePortrait's onnx model and converted it to trt, I recommend mapping the model files into the container using `-v`, for example `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`. This avoids re-downloading the onnx model and doing trt conversion. Otherwise, I will check if `CHECKPOINT_DIR` has models, and if not, I will automatically download (ensure network connectivity) and do trt conversion, which will take considerable time.
|
14 |
+
* Run command (note: modify the following command according to your settings):
|
15 |
+
```shell
|
16 |
+
docker run -d --gpus=all \
|
17 |
+
--name faster_liveportrait_api \
|
18 |
+
-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \
|
19 |
+
-e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \
|
20 |
+
-e SERVER_PORT=9871 \
|
21 |
+
-p 9871:9871 \
|
22 |
+
--restart=always \
|
23 |
+
shaoguo/faster_liveportrait_api:v1.0 \
|
24 |
+
/bin/bash
|
25 |
+
```
|
26 |
+
* Normal operation should display the following information(docker logs $container_id). The running logs are saved in `/root/FasterLivePortrait/logs/log_run.log`:
|
27 |
+
```shell
|
28 |
+
INFO: Application startup complete.
|
29 |
+
INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit)
|
30 |
+
```
|
31 |
+
|
32 |
+
* Running on Cloud GPU Cluster (production environment)
|
33 |
+
* This needs to be configured according to different clusters, but the core is the configuration of docker image and environment variables.
|
34 |
+
* Load balancing may need to be set up.
|
35 |
+
|
36 |
+
### API Call Testing
|
37 |
+
Refer to `tests/test_api.py`. The default is the Animal model, but now it also supports the Human model.
|
38 |
+
The return is a compressed package, by default unzipped to `./results/api_*`. Confirm according to the actual printed log.
|
39 |
+
* `test_with_video_animal()`, image and video driving. Set `flag_pickle=False`. It will additionally return the driving video's pkl file, which can be called directly next time.
|
40 |
+
* `test_with_pkl_animal()`, image and pkl driving.
|
41 |
+
* `test_with_video_human()`, image and video driving under the Human model, set `flag_is_animal=False`
|
assets/docs/API_ZH.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## FasterLivePortrait API使用教程
|
2 |
+
|
3 |
+
### 构建镜像
|
4 |
+
|
5 |
+
* 确定镜像的名字,比如 `shaoguo/faster_liveportrait_api:v1.0`。确认后替换为下面命令 `-t` 的参数。
|
6 |
+
* 运行 `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .`
|
7 |
+
|
8 |
+
### 运行镜像
|
9 |
+
|
10 |
+
请确保你的机器已经装了Nvidia显卡的驱动。CUDA的版本在cuda12.0及以上。以下分两种情况介绍。
|
11 |
+
|
12 |
+
* 本地机器运行(一般自己测试使用)
|
13 |
+
* 镜像名称根据上面你自己定义的更改。
|
14 |
+
* 确认服务的端口号,默认为`9871`,你可以自己定义,更改下面命令里环境变量`SERVER_PORT`。同时要记得更改`-p 9871:9871`,
|
15 |
+
将端口映射出来。
|
16 |
+
* 设置模型路径环境变量 `CHECKPOINT_DIR`。如果你之前下载过FasterLivePortrait的onnx模型并做过trt的转换,我建议
|
17 |
+
是可以通过 `-v`把
|
18 |
+
模型文件映射进入容器,比如 `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`,
|
19 |
+
这样就避免重新下载onnx模型和做trt的转换。否则我将会检测`CHECKPOINT_DIR`是否有模型,没有的话,我将自动下载(确保有网络)和做trt的转换,这将耗时比较久的时间。
|
20 |
+
* 运行命令(注意你要根据自己的设置更改以下命令的信息):
|
21 |
+
```shell
|
22 |
+
docker run -d --gpus=all \
|
23 |
+
--name faster_liveportrait_api \
|
24 |
+
-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \
|
25 |
+
-e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \
|
26 |
+
-e SERVER_PORT=9871 \
|
27 |
+
-p 9871:9871 \
|
28 |
+
--restart=always \
|
29 |
+
shaoguo/faster_liveportrait_api:v1.0
|
30 |
+
```
|
31 |
+
* 正常运行应该会显示以下信息(docker logs container_id), 运行的日志保存在`/root/FasterLivePortrait/logs/log_run.log`:
|
32 |
+
```shell
|
33 |
+
INFO: Application startup complete.
|
34 |
+
INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit)
|
35 |
+
```
|
36 |
+
* 云端GPU集群运行(生产环境)
|
37 |
+
* 这需要根据不同的集群做配置,但核心就是镜像和环境变量的配置。
|
38 |
+
* 可能要设置负载均衡。
|
39 |
+
|
40 |
+
### API调用测试
|
41 |
+
|
42 |
+
可以参考`tests/test_api.py`, 默认是Animal的模型,但现在同时也支持Human的模型了。
|
43 |
+
返回的是压缩包,默认解压在`./results/api_*`, 根据实际打印出来的日志确认。
|
44 |
+
|
45 |
+
* `test_with_video_animal()`, 图像和视频的驱动。设置`flag_pickle=False`。会额外返回driving video的pkl文件,下次可以直接调用。
|
46 |
+
* `test_with_pkl_animal()`, 图像和pkl的驱动。
|
47 |
+
* `test_with_video_human()`, Human模型下图像和视频的驱动,设置`flag_is_animal=False`
|
assets/gradio/gradio_description_animate_clear.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div style="font-size: 1.2em; text-align: center;">
|
2 |
+
Step 3: Click the <strong>🚀 Animate</strong> button below to generate, or click <strong>🧹 Clear</strong> to erase the results
|
3 |
+
</div>
|
4 |
+
<!-- <div style="font-size: 1.1em; text-align: center;">
|
5 |
+
<strong style="color: red;">Note:</strong> If both <strong>Source Image</strong> and <strong>Video</strong> are uploaded, the <strong>Source Image</strong> will be used. Please click the <strong>🧹 Clear</strong> button, then re-upload the <strong>Source Image</strong> or <strong>Video</strong>.
|
6 |
+
</div> -->
|
assets/gradio/gradio_description_animation.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<span style="font-size: 1.2em;">🔥 To animate the source image or video with the driving video, please follow these steps:</span>
|
2 |
+
<div style="font-size: 1.2em; margin-left: 20px;">
|
3 |
+
1. In the <strong>Animation Options for Source Image or Video</strong> section, we recommend enabling the <code>do crop (source)</code> option if faces occupy a small portion of your source image or video.
|
4 |
+
</div>
|
5 |
+
<div style="font-size: 1.2em; margin-left: 20px;">
|
6 |
+
2. In the <strong>Animation Options for Driving Video</strong> section, the <code>relative head rotation</code> and <code>smooth strength</code> options only take effect if the source input is a video.
|
7 |
+
</div>
|
8 |
+
<div style="font-size: 1.2em; margin-left: 20px;">
|
9 |
+
3. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
|
10 |
+
</div>
|
11 |
+
<div style="font-size: 1.2em; margin-left: 20px;">
|
12 |
+
4. If you want to upload your own driving video, <strong>the best practice</strong>:
|
13 |
+
|
14 |
+
- Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
|
15 |
+
- Focus on the head area, similar to the example videos.
|
16 |
+
- Minimize shoulder movement.
|
17 |
+
- Make sure the first frame of driving video is a frontal face with **neutral expression**.
|
18 |
+
|
19 |
+
</div>
|
assets/gradio/gradio_description_retargeting.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<br>
|
2 |
+
|
3 |
+
<!-- ## Retargeting -->
|
4 |
+
<!-- <span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span> -->
|
5 |
+
|
6 |
+
|
7 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
|
8 |
+
<div>
|
9 |
+
<h2>Retargeting</h2>
|
10 |
+
<p>Upload a Source Portrait as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times.
|
11 |
+
<br>
|
12 |
+
<strong>😊 Set both ratios to 0.8 to see what's going on!</strong></p>
|
13 |
+
</div>
|
14 |
+
</div>
|
assets/gradio/gradio_description_upload.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<br>
|
2 |
+
<div style="font-size: 1.2em; display: flex; justify-content: space-between;">
|
3 |
+
<div style="flex: 1; text-align: center; margin-right: 20px;">
|
4 |
+
<div style="display: inline-block;">
|
5 |
+
Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
|
6 |
+
</div>
|
7 |
+
</div>
|
8 |
+
<div style="flex: 1; text-align: center; margin-left: 20px;">
|
9 |
+
<div style="display: inline-block;">
|
10 |
+
Step 2: Upload a <strong>Driving Video</strong> (any aspect ratio) ⬇️
|
11 |
+
</div>
|
12 |
+
<div style="display: inline-block; font-size: 0.8em;">
|
13 |
+
<strong>Tips:</strong> Focus on the head, minimize shoulder movement, <strong>neutral expression</strong> in first frame.
|
14 |
+
</div>
|
15 |
+
</div>
|
16 |
+
</div>
|
assets/gradio/gradio_title.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
2 |
+
<div>
|
3 |
+
<h1>FasterLivePortrait: Bring Portraits to Life in Real Time</h1>
|
4 |
+
<span>Built on <a href="https://github.com/KwaiVGI/LivePortrait">LivePortrait</a></span>
|
5 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center; margin-top: 10px;">
|
6 |
+
<a href="https://huggingface.co/warmshao/FasterLivePortrait">
|
7 |
+
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue" alt="Hugging Face Spaces">
|
8 |
+
</a>
|
9 |
+
|
10 |
+
<a href="https://github.com/warmshao/FasterLivePortrait">
|
11 |
+
<img src="https://img.shields.io/badge/Github-Code-blue" alt="Github Code">
|
12 |
+
</a>
|
13 |
+
|
14 |
+
<a href="https://github.com/warmshao/FasterLivePortrait">
|
15 |
+
<img src="https://img.shields.io/github/stars/warmshao/FasterLivePortrait" alt="Github Stars">
|
16 |
+
</a>
|
17 |
+
</div>
|
18 |
+
</div>
|
19 |
+
</div>
|
assets/mask_template.png
ADDED
![]() |
camera.bat
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
setlocal enabledelayedexpansion
|
3 |
+
|
4 |
+
REM 设置默认源图像路径
|
5 |
+
set "default_src_image=assets\examples\source\s12.jpg"
|
6 |
+
set "src_image=%default_src_image%"
|
7 |
+
set "animal_param="
|
8 |
+
set "paste_back="
|
9 |
+
|
10 |
+
REM 解析命名参数
|
11 |
+
:parse_args
|
12 |
+
if "%~1"=="" goto end_parse_args
|
13 |
+
if /i "%~1"=="--src_image" (
|
14 |
+
set "src_image=%~2"
|
15 |
+
shift
|
16 |
+
) else if /i "%~1"=="--animal" (
|
17 |
+
set "animal_param=--animal"
|
18 |
+
) else if /i "%~1"=="--paste_back" (
|
19 |
+
set "paste_back=--paste_back"
|
20 |
+
)
|
21 |
+
shift
|
22 |
+
goto parse_args
|
23 |
+
:end_parse_args
|
24 |
+
|
25 |
+
echo source image: [!src_image!]
|
26 |
+
echo use animal: [!animal_param!]
|
27 |
+
echo paste_back: [!paste_back!]
|
28 |
+
|
29 |
+
REM 执行Python命令
|
30 |
+
.\venv\python.exe .\run.py --cfg configs/trt_infer.yaml --realtime --dri_video 0 --src_image !src_image! !animal_param! !paste_back!
|
31 |
+
|
32 |
+
endlocal
|
configs/onnx_infer.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models:
|
2 |
+
warping_spade:
|
3 |
+
name: "WarpingSpadeModel"
|
4 |
+
predict_type: "ort"
|
5 |
+
model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx"
|
6 |
+
motion_extractor:
|
7 |
+
name: "MotionExtractorModel"
|
8 |
+
predict_type: "ort"
|
9 |
+
model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx"
|
10 |
+
landmark:
|
11 |
+
name: "LandmarkModel"
|
12 |
+
predict_type: "ort"
|
13 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
|
14 |
+
face_analysis:
|
15 |
+
name: "FaceAnalysisModel"
|
16 |
+
predict_type: "ort"
|
17 |
+
model_path:
|
18 |
+
- "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx"
|
19 |
+
- "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx"
|
20 |
+
app_feat_extractor:
|
21 |
+
name: "AppearanceFeatureExtractorModel"
|
22 |
+
predict_type: "ort"
|
23 |
+
model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx"
|
24 |
+
stitching:
|
25 |
+
name: "StitchingModel"
|
26 |
+
predict_type: "ort"
|
27 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching.onnx"
|
28 |
+
stitching_eye_retarget:
|
29 |
+
name: "StitchingModel"
|
30 |
+
predict_type: "ort"
|
31 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx"
|
32 |
+
stitching_lip_retarget:
|
33 |
+
name: "StitchingModel"
|
34 |
+
predict_type: "ort"
|
35 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx"
|
36 |
+
|
37 |
+
animal_models:
|
38 |
+
warping_spade:
|
39 |
+
name: "WarpingSpadeModel"
|
40 |
+
predict_type: "ort"
|
41 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx"
|
42 |
+
motion_extractor:
|
43 |
+
name: "MotionExtractorModel"
|
44 |
+
predict_type: "ort"
|
45 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx"
|
46 |
+
app_feat_extractor:
|
47 |
+
name: "AppearanceFeatureExtractorModel"
|
48 |
+
predict_type: "ort"
|
49 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx"
|
50 |
+
stitching:
|
51 |
+
name: "StitchingModel"
|
52 |
+
predict_type: "ort"
|
53 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx"
|
54 |
+
stitching_eye_retarget:
|
55 |
+
name: "StitchingModel"
|
56 |
+
predict_type: "ort"
|
57 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx"
|
58 |
+
stitching_lip_retarget:
|
59 |
+
name: "StitchingModel"
|
60 |
+
predict_type: "ort"
|
61 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx"
|
62 |
+
landmark:
|
63 |
+
name: "LandmarkModel"
|
64 |
+
predict_type: "ort"
|
65 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
|
66 |
+
face_analysis:
|
67 |
+
name: "FaceAnalysisModel"
|
68 |
+
predict_type: "ort"
|
69 |
+
model_path:
|
70 |
+
- "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx"
|
71 |
+
- "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx"
|
72 |
+
|
73 |
+
joyvasa_models:
|
74 |
+
motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
|
75 |
+
audio_model_path: "checkpoints/chinese-hubert-base"
|
76 |
+
motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
|
77 |
+
|
78 |
+
crop_params:
|
79 |
+
src_dsize: 512
|
80 |
+
src_scale: 2.3
|
81 |
+
src_vx_ratio: 0.0
|
82 |
+
src_vy_ratio: -0.125
|
83 |
+
dri_scale: 2.2
|
84 |
+
dri_vx_ratio: 0.0
|
85 |
+
dri_vy_ratio: -0.1
|
86 |
+
|
87 |
+
|
88 |
+
infer_params:
|
89 |
+
flag_crop_driving_video: False
|
90 |
+
flag_normalize_lip: True
|
91 |
+
flag_source_video_eye_retargeting: False
|
92 |
+
flag_video_editing_head_rotation: False
|
93 |
+
flag_eye_retargeting: False
|
94 |
+
flag_lip_retargeting: False
|
95 |
+
flag_stitching: True
|
96 |
+
flag_relative_motion: True
|
97 |
+
flag_pasteback: True
|
98 |
+
flag_do_crop: True
|
99 |
+
flag_do_rot: True
|
100 |
+
|
101 |
+
# NOT EXPOERTED PARAMS
|
102 |
+
lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
|
103 |
+
source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
|
104 |
+
driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
105 |
+
anchor_frame: 0 # TO IMPLEMENT
|
106 |
+
mask_crop_path: "./assets/mask_template.png"
|
107 |
+
driving_multiplier: 1.0
|
108 |
+
animation_region: "all"
|
109 |
+
|
110 |
+
cfg_mode: "incremental"
|
111 |
+
cfg_scale: 1.2
|
112 |
+
|
113 |
+
source_max_dim: 1280 # the max dim of height and width of source image
|
114 |
+
source_division: 2 # make sure the height and width of source image can be divided by this number
|
configs/onnx_mp_infer.yaml
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models:
|
2 |
+
warping_spade:
|
3 |
+
name: "WarpingSpadeModel"
|
4 |
+
predict_type: "ort"
|
5 |
+
model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx"
|
6 |
+
motion_extractor:
|
7 |
+
name: "MotionExtractorModel"
|
8 |
+
predict_type: "ort"
|
9 |
+
model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx"
|
10 |
+
landmark:
|
11 |
+
name: "LandmarkModel"
|
12 |
+
predict_type: "ort"
|
13 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
|
14 |
+
face_analysis:
|
15 |
+
name: "MediaPipeFaceModel"
|
16 |
+
predict_type: "mp"
|
17 |
+
app_feat_extractor:
|
18 |
+
name: "AppearanceFeatureExtractorModel"
|
19 |
+
predict_type: "ort"
|
20 |
+
model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx"
|
21 |
+
stitching:
|
22 |
+
name: "StitchingModel"
|
23 |
+
predict_type: "ort"
|
24 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching.onnx"
|
25 |
+
stitching_eye_retarget:
|
26 |
+
name: "StitchingModel"
|
27 |
+
predict_type: "ort"
|
28 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx"
|
29 |
+
stitching_lip_retarget:
|
30 |
+
name: "StitchingModel"
|
31 |
+
predict_type: "ort"
|
32 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx"
|
33 |
+
|
34 |
+
animal_models:
|
35 |
+
warping_spade:
|
36 |
+
name: "WarpingSpadeModel"
|
37 |
+
predict_type: "ort"
|
38 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx"
|
39 |
+
motion_extractor:
|
40 |
+
name: "MotionExtractorModel"
|
41 |
+
predict_type: "ort"
|
42 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx"
|
43 |
+
app_feat_extractor:
|
44 |
+
name: "AppearanceFeatureExtractorModel"
|
45 |
+
predict_type: "ort"
|
46 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx"
|
47 |
+
stitching:
|
48 |
+
name: "StitchingModel"
|
49 |
+
predict_type: "ort"
|
50 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx"
|
51 |
+
stitching_eye_retarget:
|
52 |
+
name: "StitchingModel"
|
53 |
+
predict_type: "ort"
|
54 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx"
|
55 |
+
stitching_lip_retarget:
|
56 |
+
name: "StitchingModel"
|
57 |
+
predict_type: "ort"
|
58 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx"
|
59 |
+
landmark:
|
60 |
+
name: "LandmarkModel"
|
61 |
+
predict_type: "ort"
|
62 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
|
63 |
+
face_analysis:
|
64 |
+
name: "MediaPipeFaceModel"
|
65 |
+
predict_type: "mp"
|
66 |
+
|
67 |
+
joyvasa_models:
|
68 |
+
motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
|
69 |
+
audio_model_path: "checkpoints/chinese-hubert-base"
|
70 |
+
motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
|
71 |
+
|
72 |
+
crop_params:
|
73 |
+
src_dsize: 512
|
74 |
+
src_scale: 2.3
|
75 |
+
src_vx_ratio: 0.0
|
76 |
+
src_vy_ratio: -0.125
|
77 |
+
dri_scale: 2.2
|
78 |
+
dri_vx_ratio: 0.0
|
79 |
+
dri_vy_ratio: -0.1
|
80 |
+
|
81 |
+
|
82 |
+
infer_params:
|
83 |
+
flag_crop_driving_video: False
|
84 |
+
flag_normalize_lip: True
|
85 |
+
flag_source_video_eye_retargeting: False
|
86 |
+
flag_video_editing_head_rotation: False
|
87 |
+
flag_eye_retargeting: False
|
88 |
+
flag_lip_retargeting: False
|
89 |
+
flag_stitching: True
|
90 |
+
flag_relative_motion: True
|
91 |
+
flag_pasteback: True
|
92 |
+
flag_do_crop: True
|
93 |
+
flag_do_rot: True
|
94 |
+
|
95 |
+
# NOT EXPOERTED PARAMS
|
96 |
+
lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
|
97 |
+
source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
|
98 |
+
driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
99 |
+
anchor_frame: 0 # TO IMPLEMENT
|
100 |
+
mask_crop_path: "./assets/mask_template.png"
|
101 |
+
driving_multiplier: 1.0
|
102 |
+
animation_region: "all"
|
103 |
+
|
104 |
+
cfg_mode: "incremental"
|
105 |
+
cfg_scale: 1.2
|
106 |
+
|
107 |
+
source_max_dim: 1280 # the max dim of height and width of source image
|
108 |
+
source_division: 2 # make sure the height and width of source image can be divided by this number
|
configs/trt_infer.yaml
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models:
|
2 |
+
warping_spade:
|
3 |
+
name: "WarpingSpadeModel"
|
4 |
+
predict_type: "trt"
|
5 |
+
model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt"
|
6 |
+
motion_extractor:
|
7 |
+
name: "MotionExtractorModel"
|
8 |
+
predict_type: "trt"
|
9 |
+
model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt"
|
10 |
+
landmark:
|
11 |
+
name: "LandmarkModel"
|
12 |
+
predict_type: "trt"
|
13 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
|
14 |
+
face_analysis:
|
15 |
+
name: "FaceAnalysisModel"
|
16 |
+
predict_type: "trt"
|
17 |
+
model_path:
|
18 |
+
- "./checkpoints/liveportrait_onnx/retinaface_det_static.trt"
|
19 |
+
- "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt"
|
20 |
+
app_feat_extractor:
|
21 |
+
name: "AppearanceFeatureExtractorModel"
|
22 |
+
predict_type: "trt"
|
23 |
+
model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt"
|
24 |
+
stitching:
|
25 |
+
name: "StitchingModel"
|
26 |
+
predict_type: "trt"
|
27 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching.trt"
|
28 |
+
stitching_eye_retarget:
|
29 |
+
name: "StitchingModel"
|
30 |
+
predict_type: "trt"
|
31 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt"
|
32 |
+
stitching_lip_retarget:
|
33 |
+
name: "StitchingModel"
|
34 |
+
predict_type: "trt"
|
35 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt"
|
36 |
+
|
37 |
+
animal_models:
|
38 |
+
warping_spade:
|
39 |
+
name: "WarpingSpadeModel"
|
40 |
+
predict_type: "trt"
|
41 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt"
|
42 |
+
motion_extractor:
|
43 |
+
name: "MotionExtractorModel"
|
44 |
+
predict_type: "trt"
|
45 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt"
|
46 |
+
app_feat_extractor:
|
47 |
+
name: "AppearanceFeatureExtractorModel"
|
48 |
+
predict_type: "trt"
|
49 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt"
|
50 |
+
stitching:
|
51 |
+
name: "StitchingModel"
|
52 |
+
predict_type: "trt"
|
53 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt"
|
54 |
+
stitching_eye_retarget:
|
55 |
+
name: "StitchingModel"
|
56 |
+
predict_type: "trt"
|
57 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt"
|
58 |
+
stitching_lip_retarget:
|
59 |
+
name: "StitchingModel"
|
60 |
+
predict_type: "trt"
|
61 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt"
|
62 |
+
landmark:
|
63 |
+
name: "LandmarkModel"
|
64 |
+
predict_type: "trt"
|
65 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
|
66 |
+
face_analysis:
|
67 |
+
name: "FaceAnalysisModel"
|
68 |
+
predict_type: "trt"
|
69 |
+
model_path:
|
70 |
+
- "./checkpoints/liveportrait_onnx/retinaface_det_static.trt"
|
71 |
+
- "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt"
|
72 |
+
|
73 |
+
joyvasa_models:
|
74 |
+
motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
|
75 |
+
audio_model_path: "checkpoints/chinese-hubert-base"
|
76 |
+
motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
|
77 |
+
|
78 |
+
crop_params:
|
79 |
+
src_dsize: 512
|
80 |
+
src_scale: 2.3
|
81 |
+
src_vx_ratio: 0.0
|
82 |
+
src_vy_ratio: -0.125
|
83 |
+
dri_scale: 2.2
|
84 |
+
dri_vx_ratio: 0.0
|
85 |
+
dri_vy_ratio: -0.1
|
86 |
+
|
87 |
+
|
88 |
+
infer_params:
|
89 |
+
flag_crop_driving_video: False
|
90 |
+
flag_normalize_lip: True
|
91 |
+
flag_source_video_eye_retargeting: False
|
92 |
+
flag_video_editing_head_rotation: False
|
93 |
+
flag_eye_retargeting: False
|
94 |
+
flag_lip_retargeting: False
|
95 |
+
flag_stitching: True
|
96 |
+
flag_relative_motion: True
|
97 |
+
flag_pasteback: True
|
98 |
+
flag_do_crop: True
|
99 |
+
flag_do_rot: True
|
100 |
+
|
101 |
+
# NOT EXPOERTED PARAMS
|
102 |
+
lip_normalize_threshold: 0.1 # threshold for flag_normalize_lip
|
103 |
+
source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
|
104 |
+
driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
105 |
+
anchor_frame: 0 # TO IMPLEMENT
|
106 |
+
mask_crop_path: "./assets/mask_template.png"
|
107 |
+
driving_multiplier: 1.0
|
108 |
+
animation_region: "all"
|
109 |
+
|
110 |
+
cfg_mode: "incremental"
|
111 |
+
cfg_scale: 1.2
|
112 |
+
|
113 |
+
source_max_dim: 1280 # the max dim of height and width of source image
|
114 |
+
source_division: 2 # make sure the height and width of source image can be divided by this number
|
configs/trt_mp_infer.yaml
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models:
|
2 |
+
warping_spade:
|
3 |
+
name: "WarpingSpadeModel"
|
4 |
+
predict_type: "trt"
|
5 |
+
model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt"
|
6 |
+
motion_extractor:
|
7 |
+
name: "MotionExtractorModel"
|
8 |
+
predict_type: "trt"
|
9 |
+
model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt"
|
10 |
+
landmark:
|
11 |
+
name: "LandmarkModel"
|
12 |
+
predict_type: "trt"
|
13 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
|
14 |
+
face_analysis:
|
15 |
+
name: "MediaPipeFaceModel"
|
16 |
+
predict_type: "mp"
|
17 |
+
app_feat_extractor:
|
18 |
+
name: "AppearanceFeatureExtractorModel"
|
19 |
+
predict_type: "trt"
|
20 |
+
model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt"
|
21 |
+
stitching:
|
22 |
+
name: "StitchingModel"
|
23 |
+
predict_type: "trt"
|
24 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching.trt"
|
25 |
+
stitching_eye_retarget:
|
26 |
+
name: "StitchingModel"
|
27 |
+
predict_type: "trt"
|
28 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt"
|
29 |
+
stitching_lip_retarget:
|
30 |
+
name: "StitchingModel"
|
31 |
+
predict_type: "trt"
|
32 |
+
model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt"
|
33 |
+
|
34 |
+
animal_models:
|
35 |
+
warping_spade:
|
36 |
+
name: "WarpingSpadeModel"
|
37 |
+
predict_type: "trt"
|
38 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt"
|
39 |
+
motion_extractor:
|
40 |
+
name: "MotionExtractorModel"
|
41 |
+
predict_type: "trt"
|
42 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt"
|
43 |
+
app_feat_extractor:
|
44 |
+
name: "AppearanceFeatureExtractorModel"
|
45 |
+
predict_type: "trt"
|
46 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt"
|
47 |
+
stitching:
|
48 |
+
name: "StitchingModel"
|
49 |
+
predict_type: "trt"
|
50 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt"
|
51 |
+
stitching_eye_retarget:
|
52 |
+
name: "StitchingModel"
|
53 |
+
predict_type: "trt"
|
54 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt"
|
55 |
+
stitching_lip_retarget:
|
56 |
+
name: "StitchingModel"
|
57 |
+
predict_type: "trt"
|
58 |
+
model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt"
|
59 |
+
landmark:
|
60 |
+
name: "LandmarkModel"
|
61 |
+
predict_type: "trt"
|
62 |
+
model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
|
63 |
+
face_analysis:
|
64 |
+
name: "MediaPipeFaceModel"
|
65 |
+
predict_type: "mp"
|
66 |
+
|
67 |
+
joyvasa_models:
|
68 |
+
motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
|
69 |
+
audio_model_path: "checkpoints/chinese-hubert-base"
|
70 |
+
motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
|
71 |
+
|
72 |
+
crop_params:
|
73 |
+
src_dsize: 512
|
74 |
+
src_scale: 2.3
|
75 |
+
src_vx_ratio: 0.0
|
76 |
+
src_vy_ratio: -0.125
|
77 |
+
dri_scale: 2.2
|
78 |
+
dri_vx_ratio: 0.0
|
79 |
+
dri_vy_ratio: -0.1
|
80 |
+
|
81 |
+
|
82 |
+
infer_params:
|
83 |
+
flag_crop_driving_video: False
|
84 |
+
flag_normalize_lip: True
|
85 |
+
flag_source_video_eye_retargeting: False
|
86 |
+
flag_video_editing_head_rotation: False
|
87 |
+
flag_eye_retargeting: False
|
88 |
+
flag_lip_retargeting: False
|
89 |
+
flag_stitching: True
|
90 |
+
flag_relative_motion: True
|
91 |
+
flag_pasteback: True
|
92 |
+
flag_do_crop: True
|
93 |
+
flag_do_rot: True
|
94 |
+
animation_region: "all"
|
95 |
+
|
96 |
+
# NOT EXPOERTED PARAMS
|
97 |
+
lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
|
98 |
+
source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
|
99 |
+
driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
|
100 |
+
anchor_frame: 0 # TO IMPLEMENT
|
101 |
+
mask_crop_path: "./assets/mask_template.png"
|
102 |
+
driving_multiplier: 1.0
|
103 |
+
|
104 |
+
cfg_mode: "incremental"
|
105 |
+
cfg_scale: 1.2
|
106 |
+
|
107 |
+
source_max_dim: 1280 # the max dim of height and width of source image
|
108 |
+
source_division: 2 # make sure the height and width of source image can be divided by this number
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg-python
|
2 |
+
omegaconf
|
3 |
+
onnx
|
4 |
+
pycuda
|
5 |
+
numpy
|
6 |
+
opencv-python
|
7 |
+
gradio
|
8 |
+
scikit-image
|
9 |
+
insightface
|
10 |
+
huggingface_hub[cli]
|
11 |
+
mediapipe
|
12 |
+
torchgeometry
|
13 |
+
soundfile
|
14 |
+
munch
|
15 |
+
phonemizer
|
16 |
+
kokoro>=0.3.4
|
17 |
+
misaki[ja]
|
18 |
+
misaki[zh]
|
requirements_macos.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg-python
|
2 |
+
omegaconf
|
3 |
+
onnx
|
4 |
+
onnxruntime
|
5 |
+
numpy
|
6 |
+
opencv-python
|
7 |
+
gradio
|
8 |
+
scikit-image
|
9 |
+
insightface
|
10 |
+
huggingface_hub[cli]
|
11 |
+
mediapipe
|
12 |
+
torchgeometry
|
13 |
+
soundfile
|
14 |
+
munch
|
15 |
+
phonemizer
|
16 |
+
kokoro>=0.3.4
|
17 |
+
misaki[ja]
|
18 |
+
misaki[zh]
|
requirements_win.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg-python
|
2 |
+
omegaconf
|
3 |
+
onnx
|
4 |
+
numpy
|
5 |
+
opencv-python
|
6 |
+
gradio
|
7 |
+
scikit-image
|
8 |
+
insightface
|
9 |
+
huggingface_hub[cli]
|
10 |
+
mediapipe
|
11 |
+
torchgeometry
|
12 |
+
soundfile
|
13 |
+
munch
|
14 |
+
phonemizer
|
15 |
+
kokoro>=0.3.4
|
16 |
+
misaki[ja]
|
17 |
+
misaki[zh]
|
run.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author : wenshao
|
3 |
+
# @Email : [email protected]
|
4 |
+
# @Project : FasterLivePortrait
|
5 |
+
# @FileName: run.py
|
6 |
+
|
7 |
+
"""
|
8 |
+
# video
|
9 |
+
python run.py \
|
10 |
+
--src_image assets/examples/driving/d13.mp4 \
|
11 |
+
--dri_video assets/examples/driving/d11.mp4 \
|
12 |
+
--cfg configs/trt_infer.yaml \
|
13 |
+
--paste_back \
|
14 |
+
--animal
|
15 |
+
# pkl
|
16 |
+
python run.py \
|
17 |
+
--src_image assets/examples/source/s12.jpg \
|
18 |
+
--dri_video ./results/2024-09-13-081710/d0.mp4.pkl \
|
19 |
+
--cfg configs/trt_infer.yaml \
|
20 |
+
--paste_back \
|
21 |
+
--animal
|
22 |
+
"""
|
23 |
+
import os
|
24 |
+
import argparse
|
25 |
+
import pdb
|
26 |
+
import subprocess
|
27 |
+
import ffmpeg
|
28 |
+
import cv2
|
29 |
+
import time
|
30 |
+
import numpy as np
|
31 |
+
import os
|
32 |
+
import datetime
|
33 |
+
import platform
|
34 |
+
import pickle
|
35 |
+
from omegaconf import OmegaConf
|
36 |
+
from tqdm import tqdm
|
37 |
+
from colorama import Fore, Back, Style
|
38 |
+
from src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
|
39 |
+
from src.utils.utils import video_has_audio
|
40 |
+
|
41 |
+
if platform.system().lower() == 'windows':
|
42 |
+
FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe"
|
43 |
+
else:
|
44 |
+
FFMPEG = "ffmpeg"
|
45 |
+
|
46 |
+
|
47 |
+
def run_with_video(args):
|
48 |
+
print(Fore.RED+'Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate'+Style.RESET_ALL)
|
49 |
+
infer_cfg = OmegaConf.load(args.cfg)
|
50 |
+
infer_cfg.infer_params.flag_pasteback = args.paste_back
|
51 |
+
|
52 |
+
pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=args.animal)
|
53 |
+
ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
|
54 |
+
if not ret:
|
55 |
+
print(f"no face in {args.src_image}! exit!")
|
56 |
+
exit(1)
|
57 |
+
if not args.dri_video or not os.path.exists(args.dri_video):
|
58 |
+
# read frame from camera if no driving video input
|
59 |
+
vcap = cv2.VideoCapture(0)
|
60 |
+
if not vcap.isOpened():
|
61 |
+
print("no camera found! exit!")
|
62 |
+
exit(1)
|
63 |
+
else:
|
64 |
+
vcap = cv2.VideoCapture(args.dri_video)
|
65 |
+
fps = int(vcap.get(cv2.CAP_PROP_FPS))
|
66 |
+
h, w = pipe.src_imgs[0].shape[:2]
|
67 |
+
save_dir = f"./results/{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
68 |
+
os.makedirs(save_dir, exist_ok=True)
|
69 |
+
|
70 |
+
# render output video
|
71 |
+
if not args.realtime:
|
72 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
73 |
+
vsave_crop_path = os.path.join(save_dir,
|
74 |
+
f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4")
|
75 |
+
vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512 * 2, 512))
|
76 |
+
vsave_org_path = os.path.join(save_dir,
|
77 |
+
f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4")
|
78 |
+
vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
|
79 |
+
|
80 |
+
infer_times = []
|
81 |
+
motion_lst = []
|
82 |
+
c_eyes_lst = []
|
83 |
+
c_lip_lst = []
|
84 |
+
|
85 |
+
frame_ind = 0
|
86 |
+
while vcap.isOpened():
|
87 |
+
ret, frame = vcap.read()
|
88 |
+
if not ret:
|
89 |
+
break
|
90 |
+
t0 = time.time()
|
91 |
+
first_frame = frame_ind == 0
|
92 |
+
dri_crop, out_crop, out_org, dri_motion_info = pipe.run(frame, pipe.src_imgs[0], pipe.src_infos[0],
|
93 |
+
first_frame=first_frame)
|
94 |
+
frame_ind += 1
|
95 |
+
if out_crop is None:
|
96 |
+
print(f"no face in driving frame:{frame_ind}")
|
97 |
+
continue
|
98 |
+
|
99 |
+
motion_lst.append(dri_motion_info[0])
|
100 |
+
c_eyes_lst.append(dri_motion_info[1])
|
101 |
+
c_lip_lst.append(dri_motion_info[2])
|
102 |
+
|
103 |
+
infer_times.append(time.time() - t0)
|
104 |
+
# print(time.time() - t0)
|
105 |
+
dri_crop = cv2.resize(dri_crop, (512, 512))
|
106 |
+
out_crop = np.concatenate([dri_crop, out_crop], axis=1)
|
107 |
+
out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
|
108 |
+
if not args.realtime:
|
109 |
+
vout_crop.write(out_crop)
|
110 |
+
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
|
111 |
+
vout_org.write(out_org)
|
112 |
+
else:
|
113 |
+
if infer_cfg.infer_params.flag_pasteback:
|
114 |
+
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
|
115 |
+
cv2.imshow('Render', out_org)
|
116 |
+
else:
|
117 |
+
# image show in realtime mode
|
118 |
+
cv2.imshow('Render', out_crop)
|
119 |
+
# 按下'q'键退出循环
|
120 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
121 |
+
break
|
122 |
+
vcap.release()
|
123 |
+
if not args.realtime:
|
124 |
+
vout_crop.release()
|
125 |
+
vout_org.release()
|
126 |
+
if video_has_audio(args.dri_video):
|
127 |
+
vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4"
|
128 |
+
subprocess.call(
|
129 |
+
[FFMPEG, "-i", vsave_crop_path, "-i", args.dri_video,
|
130 |
+
"-b:v", "10M", "-c:v",
|
131 |
+
"libx264", "-map", "0:v", "-map", "1:a",
|
132 |
+
"-c:a", "aac",
|
133 |
+
"-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
|
134 |
+
vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4"
|
135 |
+
subprocess.call(
|
136 |
+
[FFMPEG, "-i", vsave_org_path, "-i", args.dri_video,
|
137 |
+
"-b:v", "10M", "-c:v",
|
138 |
+
"libx264", "-map", "0:v", "-map", "1:a",
|
139 |
+
"-c:a", "aac",
|
140 |
+
"-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
|
141 |
+
|
142 |
+
print(vsave_crop_path_new)
|
143 |
+
print(vsave_org_path_new)
|
144 |
+
else:
|
145 |
+
print(vsave_crop_path)
|
146 |
+
print(vsave_org_path)
|
147 |
+
else:
|
148 |
+
cv2.destroyAllWindows()
|
149 |
+
|
150 |
+
print(
|
151 |
+
"inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
|
152 |
+
np.mean(infer_times) * 1000))
|
153 |
+
# save driving motion to pkl
|
154 |
+
template_dct = {
|
155 |
+
'n_frames': len(motion_lst),
|
156 |
+
'output_fps': fps,
|
157 |
+
'motion': motion_lst,
|
158 |
+
'c_eyes_lst': c_eyes_lst,
|
159 |
+
'c_lip_lst': c_lip_lst,
|
160 |
+
}
|
161 |
+
template_pkl_path = os.path.join(save_dir,
|
162 |
+
f"{os.path.basename(args.dri_video)}.pkl")
|
163 |
+
with open(template_pkl_path, "wb") as fw:
|
164 |
+
pickle.dump(template_dct, fw)
|
165 |
+
print(f"save driving motion pkl file at : {template_pkl_path}")
|
166 |
+
|
167 |
+
|
168 |
+
def run_with_pkl(args):
|
169 |
+
infer_cfg = OmegaConf.load(args.cfg)
|
170 |
+
infer_cfg.infer_params.flag_pasteback = args.paste_back
|
171 |
+
|
172 |
+
pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=args.animal)
|
173 |
+
ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
|
174 |
+
if not ret:
|
175 |
+
print(f"no face in {args.src_image}! exit!")
|
176 |
+
return
|
177 |
+
with open(args.dri_video, "rb") as fin:
|
178 |
+
dri_motion_infos = pickle.load(fin)
|
179 |
+
|
180 |
+
fps = int(dri_motion_infos["output_fps"])
|
181 |
+
h, w = pipe.src_imgs[0].shape[:2]
|
182 |
+
save_dir = f"./results/{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
183 |
+
os.makedirs(save_dir, exist_ok=True)
|
184 |
+
|
185 |
+
# render output video
|
186 |
+
if not args.realtime:
|
187 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
188 |
+
vsave_crop_path = os.path.join(save_dir,
|
189 |
+
f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4")
|
190 |
+
vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512, 512))
|
191 |
+
vsave_org_path = os.path.join(save_dir,
|
192 |
+
f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4")
|
193 |
+
vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
|
194 |
+
|
195 |
+
infer_times = []
|
196 |
+
motion_lst = dri_motion_infos["motion"]
|
197 |
+
c_eyes_lst = dri_motion_infos["c_eyes_lst"] if "c_eyes_lst" in dri_motion_infos else dri_motion_infos[
|
198 |
+
"c_d_eyes_lst"]
|
199 |
+
c_lip_lst = dri_motion_infos["c_lip_lst"] if "c_lip_lst" in dri_motion_infos else dri_motion_infos["c_d_lip_lst"]
|
200 |
+
|
201 |
+
frame_num = len(motion_lst)
|
202 |
+
for frame_ind in tqdm(range(frame_num)):
|
203 |
+
t0 = time.time()
|
204 |
+
first_frame = frame_ind == 0
|
205 |
+
dri_motion_info_ = [motion_lst[frame_ind], c_eyes_lst[frame_ind], c_lip_lst[frame_ind]]
|
206 |
+
out_crop, out_org = pipe.run_with_pkl(dri_motion_info_, pipe.src_imgs[0], pipe.src_infos[0],
|
207 |
+
first_frame=first_frame)
|
208 |
+
if out_crop is None:
|
209 |
+
print(f"no face in driving frame:{frame_ind}")
|
210 |
+
continue
|
211 |
+
|
212 |
+
infer_times.append(time.time() - t0)
|
213 |
+
# print(time.time() - t0)
|
214 |
+
out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
|
215 |
+
if not args.realtime:
|
216 |
+
vout_crop.write(out_crop)
|
217 |
+
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
|
218 |
+
vout_org.write(out_org)
|
219 |
+
else:
|
220 |
+
if infer_cfg.infer_params.flag_pasteback:
|
221 |
+
out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
|
222 |
+
cv2.imshow('Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate',out_org)
|
223 |
+
else:
|
224 |
+
# image show in realtime mode
|
225 |
+
cv2.imshow('Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate', out_crop)
|
226 |
+
# Press the 'q' key to exit the loop, r to switch realtime src_webcam update, spacebar to switch sourceisWebcam
|
227 |
+
k = cv2.waitKey(1) & 0xFF
|
228 |
+
if k == ord('q'):
|
229 |
+
break
|
230 |
+
# Key for Interesting Params
|
231 |
+
if k == ord('s'):
|
232 |
+
infer_cfg.infer_params.flag_stitching = not infer_cfg.infer_params.flag_stitching
|
233 |
+
print('flag_stitching:'+str(infer_cfg.infer_params.flag_stitching))
|
234 |
+
if k == ord('z'):
|
235 |
+
infer_cfg.infer_params.flag_relative_motion = not infer_cfg.infer_params.flag_relative_motion
|
236 |
+
print('flag_relative_motion:'+str(infer_cfg.infer_params.flag_relative_motion))
|
237 |
+
if k == ord('x'):
|
238 |
+
if infer_cfg.infer_params.animation_region == "all": infer_cfg.infer_params.animation_region = "exp", print('animation_region = "exp"')
|
239 |
+
else:infer_cfg.infer_params.animation_region = "all", print('animation_region = "all"')
|
240 |
+
if k == ord('c'):
|
241 |
+
infer_cfg.infer_params.flag_crop_driving_video = not infer_cfg.infer_params.flag_crop_driving_video
|
242 |
+
print('flag_crop_driving_video:'+str(infer_cfg.infer_params.flag_crop_driving_video))
|
243 |
+
if k == ord('v'):
|
244 |
+
infer_cfg.infer_params.flag_pasteback = not infer_cfg.infer_params.flag_pasteback
|
245 |
+
print('flag_pasteback:'+str(infer_cfg.infer_params.flag_pasteback))
|
246 |
+
|
247 |
+
if k == ord('a'):
|
248 |
+
infer_cfg.infer_params.flag_normalize_lip = not infer_cfg.infer_params.flag_normalize_lip
|
249 |
+
print('flag_normalize_lip:'+str(infer_cfg.infer_params.flag_normalize_lip))
|
250 |
+
if k == ord('d'):
|
251 |
+
infer_cfg.infer_params.flag_source_video_eye_retargeting = not infer_cfg.infer_params.flag_source_video_eye_retargeting
|
252 |
+
print('flag_source_video_eye_retargeting:'+str(infer_cfg.infer_params.flag_source_video_eye_retargeting))
|
253 |
+
if k == ord('f'):
|
254 |
+
infer_cfg.infer_params.flag_video_editing_head_rotation = not infer_cfg.infer_params.flag_video_editing_head_rotation
|
255 |
+
print('flag_video_editing_head_rotation:'+str(infer_cfg.infer_params.flag_video_editing_head_rotation))
|
256 |
+
if k == ord('g'):
|
257 |
+
infer_cfg.infer_params.flag_eye_retargeting = not infer_cfg.infer_params.flag_eye_retargeting
|
258 |
+
print('flag_eye_retargeting:'+str(infer_cfg.infer_params.flag_eye_retargeting))
|
259 |
+
|
260 |
+
if k == ord('k'):
|
261 |
+
infer_cfg.crop_params.src_scale -= 0.1
|
262 |
+
ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
|
263 |
+
print('src_scale:'+str(infer_cfg.crop_params.src_scale))
|
264 |
+
if k == ord('l'):
|
265 |
+
infer_cfg.crop_params.src_scale += 0.1
|
266 |
+
ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
|
267 |
+
print('src_scale:'+str(infer_cfg.crop_params.src_scale))
|
268 |
+
if k == ord('n'):
|
269 |
+
infer_cfg.crop_params.dri_scale -= 0.1
|
270 |
+
print('dri_scale:'+str(infer_cfg.crop_params.dri_scale))
|
271 |
+
if k == ord('m'):
|
272 |
+
infer_cfg.crop_params.dri_scale += 0.1
|
273 |
+
print('dri_scale:'+str(infer_cfg.crop_params.dri_scale))
|
274 |
+
|
275 |
+
if not args.realtime:
|
276 |
+
vout_crop.release()
|
277 |
+
vout_org.release()
|
278 |
+
if video_has_audio(args.dri_video):
|
279 |
+
vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4"
|
280 |
+
subprocess.call(
|
281 |
+
[FFMPEG, "-i", vsave_crop_path, "-i", args.dri_video,
|
282 |
+
"-b:v", "10M", "-c:v",
|
283 |
+
"libx264", "-map", "0:v", "-map", "1:a",
|
284 |
+
"-c:a", "aac",
|
285 |
+
"-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
|
286 |
+
vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4"
|
287 |
+
subprocess.call(
|
288 |
+
[FFMPEG, "-i", vsave_org_path, "-i", args.dri_video,
|
289 |
+
"-b:v", "10M", "-c:v",
|
290 |
+
"libx264", "-map", "0:v", "-map", "1:a",
|
291 |
+
"-c:a", "aac",
|
292 |
+
"-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
|
293 |
+
|
294 |
+
print(vsave_crop_path_new)
|
295 |
+
print(vsave_org_path_new)
|
296 |
+
else:
|
297 |
+
print(vsave_crop_path)
|
298 |
+
print(vsave_org_path)
|
299 |
+
else:
|
300 |
+
cv2.destroyAllWindows()
|
301 |
+
|
302 |
+
print(
|
303 |
+
"inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
|
304 |
+
np.mean(infer_times) * 1000))
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == '__main__':
|
308 |
+
parser = argparse.ArgumentParser(description='Faster Live Portrait Pipeline')
|
309 |
+
parser.add_argument('--src_image', required=False, type=str, default="assets/examples/source/s12.jpg",
|
310 |
+
help='source image')
|
311 |
+
parser.add_argument('--dri_video', required=False, type=str, default="assets/examples/driving/d14.mp4",
|
312 |
+
help='driving video')
|
313 |
+
parser.add_argument('--cfg', required=False, type=str, default="configs/onnx_infer.yaml", help='inference config')
|
314 |
+
parser.add_argument('--realtime', action='store_true', help='realtime inference')
|
315 |
+
parser.add_argument('--animal', action='store_true', help='use animal model')
|
316 |
+
parser.add_argument('--paste_back', action='store_true', default=False, help='paste back to origin image')
|
317 |
+
args, unknown = parser.parse_known_args()
|
318 |
+
|
319 |
+
if args.dri_video.endswith(".pkl"):
|
320 |
+
run_with_pkl(args)
|
321 |
+
else:
|
322 |
+
run_with_video(args)
|
scripts/all_onnx2trt.bat
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
REM warping+spade model
|
4 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\warping_spade-fix.onnx
|
5 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\warping_spade-fix.onnx
|
6 |
+
|
7 |
+
REM landmark model
|
8 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\landmark.onnx
|
9 |
+
|
10 |
+
REM motion_extractor model
|
11 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\motion_extractor.onnx -p fp32
|
12 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\motion_extractor.onnx -p fp32
|
13 |
+
|
14 |
+
REM face_analysis model
|
15 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\retinaface_det_static.onnx
|
16 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\face_2dpose_106_static.onnx
|
17 |
+
|
18 |
+
REM appearance_extractor model
|
19 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\appearance_feature_extractor.onnx
|
20 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\appearance_feature_extractor.onnx
|
21 |
+
|
22 |
+
REM stitching model
|
23 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching.onnx
|
24 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_eye.onnx
|
25 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_lip.onnx
|
26 |
+
|
27 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching.onnx
|
28 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_eye.onnx
|
29 |
+
.\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_lip.onnx
|
scripts/all_onnx2trt.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# warping+spade model
|
4 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/warping_spade-fix.onnx
|
5 |
+
# landmark model
|
6 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/landmark.onnx
|
7 |
+
# motion_extractor model
|
8 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/motion_extractor.onnx -p fp32
|
9 |
+
# face_analysis model
|
10 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/retinaface_det_static.onnx
|
11 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx
|
12 |
+
# appearance_extractor model
|
13 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx
|
14 |
+
# stitching model
|
15 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching.onnx
|
16 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_eye.onnx
|
17 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_lip.onnx
|
scripts/all_onnx2trt_animal.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# warping+spade model
|
4 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.onnx
|
5 |
+
# motion_extractor model
|
6 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.onnx -p fp32
|
7 |
+
# appearance_extractor model
|
8 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.onnx
|
9 |
+
# stitching model
|
10 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching-v1.1.onnx
|
11 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.onnx
|
12 |
+
python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.onnx
|
scripts/onnx2trt.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
# SPDX-License-Identifier: Apache-2.0
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
#
|
17 |
+
|
18 |
+
import os
|
19 |
+
import pdb
|
20 |
+
import sys
|
21 |
+
import logging
|
22 |
+
import argparse
|
23 |
+
import platform
|
24 |
+
|
25 |
+
import tensorrt as trt
|
26 |
+
import ctypes
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
logging.basicConfig(level=logging.INFO)
|
30 |
+
logging.getLogger("EngineBuilder").setLevel(logging.INFO)
|
31 |
+
log = logging.getLogger("EngineBuilder")
|
32 |
+
|
33 |
+
|
34 |
+
def load_plugins(logger: trt.Logger):
|
35 |
+
# 加载插件库
|
36 |
+
if platform.system().lower() == 'linux':
|
37 |
+
ctypes.CDLL("./checkpoints/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL)
|
38 |
+
else:
|
39 |
+
ctypes.CDLL("./checkpoints/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL, winmode=0)
|
40 |
+
# 初始化TensorRT的插件库
|
41 |
+
trt.init_libnvinfer_plugins(logger, "")
|
42 |
+
|
43 |
+
|
44 |
+
class EngineBuilder:
|
45 |
+
"""
|
46 |
+
Parses an ONNX graph and builds a TensorRT engine from it.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(self, verbose=False):
|
50 |
+
"""
|
51 |
+
:param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger.
|
52 |
+
"""
|
53 |
+
self.trt_logger = trt.Logger(trt.Logger.INFO)
|
54 |
+
if verbose:
|
55 |
+
self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE
|
56 |
+
|
57 |
+
trt.init_libnvinfer_plugins(self.trt_logger, namespace="")
|
58 |
+
|
59 |
+
self.builder = trt.Builder(self.trt_logger)
|
60 |
+
self.config = self.builder.create_builder_config()
|
61 |
+
self.config.max_workspace_size = 12 * (2 ** 30) # 12 GB
|
62 |
+
|
63 |
+
profile = self.builder.create_optimization_profile()
|
64 |
+
|
65 |
+
# for face_2dpose_106.onnx
|
66 |
+
# profile.set_shape("data", (1, 3, 192, 192), (1, 3, 192, 192), (1, 3, 192, 192))
|
67 |
+
# for retinaface_det.onnx
|
68 |
+
# profile.set_shape("input.1", (1, 3, 512, 512), (1, 3, 512, 512), (1, 3, 512, 512))
|
69 |
+
|
70 |
+
self.config.add_optimization_profile(profile)
|
71 |
+
# 严格类型约束
|
72 |
+
self.config.set_flag(trt.BuilderFlag.STRICT_TYPES)
|
73 |
+
|
74 |
+
self.batch_size = None
|
75 |
+
self.network = None
|
76 |
+
self.parser = None
|
77 |
+
|
78 |
+
# 加载自定义插件
|
79 |
+
load_plugins(self.trt_logger)
|
80 |
+
|
81 |
+
def create_network(self, onnx_path):
|
82 |
+
"""
|
83 |
+
Parse the ONNX graph and create the corresponding TensorRT network definition.
|
84 |
+
:param onnx_path: The path to the ONNX graph to load.
|
85 |
+
"""
|
86 |
+
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
87 |
+
|
88 |
+
self.network = self.builder.create_network(network_flags)
|
89 |
+
self.parser = trt.OnnxParser(self.network, self.trt_logger)
|
90 |
+
|
91 |
+
onnx_path = os.path.realpath(onnx_path)
|
92 |
+
with open(onnx_path, "rb") as f:
|
93 |
+
if not self.parser.parse(f.read()):
|
94 |
+
log.error("Failed to load ONNX file: {}".format(onnx_path))
|
95 |
+
for error in range(self.parser.num_errors):
|
96 |
+
log.error(self.parser.get_error(error))
|
97 |
+
sys.exit(1)
|
98 |
+
|
99 |
+
inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
|
100 |
+
outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
|
101 |
+
|
102 |
+
log.info("Network Description")
|
103 |
+
for input in inputs:
|
104 |
+
self.batch_size = input.shape[0]
|
105 |
+
log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
|
106 |
+
for output in outputs:
|
107 |
+
log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
|
108 |
+
# assert self.batch_size > 0
|
109 |
+
self.builder.max_batch_size = 1
|
110 |
+
|
111 |
+
def create_engine(
|
112 |
+
self,
|
113 |
+
engine_path,
|
114 |
+
precision
|
115 |
+
):
|
116 |
+
"""
|
117 |
+
Build the TensorRT engine and serialize it to disk.
|
118 |
+
:param engine_path: The path where to serialize the engine to.
|
119 |
+
:param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
|
120 |
+
"""
|
121 |
+
engine_path = os.path.realpath(engine_path)
|
122 |
+
engine_dir = os.path.dirname(engine_path)
|
123 |
+
os.makedirs(engine_dir, exist_ok=True)
|
124 |
+
log.info("Building {} Engine in {}".format(precision, engine_path))
|
125 |
+
|
126 |
+
if precision == "fp16":
|
127 |
+
if not self.builder.platform_has_fast_fp16:
|
128 |
+
log.warning("FP16 is not supported natively on this platform/device")
|
129 |
+
else:
|
130 |
+
self.config.set_flag(trt.BuilderFlag.FP16)
|
131 |
+
|
132 |
+
with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
|
133 |
+
log.info("Serializing engine to file: {:}".format(engine_path))
|
134 |
+
f.write(engine.serialize())
|
135 |
+
|
136 |
+
|
137 |
+
def main(args):
|
138 |
+
builder = EngineBuilder(args.verbose)
|
139 |
+
builder.create_network(args.onnx)
|
140 |
+
builder.create_engine(
|
141 |
+
args.engine,
|
142 |
+
args.precision
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
parser = argparse.ArgumentParser()
|
148 |
+
parser.add_argument("-o", "--onnx", required=True, help="The input ONNX model file to load")
|
149 |
+
parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
|
150 |
+
parser.add_argument(
|
151 |
+
"-p",
|
152 |
+
"--precision",
|
153 |
+
default="fp16",
|
154 |
+
choices=["fp32", "fp16", "int8"],
|
155 |
+
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'",
|
156 |
+
)
|
157 |
+
parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
|
158 |
+
args = parser.parse_args()
|
159 |
+
if args.engine is None:
|
160 |
+
args.engine = args.onnx.replace(".onnx", ".trt")
|
161 |
+
main(args)
|
scripts/start_api.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
source ~/.bashrc
|
3 |
+
python api.py
|
src/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author : wenshao
|
3 |
+
# @Email : [email protected]
|
4 |
+
# @Project : FasterLivePortrait
|
5 |
+
# @FileName: __init__.py.py
|
src/models/JoyVASA/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/12/15
|
3 |
+
# @Author : wenshao
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Project : FasterLivePortrait
|
6 |
+
# @FileName: __init__.py
|
src/models/JoyVASA/common.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class PositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_model, dropout=0.1, max_len=600):
|
10 |
+
super().__init__()
|
11 |
+
self.dropout = nn.Dropout(p=dropout)
|
12 |
+
# vanilla sinusoidal encoding
|
13 |
+
pe = torch.zeros(max_len, d_model)
|
14 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
18 |
+
pe = pe.unsqueeze(0)
|
19 |
+
self.register_buffer('pe', pe)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[:, x.shape[1], :]
|
23 |
+
return self.dropout(x)
|
24 |
+
|
25 |
+
|
26 |
+
def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'):
|
27 |
+
mask = torch.ones(T, S)
|
28 |
+
for i in range(T):
|
29 |
+
mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
|
30 |
+
return (mask == 1).to(device=device)
|
31 |
+
|
32 |
+
|
33 |
+
def pad_audio(audio, audio_unit=320, pad_threshold=80):
|
34 |
+
batch_size, audio_len = audio.shape
|
35 |
+
n_units = audio_len // audio_unit
|
36 |
+
side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
|
37 |
+
if side_len >= 0:
|
38 |
+
reflect_len = side_len // 2
|
39 |
+
replicate_len = side_len % 2
|
40 |
+
if reflect_len > 0:
|
41 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
42 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
43 |
+
if replicate_len > 0:
|
44 |
+
audio = F.pad(audio, (1, 1), mode='replicate')
|
45 |
+
|
46 |
+
return audio
|
src/models/JoyVASA/dit_talking_head.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import platform
|
7 |
+
from .common import PositionalEncoding, enc_dec_mask, pad_audio
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
class DiffusionSchedule(nn.Module):
|
12 |
+
def __init__(self, num_steps, mode='linear', beta_1=1e-4, beta_T=0.02, s=0.008):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
if mode == 'linear':
|
16 |
+
betas = torch.linspace(beta_1, beta_T, num_steps)
|
17 |
+
elif mode == 'quadratic':
|
18 |
+
betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
|
19 |
+
elif mode == 'sigmoid':
|
20 |
+
betas = torch.sigmoid(torch.linspace(-5, 5, num_steps)) * (beta_T - beta_1) + beta_1
|
21 |
+
elif mode == 'cosine':
|
22 |
+
steps = num_steps + 1
|
23 |
+
x = torch.linspace(0, num_steps, steps)
|
24 |
+
alpha_bars = torch.cos(((x / num_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
25 |
+
alpha_bars = alpha_bars / alpha_bars[0]
|
26 |
+
betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
|
27 |
+
betas = torch.clip(betas, 0.0001, 0.999)
|
28 |
+
else:
|
29 |
+
raise ValueError(f'Unknown diffusion schedule {mode}!')
|
30 |
+
betas = torch.cat([torch.zeros(1), betas], dim=0) # Padding beta_0 = 0
|
31 |
+
|
32 |
+
alphas = 1 - betas
|
33 |
+
log_alphas = torch.log(alphas)
|
34 |
+
for i in range(1, log_alphas.shape[0]): # 1 to T
|
35 |
+
log_alphas[i] += log_alphas[i - 1]
|
36 |
+
alpha_bars = log_alphas.exp()
|
37 |
+
|
38 |
+
sigmas_flex = torch.sqrt(betas)
|
39 |
+
sigmas_inflex = torch.zeros_like(sigmas_flex)
|
40 |
+
for i in range(1, sigmas_flex.shape[0]):
|
41 |
+
sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[i]
|
42 |
+
sigmas_inflex = torch.sqrt(sigmas_inflex)
|
43 |
+
|
44 |
+
self.num_steps = num_steps
|
45 |
+
self.register_buffer('betas', betas)
|
46 |
+
self.register_buffer('alphas', alphas)
|
47 |
+
self.register_buffer('alpha_bars', alpha_bars)
|
48 |
+
self.register_buffer('sigmas_flex', sigmas_flex)
|
49 |
+
self.register_buffer('sigmas_inflex', sigmas_inflex)
|
50 |
+
|
51 |
+
def uniform_sample_t(self, batch_size):
|
52 |
+
ts = torch.randint(1, self.num_steps + 1, (batch_size,))
|
53 |
+
return ts.tolist()
|
54 |
+
|
55 |
+
def get_sigmas(self, t, flexibility=0):
|
56 |
+
assert 0 <= flexibility <= 1
|
57 |
+
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
|
58 |
+
return sigmas
|
59 |
+
|
60 |
+
|
61 |
+
class DitTalkingHead(nn.Module):
|
62 |
+
def __init__(self, device='cuda', target="sample", architecture="decoder",
|
63 |
+
motion_feat_dim=76, fps=25, n_motions=100, n_prev_motions=10,
|
64 |
+
audio_model="hubert", feature_dim=512, n_diff_steps=500, diff_schedule="cosine",
|
65 |
+
cfg_mode="incremental", guiding_conditions="audio,", audio_encoder_path=''):
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
# Model parameters
|
69 |
+
self.target = target # 预测原始图像还是预测噪声
|
70 |
+
self.architecture = architecture
|
71 |
+
self.motion_feat_dim = motion_feat_dim # motion 特征维度
|
72 |
+
self.fps = fps
|
73 |
+
self.n_motions = n_motions # 当前motion100个, window_length, T_w
|
74 |
+
self.n_prev_motions = n_prev_motions # 前续motion 10个, T_p
|
75 |
+
self.feature_dim = feature_dim
|
76 |
+
|
77 |
+
# Audio encoder
|
78 |
+
self.audio_model = audio_model
|
79 |
+
if self.audio_model == 'wav2vec2':
|
80 |
+
print("using wav2vec2 audio encoder ...")
|
81 |
+
from .wav2vec2 import Wav2Vec2Model
|
82 |
+
self.audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path)
|
83 |
+
# wav2vec 2.0 weights initialization
|
84 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
85 |
+
|
86 |
+
frozen_layers = [0, 1]
|
87 |
+
for name, param in self.audio_encoder.named_parameters():
|
88 |
+
if name.startswith("feature_projection"):
|
89 |
+
param.requires_grad = False
|
90 |
+
if name.startswith("encoder.layers"):
|
91 |
+
layer = int(name.split(".")[2])
|
92 |
+
if layer in frozen_layers:
|
93 |
+
param.requires_grad = False
|
94 |
+
elif self.audio_model == "wav2vec2_ori":
|
95 |
+
from .wav2vec2 import Wav2Vec2Model
|
96 |
+
self.audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path)
|
97 |
+
# wav2vec 2.0 weights initialization
|
98 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
99 |
+
elif self.audio_model == 'hubert': # 根据经验,hubert特征提取器效果更好
|
100 |
+
from .hubert import HubertModel
|
101 |
+
# from hubert import HubertModel
|
102 |
+
self.audio_encoder = HubertModel.from_pretrained(audio_encoder_path)
|
103 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
104 |
+
# print("hubert-en: ", self.audio_encoder)
|
105 |
+
|
106 |
+
frozen_layers = [0, 1]
|
107 |
+
for name, param in self.audio_encoder.named_parameters():
|
108 |
+
if name.startswith("feature_projection"):
|
109 |
+
param.requires_grad = False
|
110 |
+
if name.startswith("encoder.layers"):
|
111 |
+
layer = int(name.split(".")[2])
|
112 |
+
if layer in frozen_layers:
|
113 |
+
param.requires_grad = False
|
114 |
+
elif self.audio_model == 'hubert_zh': # 根据经验,hubert特征提取器效果更好
|
115 |
+
print("using hubert chinese")
|
116 |
+
from .hubert import HubertModel
|
117 |
+
# from hubert import HubertModel
|
118 |
+
self.audio_encoder = HubertModel.from_pretrained(audio_encoder_path)
|
119 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
120 |
+
|
121 |
+
frozen_layers = [0, 1]
|
122 |
+
for name, param in self.audio_encoder.named_parameters():
|
123 |
+
if name.startswith("feature_projection"):
|
124 |
+
param.requires_grad = False
|
125 |
+
if name.startswith("encoder.layers"):
|
126 |
+
layer = int(name.split(".")[2])
|
127 |
+
if layer in frozen_layers:
|
128 |
+
param.requires_grad = False
|
129 |
+
elif self.audio_model == 'hubert_zh_ori': # 根据经验,hubert特征提取器效果更好
|
130 |
+
print("using hubert chinese ori")
|
131 |
+
from .hubert import HubertModel
|
132 |
+
self.audio_encoder = HubertModel.from_pretrained(audio_encoder_path)
|
133 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
134 |
+
else:
|
135 |
+
raise ValueError(f'Unknown audio model {self.audio_model}!')
|
136 |
+
|
137 |
+
if architecture == 'decoder':
|
138 |
+
self.audio_feature_map = nn.Linear(768, feature_dim)
|
139 |
+
self.start_audio_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, feature_dim))
|
140 |
+
else:
|
141 |
+
raise ValueError(f'Unknown architecture {architecture}!')
|
142 |
+
|
143 |
+
self.start_motion_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, self.motion_feat_dim)) # 1, 10, 76
|
144 |
+
|
145 |
+
# Diffusion model
|
146 |
+
self.denoising_net = DenoisingNetwork(device=device, n_motions=self.n_motions,
|
147 |
+
n_prev_motions=self.n_prev_motions,
|
148 |
+
motion_feat_dim=self.motion_feat_dim, feature_dim=feature_dim)
|
149 |
+
# diffusion schedule
|
150 |
+
self.diffusion_sched = DiffusionSchedule(n_diff_steps, diff_schedule)
|
151 |
+
|
152 |
+
# Classifier-free settings
|
153 |
+
self.cfg_mode = cfg_mode
|
154 |
+
guiding_conditions = guiding_conditions.split(',') if guiding_conditions else []
|
155 |
+
self.guiding_conditions = [cond for cond in guiding_conditions if cond in ['audio']]
|
156 |
+
if 'audio' in self.guiding_conditions:
|
157 |
+
audio_feat_dim = feature_dim
|
158 |
+
self.null_audio_feat = nn.Parameter(torch.randn(1, 1, audio_feat_dim)) # 1, 1, 512
|
159 |
+
|
160 |
+
self.to(device)
|
161 |
+
|
162 |
+
@property
|
163 |
+
def device(self):
|
164 |
+
return next(self.parameters()).device
|
165 |
+
|
166 |
+
def forward(self, motion_feat, audio_or_feat, prev_motion_feat=None, prev_audio_feat=None, time_step=None,
|
167 |
+
indicator=None):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
motion_feat: (N, L, d_coef) motion coefficients or features
|
171 |
+
audio_or_feat: (N, L_audio) raw audio or audio feature
|
172 |
+
prev_motion_feat: (N, n_prev_motions, d_motion) previous motion coefficients or feature
|
173 |
+
prev_audio_feat: (N, n_prev_motions, d_audio) previous audio features
|
174 |
+
time_step: (N,)
|
175 |
+
indicator: (N, L) 0/1 indicator of real (unpadded) motion coefficients
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
motion_feat_noise: (N, L, d_motion)
|
179 |
+
"""
|
180 |
+
batch_size = motion_feat.shape[0]
|
181 |
+
|
182 |
+
# 加载语音特征
|
183 |
+
if audio_or_feat.ndim == 2: # 原始语音
|
184 |
+
# Extract audio features
|
185 |
+
assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
|
186 |
+
f'Incorrect audio length {audio_or_feat.shape[1]}'
|
187 |
+
audio_feat_saved = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
|
188 |
+
elif audio_or_feat.ndim == 3: # 语音特征
|
189 |
+
assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
|
190 |
+
audio_feat_saved = audio_or_feat
|
191 |
+
else:
|
192 |
+
raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
|
193 |
+
audio_feat = audio_feat_saved.clone()
|
194 |
+
|
195 |
+
# 前续motion特征
|
196 |
+
if prev_motion_feat is None:
|
197 |
+
prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
|
198 |
+
|
199 |
+
# 前续语音特征
|
200 |
+
if prev_audio_feat is None:
|
201 |
+
# (N, n_prev_motions, feature_dim)
|
202 |
+
prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
|
203 |
+
|
204 |
+
# Classifier-free guidance
|
205 |
+
if len(self.guiding_conditions) > 0:
|
206 |
+
assert len(self.guiding_conditions) <= 2, 'Only support 1 or 2 CFG conditions!'
|
207 |
+
if len(self.guiding_conditions) == 1 or self.cfg_mode == 'independent':
|
208 |
+
null_cond_prob = 0.5 if len(self.guiding_conditions) >= 2 else 0.1
|
209 |
+
if 'audio' in self.guiding_conditions:
|
210 |
+
mask_audio = torch.rand(batch_size, device=self.device) < null_cond_prob
|
211 |
+
audio_feat = torch.where(mask_audio.view(-1, 1, 1),
|
212 |
+
self.null_audio_feat.expand(batch_size, self.n_motions, -1),
|
213 |
+
audio_feat)
|
214 |
+
else:
|
215 |
+
# len(self.guiding_conditions) > 1 and self.cfg_mode == 'incremental'
|
216 |
+
# full (0.45), w/o style (0.45), w/o style or audio (0.1)
|
217 |
+
mask_flag = torch.rand(batch_size, device=self.device)
|
218 |
+
if 'audio' in self.guiding_conditions:
|
219 |
+
mask_audio = mask_flag > 0.9
|
220 |
+
audio_feat = torch.where(mask_audio.view(-1, 1, 1),
|
221 |
+
self.null_audio_feat.expand(batch_size, self.n_motions, -1),
|
222 |
+
audio_feat)
|
223 |
+
|
224 |
+
if time_step is None:
|
225 |
+
# Sample time step
|
226 |
+
time_step = self.diffusion_sched.uniform_sample_t(batch_size) # (N,)
|
227 |
+
|
228 |
+
# The forward diffusion process
|
229 |
+
alpha_bar = self.diffusion_sched.alpha_bars[time_step] # (N,)
|
230 |
+
c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) # (N, 1, 1)
|
231 |
+
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # (N, 1, 1)
|
232 |
+
|
233 |
+
eps = torch.randn_like(motion_feat) # (N, L, d_motion)
|
234 |
+
motion_feat_noisy = c0 * motion_feat + c1 * eps
|
235 |
+
|
236 |
+
# The reverse diffusion process
|
237 |
+
motion_feat_target = self.denoising_net(motion_feat_noisy, audio_feat,
|
238 |
+
prev_motion_feat, prev_audio_feat, time_step, indicator)
|
239 |
+
|
240 |
+
return eps, motion_feat_target, motion_feat.detach(), audio_feat_saved.detach()
|
241 |
+
|
242 |
+
def extract_audio_feature(self, audio, frame_num=None):
|
243 |
+
frame_num = frame_num or self.n_motions
|
244 |
+
|
245 |
+
# # Strategy 1: resample during audio feature extraction
|
246 |
+
# hidden_states = self.audio_encoder(pad_audio(audio), self.fps, frame_num=frame_num).last_hidden_state # (N, L, 768)
|
247 |
+
|
248 |
+
# Strategy 2: resample after audio feature extraction (BackResample)
|
249 |
+
hidden_states = self.audio_encoder(pad_audio(audio), self.fps,
|
250 |
+
frame_num=frame_num * 2).last_hidden_state # (N, 2L, 768)
|
251 |
+
hidden_states = hidden_states.transpose(1, 2) # (N, 768, 2L)
|
252 |
+
hidden_states = F.interpolate(hidden_states, size=frame_num, align_corners=False, mode='linear') # (N, 768, L)
|
253 |
+
hidden_states = hidden_states.transpose(1, 2) # (N, L, 768)
|
254 |
+
|
255 |
+
audio_feat = self.audio_feature_map(hidden_states) # (N, L, feature_dim)
|
256 |
+
return audio_feat
|
257 |
+
|
258 |
+
@torch.no_grad()
|
259 |
+
def sample(self, audio_or_feat, prev_motion_feat=None, prev_audio_feat=None,
|
260 |
+
motion_at_T=None, indicator=None, cfg_mode=None, cfg_cond=None, cfg_scale=1.15, flexibility=0,
|
261 |
+
dynamic_threshold=None, ret_traj=False):
|
262 |
+
# Check and convert inputs
|
263 |
+
batch_size = audio_or_feat.shape[0]
|
264 |
+
|
265 |
+
# Check CFG conditions
|
266 |
+
if cfg_mode is None: # Use default CFG mode
|
267 |
+
cfg_mode = self.cfg_mode
|
268 |
+
if cfg_cond is None: # Use default CFG conditions
|
269 |
+
cfg_cond = self.guiding_conditions
|
270 |
+
cfg_cond = [c for c in cfg_cond if c in ['audio', ]]
|
271 |
+
|
272 |
+
if not isinstance(cfg_scale, list):
|
273 |
+
cfg_scale = [cfg_scale] * len(cfg_cond)
|
274 |
+
|
275 |
+
# sort cfg_cond and cfg_scale
|
276 |
+
if len(cfg_cond) > 0:
|
277 |
+
cfg_cond, cfg_scale = zip(*sorted(zip(cfg_cond, cfg_scale), key=lambda x: ['audio', ].index(x[0])))
|
278 |
+
else:
|
279 |
+
cfg_cond, cfg_scale = [], []
|
280 |
+
|
281 |
+
if audio_or_feat.ndim == 2:
|
282 |
+
# Extract audio features
|
283 |
+
assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
|
284 |
+
f'Incorrect audio length {audio_or_feat.shape[1]}'
|
285 |
+
audio_feat = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
|
286 |
+
elif audio_or_feat.ndim == 3:
|
287 |
+
assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
|
288 |
+
audio_feat = audio_or_feat
|
289 |
+
else:
|
290 |
+
raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
|
291 |
+
|
292 |
+
if prev_motion_feat is None:
|
293 |
+
prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
|
294 |
+
if prev_audio_feat is None:
|
295 |
+
# (N, n_prev_motions, feature_dim)
|
296 |
+
prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
|
297 |
+
|
298 |
+
if motion_at_T is None:
|
299 |
+
motion_at_T = torch.randn((batch_size, self.n_motions, self.motion_feat_dim)).to(self.device)
|
300 |
+
|
301 |
+
# Prepare input for the reverse diffusion process (including optional classifier-free guidance)
|
302 |
+
if 'audio' in cfg_cond:
|
303 |
+
audio_feat_null = self.null_audio_feat.expand(batch_size, self.n_motions, -1)
|
304 |
+
else:
|
305 |
+
audio_feat_null = audio_feat
|
306 |
+
|
307 |
+
audio_feat_in = [audio_feat_null]
|
308 |
+
for cond in cfg_cond:
|
309 |
+
if cond == 'audio':
|
310 |
+
audio_feat_in.append(audio_feat)
|
311 |
+
|
312 |
+
n_entries = len(audio_feat_in)
|
313 |
+
audio_feat_in = torch.cat(audio_feat_in, dim=0)
|
314 |
+
prev_motion_feat_in = torch.cat([prev_motion_feat] * n_entries, dim=0)
|
315 |
+
prev_audio_feat_in = torch.cat([prev_audio_feat] * n_entries, dim=0)
|
316 |
+
indicator_in = torch.cat([indicator] * n_entries, dim=0) if indicator is not None else None
|
317 |
+
|
318 |
+
traj = {self.diffusion_sched.num_steps: motion_at_T}
|
319 |
+
for t in tqdm(range(self.diffusion_sched.num_steps, 0, -1)):
|
320 |
+
if t > 1:
|
321 |
+
z = torch.randn_like(motion_at_T)
|
322 |
+
else:
|
323 |
+
z = torch.zeros_like(motion_at_T)
|
324 |
+
|
325 |
+
alpha = self.diffusion_sched.alphas[t]
|
326 |
+
alpha_bar = self.diffusion_sched.alpha_bars[t]
|
327 |
+
alpha_bar_prev = self.diffusion_sched.alpha_bars[t - 1]
|
328 |
+
sigma = self.diffusion_sched.get_sigmas(t, flexibility)
|
329 |
+
|
330 |
+
motion_at_t = traj[t]
|
331 |
+
motion_in = torch.cat([motion_at_t] * n_entries, dim=0)
|
332 |
+
step_in = torch.tensor([t] * batch_size, device=self.device)
|
333 |
+
step_in = torch.cat([step_in] * n_entries, dim=0)
|
334 |
+
|
335 |
+
results = self.denoising_net(motion_in, audio_feat_in, prev_motion_feat_in,
|
336 |
+
prev_audio_feat_in, step_in, indicator_in)
|
337 |
+
|
338 |
+
# Apply thresholding if specified
|
339 |
+
if dynamic_threshold:
|
340 |
+
dt_ratio, dt_min, dt_max = dynamic_threshold
|
341 |
+
abs_results = results[:, -self.n_motions:].reshape(batch_size * n_entries, -1).abs()
|
342 |
+
s = torch.quantile(abs_results, dt_ratio, dim=1)
|
343 |
+
s = torch.clamp(s, min=dt_min, max=dt_max)
|
344 |
+
s = s[..., None, None]
|
345 |
+
results = torch.clamp(results, min=-s, max=s)
|
346 |
+
|
347 |
+
results = results.chunk(n_entries)
|
348 |
+
|
349 |
+
# Unconditional target (CFG) or the conditional target (non-CFG)
|
350 |
+
target_theta = results[0][:, -self.n_motions:]
|
351 |
+
# Classifier-free Guidance (optional)
|
352 |
+
for i in range(0, n_entries - 1):
|
353 |
+
if cfg_mode == 'independent':
|
354 |
+
target_theta += cfg_scale[i] * (
|
355 |
+
results[i + 1][:, -self.n_motions:] - results[0][:, -self.n_motions:])
|
356 |
+
elif cfg_mode == 'incremental':
|
357 |
+
target_theta += cfg_scale[i] * (
|
358 |
+
results[i + 1][:, -self.n_motions:] - results[i][:, -self.n_motions:])
|
359 |
+
else:
|
360 |
+
raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
|
361 |
+
|
362 |
+
if self.target == 'noise':
|
363 |
+
c0 = 1 / torch.sqrt(alpha)
|
364 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
365 |
+
motion_next = c0 * (motion_at_t - c1 * target_theta) + sigma * z
|
366 |
+
elif self.target == 'sample':
|
367 |
+
c0 = (1 - alpha_bar_prev) * torch.sqrt(alpha) / (1 - alpha_bar)
|
368 |
+
c1 = (1 - alpha) * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar)
|
369 |
+
motion_next = c0 * motion_at_t + c1 * target_theta + sigma * z
|
370 |
+
else:
|
371 |
+
raise ValueError('Unknown target type: {}'.format(self.target))
|
372 |
+
|
373 |
+
traj[t - 1] = motion_next.detach() # Stop gradient and save trajectory.
|
374 |
+
traj[t] = traj[t].cpu() # Move previous output to CPU memory.
|
375 |
+
if not ret_traj:
|
376 |
+
del traj[t]
|
377 |
+
|
378 |
+
if ret_traj:
|
379 |
+
return traj, motion_at_T, audio_feat
|
380 |
+
else:
|
381 |
+
return traj[0], motion_at_T, audio_feat
|
382 |
+
|
383 |
+
|
384 |
+
class DenoisingNetwork(nn.Module):
|
385 |
+
def __init__(self, device='cuda', motion_feat_dim=76,
|
386 |
+
use_indicator=None, architecture="decoder", feature_dim=512, n_heads=8,
|
387 |
+
n_layers=8, mlp_ratio=4, align_mask_width=1, no_use_learnable_pe=True, n_prev_motions=10,
|
388 |
+
n_motions=100, n_diff_steps=500, ):
|
389 |
+
super().__init__()
|
390 |
+
|
391 |
+
# Model parameters
|
392 |
+
self.motion_feat_dim = motion_feat_dim
|
393 |
+
self.use_indicator = use_indicator
|
394 |
+
|
395 |
+
# Transformer
|
396 |
+
self.architecture = architecture
|
397 |
+
self.feature_dim = feature_dim
|
398 |
+
self.n_heads = n_heads
|
399 |
+
self.n_layers = n_layers
|
400 |
+
self.mlp_ratio = mlp_ratio
|
401 |
+
self.align_mask_width = align_mask_width
|
402 |
+
self.use_learnable_pe = not no_use_learnable_pe
|
403 |
+
|
404 |
+
# sequence length
|
405 |
+
self.n_prev_motions = n_prev_motions
|
406 |
+
self.n_motions = n_motions
|
407 |
+
|
408 |
+
# Temporal embedding for the diffusion time step
|
409 |
+
self.TE = PositionalEncoding(self.feature_dim, max_len=n_diff_steps + 1)
|
410 |
+
self.diff_step_map = nn.Sequential(
|
411 |
+
nn.Linear(self.feature_dim, self.feature_dim),
|
412 |
+
nn.GELU(),
|
413 |
+
nn.Linear(self.feature_dim, self.feature_dim)
|
414 |
+
)
|
415 |
+
|
416 |
+
if self.use_learnable_pe:
|
417 |
+
# Learnable positional encoding
|
418 |
+
self.PE = nn.Parameter(torch.randn(1, 1 + self.n_prev_motions + self.n_motions, self.feature_dim))
|
419 |
+
else:
|
420 |
+
self.PE = PositionalEncoding(self.feature_dim)
|
421 |
+
|
422 |
+
# Transformer decoder
|
423 |
+
if self.architecture == 'decoder':
|
424 |
+
self.feature_proj = nn.Linear(self.motion_feat_dim + (1 if self.use_indicator else 0),
|
425 |
+
self.feature_dim)
|
426 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
427 |
+
d_model=self.feature_dim, nhead=self.n_heads, dim_feedforward=self.mlp_ratio * self.feature_dim,
|
428 |
+
activation='gelu', batch_first=True
|
429 |
+
)
|
430 |
+
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=self.n_layers)
|
431 |
+
if self.align_mask_width > 0:
|
432 |
+
motion_len = self.n_prev_motions + self.n_motions
|
433 |
+
alignment_mask = enc_dec_mask(motion_len, motion_len, frame_width=1,
|
434 |
+
expansion=self.align_mask_width - 1)
|
435 |
+
# print(f"alignment_mask: ", alignment_mask.shape)
|
436 |
+
# alignment_mask = F.pad(alignment_mask, (0, 0, 1, 0), value=False)
|
437 |
+
self.register_buffer('alignment_mask', alignment_mask)
|
438 |
+
else:
|
439 |
+
self.alignment_mask = None
|
440 |
+
else:
|
441 |
+
raise ValueError(f'Unknown architecture: {self.architecture}')
|
442 |
+
|
443 |
+
# Motion decoder
|
444 |
+
self.motion_dec = nn.Sequential(
|
445 |
+
nn.Linear(self.feature_dim, self.feature_dim // 2),
|
446 |
+
nn.GELU(),
|
447 |
+
nn.Linear(self.feature_dim // 2, self.motion_feat_dim),
|
448 |
+
# nn.Tanh() # 增加了一个tanh
|
449 |
+
# nn.Softmax()
|
450 |
+
)
|
451 |
+
|
452 |
+
self.to(device)
|
453 |
+
|
454 |
+
@property
|
455 |
+
def device(self):
|
456 |
+
return next(self.parameters()).device
|
457 |
+
|
458 |
+
def forward(self, motion_feat, audio_feat, prev_motion_feat, prev_audio_feat, step, indicator=None):
|
459 |
+
"""
|
460 |
+
Args:
|
461 |
+
motion_feat: (N, L, d_motion). Noisy motion feature
|
462 |
+
audio_feat: (N, L, feature_dim)
|
463 |
+
prev_motion_feat: (N, L_p, d_motion). Padded previous motion coefficients or feature
|
464 |
+
prev_audio_feat: (N, L_p, d_audio). Padded previous motion coefficients or feature
|
465 |
+
step: (N,)
|
466 |
+
indicator: (N, L). 0/1 indicator for the real (unpadded) motion feature
|
467 |
+
|
468 |
+
Returns:
|
469 |
+
motion_feat_target: (N, L_p + L, d_motion)
|
470 |
+
"""
|
471 |
+
motion_feat = motion_feat.to(audio_feat.dtype)
|
472 |
+
# Diffusion time step embedding
|
473 |
+
diff_step_embedding = self.diff_step_map(self.TE.pe[0, step]).unsqueeze(1) # (N, 1, diff_step_dim)
|
474 |
+
|
475 |
+
if indicator is not None:
|
476 |
+
indicator = torch.cat([torch.zeros((indicator.shape[0], self.n_prev_motions), device=indicator.device),
|
477 |
+
indicator], dim=1) # (N, L_p + L)
|
478 |
+
indicator = indicator.unsqueeze(-1) # (N, L_p + L, 1)
|
479 |
+
|
480 |
+
# Concat features and embeddings
|
481 |
+
if self.architecture == 'decoder':
|
482 |
+
# print("prev_motion_feat: ", prev_motion_feat.shape, "motion_feat: ", motion_feat.shape)
|
483 |
+
feats_in = torch.cat([prev_motion_feat, motion_feat], dim=1) # (N, L_p + L, d_motion)
|
484 |
+
else:
|
485 |
+
raise ValueError(f'Unknown architecture: {self.architecture}')
|
486 |
+
if self.use_indicator:
|
487 |
+
feats_in = torch.cat([feats_in, indicator], dim=-1) # (N, L_p + L, d_motion + d_audio + 1)
|
488 |
+
feats_in = self.feature_proj(feats_in) # (N, L_p + L, feature_dim)
|
489 |
+
# feats_in = torch.cat([person_feat, feats_in], dim=1) # (N, 1 + L_p + L, feature_dim)
|
490 |
+
|
491 |
+
if self.use_learnable_pe:
|
492 |
+
# feats_in = feats_in + self.PE
|
493 |
+
feats_in = feats_in + self.PE + diff_step_embedding
|
494 |
+
else:
|
495 |
+
# feats_in = self.PE(feats_in)
|
496 |
+
feats_in = self.PE(feats_in) + diff_step_embedding
|
497 |
+
|
498 |
+
# Transformer
|
499 |
+
if self.architecture == 'decoder':
|
500 |
+
audio_feat_in = torch.cat([prev_audio_feat, audio_feat], dim=1) # (N, L_p + L, d_audio)
|
501 |
+
# print(f"feats_in: {feats_in.shape}, audio_feat_in: {audio_feat_in.shape}, memory_mask: {self.alignment_mask.shape}")
|
502 |
+
feat_out = self.transformer(feats_in, audio_feat_in, memory_mask=self.alignment_mask)
|
503 |
+
else:
|
504 |
+
raise ValueError(f'Unknown architecture: {self.architecture}')
|
505 |
+
|
506 |
+
# Decode predicted motion feature noise / sample
|
507 |
+
# motion_feat_target = self.motion_dec(feat_out[:, 1:]) # (N, L_p + L, d_motion)
|
508 |
+
motion_feat_target = self.motion_dec(feat_out) # (N, L_p + L, d_motion)
|
509 |
+
|
510 |
+
return motion_feat_target
|
511 |
+
|
512 |
+
|
513 |
+
if __name__ == "__main__":
|
514 |
+
device = "cuda"
|
515 |
+
motion_feat_dim = 76
|
516 |
+
n_motions = 100 # L
|
517 |
+
n_prev_motions = 10 # L_p
|
518 |
+
|
519 |
+
L_audio = int(16000 * n_motions / 25) # 64000
|
520 |
+
d_audio = 768
|
521 |
+
|
522 |
+
N = 5
|
523 |
+
feature_dim = 512
|
524 |
+
|
525 |
+
motion_feat = torch.ones((N, n_motions, motion_feat_dim)).to(device)
|
526 |
+
prev_motion_feat = torch.ones((N, n_prev_motions, motion_feat_dim)).to(device)
|
527 |
+
|
528 |
+
audio_or_feat = torch.ones((N, L_audio)).to(device)
|
529 |
+
prev_audio_feat = torch.ones((N, n_prev_motions, d_audio)).to(device)
|
530 |
+
|
531 |
+
time_step = torch.ones(N, dtype=torch.long).to(device)
|
532 |
+
|
533 |
+
model = DitTalkingHead().to(device)
|
534 |
+
|
535 |
+
z = model(motion_feat, audio_or_feat, prev_motion_feat=None,
|
536 |
+
prev_audio_feat=None, time_step=None, indicator=None)
|
537 |
+
traj, motion_at_T, audio_feat = z[0], z[1], z[2]
|
538 |
+
print(motion_at_T.shape, audio_feat.shape)
|
src/models/JoyVASA/helper.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/12/15
|
3 |
+
# @Author : wenshao
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Project : FasterLivePortrait
|
6 |
+
# @FileName: helper.py
|
7 |
+
import os.path as osp
|
8 |
+
|
9 |
+
|
10 |
+
class NullableArgs:
|
11 |
+
def __init__(self, namespace):
|
12 |
+
for key, value in namespace.__dict__.items():
|
13 |
+
setattr(self, key, value)
|
14 |
+
|
15 |
+
def __getattr__(self, key):
|
16 |
+
# when an attribute lookup has not found the attribute
|
17 |
+
if key == 'align_mask_width':
|
18 |
+
if 'use_alignment_mask' in self.__dict__:
|
19 |
+
return 1 if self.use_alignment_mask else 0
|
20 |
+
else:
|
21 |
+
return 0
|
22 |
+
if key == 'no_head_pose':
|
23 |
+
return not self.predict_head_pose
|
24 |
+
if key == 'no_use_learnable_pe':
|
25 |
+
return not self.use_learnable_pe
|
26 |
+
|
27 |
+
return None
|
28 |
+
|
29 |
+
|
30 |
+
def make_abs_path(fn):
|
31 |
+
# return osp.join(osp.dirname(osp.realpath(__file__)), fn)
|
32 |
+
return osp.abspath(osp.join(osp.dirname(osp.realpath(__file__)), fn))
|
src/models/JoyVASA/hubert.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import HubertModel
|
2 |
+
from transformers.modeling_outputs import BaseModelOutput
|
3 |
+
|
4 |
+
from .wav2vec2 import linear_interpolation
|
5 |
+
|
6 |
+
_CONFIG_FOR_DOC = 'HubertConfig'
|
7 |
+
|
8 |
+
|
9 |
+
class HubertModel(HubertModel):
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__(config)
|
12 |
+
|
13 |
+
def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
|
14 |
+
output_hidden_states=None, return_dict=None, frame_num=None):
|
15 |
+
self.config.output_attentions = True
|
16 |
+
|
17 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
18 |
+
output_hidden_states = (
|
19 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
20 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
21 |
+
|
22 |
+
extract_features = self.feature_extractor(input_values) # (N, C, L)
|
23 |
+
# Resample the audio feature @ 50 fps to `output_fps`.
|
24 |
+
if frame_num is not None:
|
25 |
+
extract_features_len = round(frame_num * 50 / output_fps)
|
26 |
+
extract_features = extract_features[:, :, :extract_features_len]
|
27 |
+
extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
|
28 |
+
extract_features = extract_features.transpose(1, 2) # (N, L, C)
|
29 |
+
|
30 |
+
if attention_mask is not None:
|
31 |
+
# compute reduced attention_mask corresponding to feature vectors
|
32 |
+
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
33 |
+
|
34 |
+
hidden_states = self.feature_projection(extract_features)
|
35 |
+
hidden_states = self._mask_hidden_states(hidden_states)
|
36 |
+
|
37 |
+
encoder_outputs = self.encoder(
|
38 |
+
hidden_states,
|
39 |
+
attention_mask=attention_mask,
|
40 |
+
output_attentions=output_attentions,
|
41 |
+
output_hidden_states=output_hidden_states,
|
42 |
+
return_dict=return_dict,
|
43 |
+
)
|
44 |
+
|
45 |
+
hidden_states = encoder_outputs[0]
|
46 |
+
|
47 |
+
if not return_dict:
|
48 |
+
return (hidden_states,) + encoder_outputs[1:]
|
49 |
+
|
50 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
|
51 |
+
attentions=encoder_outputs.attentions, )
|
src/models/JoyVASA/wav2vec2.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import transformers
|
8 |
+
from transformers import Wav2Vec2Model
|
9 |
+
from transformers.modeling_outputs import BaseModelOutput
|
10 |
+
|
11 |
+
_CONFIG_FOR_DOC = 'Wav2Vec2Config'
|
12 |
+
|
13 |
+
|
14 |
+
# the implementation of Wav2Vec2Model is borrowed from
|
15 |
+
# https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
|
16 |
+
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
17 |
+
def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
|
18 |
+
attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
|
19 |
+
bsz, all_sz = shape
|
20 |
+
mask = np.full((bsz, all_sz), False)
|
21 |
+
|
22 |
+
all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
|
23 |
+
all_num_mask = max(min_masks, all_num_mask)
|
24 |
+
mask_idcs = []
|
25 |
+
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
26 |
+
for i in range(bsz):
|
27 |
+
if padding_mask is not None:
|
28 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
29 |
+
num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
|
30 |
+
num_mask = max(min_masks, num_mask)
|
31 |
+
else:
|
32 |
+
sz = all_sz
|
33 |
+
num_mask = all_num_mask
|
34 |
+
|
35 |
+
lengths = np.full(num_mask, mask_length)
|
36 |
+
|
37 |
+
if sum(lengths) == 0:
|
38 |
+
lengths[0] = min(mask_length, sz - 1)
|
39 |
+
|
40 |
+
min_len = min(lengths)
|
41 |
+
if sz - min_len <= num_mask:
|
42 |
+
min_len = sz - num_mask - 1
|
43 |
+
|
44 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
45 |
+
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
46 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
47 |
+
|
48 |
+
min_len = min([len(m) for m in mask_idcs])
|
49 |
+
for i, mask_idc in enumerate(mask_idcs):
|
50 |
+
if len(mask_idc) > min_len:
|
51 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
52 |
+
mask[i, mask_idc] = True
|
53 |
+
return mask
|
54 |
+
|
55 |
+
|
56 |
+
# linear interpolation layer
|
57 |
+
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
58 |
+
# features: (N, C, L)
|
59 |
+
seq_len = features.shape[2] / float(input_fps)
|
60 |
+
if output_len is None:
|
61 |
+
output_len = int(seq_len * output_fps)
|
62 |
+
output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
|
63 |
+
return output_features
|
64 |
+
|
65 |
+
|
66 |
+
class Wav2Vec2Model(Wav2Vec2Model):
|
67 |
+
def __init__(self, config):
|
68 |
+
super().__init__(config)
|
69 |
+
self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
|
70 |
+
|
71 |
+
def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
|
72 |
+
output_hidden_states=None, return_dict=None, frame_num=None):
|
73 |
+
self.config.output_attentions = True
|
74 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
75 |
+
output_hidden_states = (
|
76 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
77 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
78 |
+
|
79 |
+
hidden_states = self.feature_extractor(input_values) # (N, C, L)
|
80 |
+
# Resample the audio feature @ 50 fps to `output_fps`.
|
81 |
+
if frame_num is not None:
|
82 |
+
hidden_states_len = round(frame_num * 50 / output_fps)
|
83 |
+
hidden_states = hidden_states[:, :, :hidden_states_len]
|
84 |
+
hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
|
85 |
+
hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
|
86 |
+
|
87 |
+
if attention_mask is not None:
|
88 |
+
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
89 |
+
attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
|
90 |
+
device=hidden_states.device)
|
91 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
|
92 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
93 |
+
|
94 |
+
if self.is_old_version:
|
95 |
+
hidden_states = self.feature_projection(hidden_states)
|
96 |
+
else:
|
97 |
+
hidden_states = self.feature_projection(hidden_states)[0]
|
98 |
+
|
99 |
+
if self.config.apply_spec_augment and self.training:
|
100 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
101 |
+
if self.config.mask_time_prob > 0:
|
102 |
+
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
|
103 |
+
self.config.mask_time_length, attention_mask=attention_mask,
|
104 |
+
min_masks=2, )
|
105 |
+
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
|
106 |
+
if self.config.mask_feature_prob > 0:
|
107 |
+
mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
|
108 |
+
self.config.mask_feature_length, )
|
109 |
+
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
|
110 |
+
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
111 |
+
encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
|
112 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
113 |
+
return_dict=return_dict, )
|
114 |
+
hidden_states = encoder_outputs[0]
|
115 |
+
if not return_dict:
|
116 |
+
return (hidden_states,) + encoder_outputs[1:]
|
117 |
+
|
118 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
|
119 |
+
attentions=encoder_outputs.attentions, )
|
src/models/XPose/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/8/5 21:58
|
3 |
+
# @Author : shaoguowen
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Project : FasterLivePortrait
|
6 |
+
# @FileName: __init__.py.py
|
src/models/XPose/config_model/UniPose_SwinT.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = ['coco_transformer.py']
|
2 |
+
|
3 |
+
use_label_enc = True
|
4 |
+
|
5 |
+
num_classes=2
|
6 |
+
|
7 |
+
lr = 0.0001
|
8 |
+
param_dict_type = 'default'
|
9 |
+
lr_backbone = 1e-05
|
10 |
+
lr_backbone_names = ['backbone.0']
|
11 |
+
lr_linear_proj_names = ['reference_points', 'sampling_offsets']
|
12 |
+
lr_linear_proj_mult = 0.1
|
13 |
+
ddetr_lr_param = False
|
14 |
+
batch_size = 2
|
15 |
+
weight_decay = 0.0001
|
16 |
+
epochs = 12
|
17 |
+
lr_drop = 11
|
18 |
+
save_checkpoint_interval = 100
|
19 |
+
clip_max_norm = 0.1
|
20 |
+
onecyclelr = False
|
21 |
+
multi_step_lr = False
|
22 |
+
lr_drop_list = [33, 45]
|
23 |
+
|
24 |
+
|
25 |
+
modelname = 'UniPose'
|
26 |
+
frozen_weights = None
|
27 |
+
backbone = 'swin_T_224_1k'
|
28 |
+
|
29 |
+
|
30 |
+
dilation = False
|
31 |
+
position_embedding = 'sine'
|
32 |
+
pe_temperatureH = 20
|
33 |
+
pe_temperatureW = 20
|
34 |
+
return_interm_indices = [1, 2, 3]
|
35 |
+
backbone_freeze_keywords = None
|
36 |
+
enc_layers = 6
|
37 |
+
dec_layers = 6
|
38 |
+
unic_layers = 0
|
39 |
+
pre_norm = False
|
40 |
+
dim_feedforward = 2048
|
41 |
+
hidden_dim = 256
|
42 |
+
dropout = 0.0
|
43 |
+
nheads = 8
|
44 |
+
num_queries = 900
|
45 |
+
query_dim = 4
|
46 |
+
num_patterns = 0
|
47 |
+
pdetr3_bbox_embed_diff_each_layer = False
|
48 |
+
pdetr3_refHW = -1
|
49 |
+
random_refpoints_xy = False
|
50 |
+
fix_refpoints_hw = -1
|
51 |
+
dabdetr_yolo_like_anchor_update = False
|
52 |
+
dabdetr_deformable_encoder = False
|
53 |
+
dabdetr_deformable_decoder = False
|
54 |
+
use_deformable_box_attn = False
|
55 |
+
box_attn_type = 'roi_align'
|
56 |
+
dec_layer_number = None
|
57 |
+
num_feature_levels = 4
|
58 |
+
enc_n_points = 4
|
59 |
+
dec_n_points = 4
|
60 |
+
decoder_layer_noise = False
|
61 |
+
dln_xy_noise = 0.2
|
62 |
+
dln_hw_noise = 0.2
|
63 |
+
add_channel_attention = False
|
64 |
+
add_pos_value = False
|
65 |
+
two_stage_type = 'standard'
|
66 |
+
two_stage_pat_embed = 0
|
67 |
+
two_stage_add_query_num = 0
|
68 |
+
two_stage_bbox_embed_share = False
|
69 |
+
two_stage_class_embed_share = False
|
70 |
+
two_stage_learn_wh = False
|
71 |
+
two_stage_default_hw = 0.05
|
72 |
+
two_stage_keep_all_tokens = False
|
73 |
+
num_select = 50
|
74 |
+
transformer_activation = 'relu'
|
75 |
+
batch_norm_type = 'FrozenBatchNorm2d'
|
76 |
+
masks = False
|
77 |
+
|
78 |
+
decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content']
|
79 |
+
matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
|
80 |
+
decoder_module_seq = ['sa', 'ca', 'ffn']
|
81 |
+
nms_iou_threshold = -1
|
82 |
+
|
83 |
+
dec_pred_bbox_embed_share = True
|
84 |
+
dec_pred_class_embed_share = True
|
85 |
+
|
86 |
+
|
87 |
+
use_dn = True
|
88 |
+
dn_number = 100
|
89 |
+
dn_box_noise_scale = 1.0
|
90 |
+
dn_label_noise_ratio = 0.5
|
91 |
+
dn_label_coef=1.0
|
92 |
+
dn_bbox_coef=1.0
|
93 |
+
embed_init_tgt = True
|
94 |
+
dn_labelbook_size = 2000
|
95 |
+
|
96 |
+
match_unstable_error = True
|
97 |
+
|
98 |
+
# for ema
|
99 |
+
use_ema = True
|
100 |
+
ema_decay = 0.9997
|
101 |
+
ema_epoch = 0
|
102 |
+
|
103 |
+
use_detached_boxes_dec_out = False
|
104 |
+
|
105 |
+
max_text_len = 256
|
106 |
+
shuffle_type = None
|
107 |
+
|
108 |
+
use_text_enhancer = True
|
109 |
+
use_fusion_layer = True
|
110 |
+
|
111 |
+
use_checkpoint = False # True
|
112 |
+
use_transformer_ckpt = True
|
113 |
+
text_encoder_type = 'bert-base-uncased'
|
114 |
+
|
115 |
+
use_text_cross_attention = True
|
116 |
+
text_dropout = 0.0
|
117 |
+
fusion_dropout = 0.0
|
118 |
+
fusion_droppath = 0.1
|
119 |
+
|
120 |
+
num_body_points=68
|
121 |
+
binary_query_selection = False
|
122 |
+
use_cdn = True
|
123 |
+
ffn_extra_layernorm = False
|
124 |
+
|
125 |
+
fix_size=False
|
src/models/XPose/config_model/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/8/5 21:58
|
3 |
+
# @Author : shaoguowen
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Project : FasterLivePortrait
|
6 |
+
# @FileName: __init__.py.py
|
src/models/XPose/config_model/coco_transformer.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
|
2 |
+
data_aug_max_size = 1333
|
3 |
+
data_aug_scales2_resize = [400, 500, 600]
|
4 |
+
data_aug_scales2_crop = [384, 600]
|
5 |
+
|
6 |
+
|
7 |
+
data_aug_scale_overlap = None
|
8 |
+
|
src/models/XPose/models/UniPose/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Conditional DETR
|
3 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------
|
9 |
+
|
10 |
+
from .unipose import build_unipose
|
src/models/XPose/models/UniPose/attention.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# UniPose
|
3 |
+
# url: https://github.com/IDEA-Research/UniPose
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
# ED-Pose
|
8 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
9 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
# Conditional DETR
|
12 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
13 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
# Modified from codes in torch.nn
|
16 |
+
# ------------------------------------------------------------------------
|
17 |
+
|
18 |
+
"""
|
19 |
+
MultiheadAttention that support query, key, and value to have different dimensions.
|
20 |
+
Query, key, and value projections are removed.
|
21 |
+
|
22 |
+
Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
|
23 |
+
and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
|
24 |
+
"""
|
25 |
+
|
26 |
+
import warnings
|
27 |
+
import torch
|
28 |
+
from torch.nn.modules.linear import Linear
|
29 |
+
from torch.nn.init import constant_
|
30 |
+
from torch.nn.modules.module import Module
|
31 |
+
from torch._jit_internal import Optional, Tuple
|
32 |
+
try:
|
33 |
+
from torch.overrides import has_torch_function, handle_torch_function
|
34 |
+
except:
|
35 |
+
from torch._overrides import has_torch_function, handle_torch_function
|
36 |
+
from torch.nn.functional import linear, pad, softmax, dropout
|
37 |
+
Tensor = torch.Tensor
|
38 |
+
|
39 |
+
class MultiheadAttention(Module):
|
40 |
+
r"""Allows the model to jointly attend to information
|
41 |
+
from different representation subspaces.
|
42 |
+
See reference: Attention Is All You Need
|
43 |
+
.. math::
|
44 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
45 |
+
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
46 |
+
Args:
|
47 |
+
embed_dim: total dimension of the model.
|
48 |
+
num_heads: parallel attention heads.
|
49 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
50 |
+
bias: add bias as module parameter. Default: True.
|
51 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
52 |
+
add_zero_attn: add a new batch of zeros to the key and
|
53 |
+
value sequences at dim=1.
|
54 |
+
kdim: total number of features in key. Default: None.
|
55 |
+
vdim: total number of features in value. Default: None.
|
56 |
+
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
57 |
+
query, key, and value have the same number of features.
|
58 |
+
Examples::
|
59 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
60 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
61 |
+
"""
|
62 |
+
bias_k: Optional[torch.Tensor]
|
63 |
+
bias_v: Optional[torch.Tensor]
|
64 |
+
|
65 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
66 |
+
super(MultiheadAttention, self).__init__()
|
67 |
+
self.embed_dim = embed_dim
|
68 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
69 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
70 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
71 |
+
|
72 |
+
self.num_heads = num_heads
|
73 |
+
self.dropout = dropout
|
74 |
+
self.head_dim = embed_dim // num_heads
|
75 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
76 |
+
|
77 |
+
vdim = vdim if vdim is not None else embed_dim
|
78 |
+
self.out_proj = Linear(vdim , vdim)
|
79 |
+
|
80 |
+
self.in_proj_bias = None
|
81 |
+
self.in_proj_weight = None
|
82 |
+
self.bias_k = self.bias_v = None
|
83 |
+
self.q_proj_weight = None
|
84 |
+
self.k_proj_weight = None
|
85 |
+
self.v_proj_weight = None
|
86 |
+
|
87 |
+
self.add_zero_attn = add_zero_attn
|
88 |
+
|
89 |
+
self._reset_parameters()
|
90 |
+
|
91 |
+
def _reset_parameters(self):
|
92 |
+
constant_(self.out_proj.bias, 0.)
|
93 |
+
|
94 |
+
def __setstate__(self, state):
|
95 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
96 |
+
if '_qkv_same_embed_dim' not in state:
|
97 |
+
state['_qkv_same_embed_dim'] = True
|
98 |
+
|
99 |
+
super(MultiheadAttention, self).__setstate__(state)
|
100 |
+
|
101 |
+
def forward(self, query, key, value, key_padding_mask=None,
|
102 |
+
need_weights=True, attn_mask=None):
|
103 |
+
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
104 |
+
r"""
|
105 |
+
Args:
|
106 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
107 |
+
See "Attention Is All You Need" for more details.
|
108 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
109 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
110 |
+
the corresponding value on the attention layer will be ignored. When given
|
111 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
112 |
+
layer will be ignored
|
113 |
+
need_weights: output attn_output_weights.
|
114 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
115 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
116 |
+
Shape:
|
117 |
+
- Inputs:
|
118 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
119 |
+
the embedding dimension.
|
120 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
121 |
+
the embedding dimension.
|
122 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
123 |
+
the embedding dimension.
|
124 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
125 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
126 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
127 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
128 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
129 |
+
3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
|
130 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
131 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
132 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
133 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
134 |
+
is provided, it will be added to the attention weight.
|
135 |
+
- Outputs:
|
136 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
137 |
+
E is the embedding dimension.
|
138 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
139 |
+
L is the target sequence length, S is the source sequence length.
|
140 |
+
"""
|
141 |
+
if not self._qkv_same_embed_dim:
|
142 |
+
return multi_head_attention_forward(
|
143 |
+
query, key, value, self.embed_dim, self.num_heads,
|
144 |
+
self.in_proj_weight, self.in_proj_bias,
|
145 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
146 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
147 |
+
training=self.training,
|
148 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
149 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
150 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
151 |
+
v_proj_weight=self.v_proj_weight, out_dim=self.vdim)
|
152 |
+
else:
|
153 |
+
return multi_head_attention_forward(
|
154 |
+
query, key, value, self.embed_dim, self.num_heads,
|
155 |
+
self.in_proj_weight, self.in_proj_bias,
|
156 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
157 |
+
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
158 |
+
training=self.training,
|
159 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
160 |
+
attn_mask=attn_mask, out_dim=self.vdim)
|
161 |
+
|
162 |
+
|
163 |
+
def multi_head_attention_forward(query: Tensor,
|
164 |
+
key: Tensor,
|
165 |
+
value: Tensor,
|
166 |
+
embed_dim_to_check: int,
|
167 |
+
num_heads: int,
|
168 |
+
in_proj_weight: Tensor,
|
169 |
+
in_proj_bias: Tensor,
|
170 |
+
bias_k: Optional[Tensor],
|
171 |
+
bias_v: Optional[Tensor],
|
172 |
+
add_zero_attn: bool,
|
173 |
+
dropout_p: float,
|
174 |
+
out_proj_weight: Tensor,
|
175 |
+
out_proj_bias: Tensor,
|
176 |
+
training: bool = True,
|
177 |
+
key_padding_mask: Optional[Tensor] = None,
|
178 |
+
need_weights: bool = True,
|
179 |
+
attn_mask: Optional[Tensor] = None,
|
180 |
+
use_separate_proj_weight: bool = False,
|
181 |
+
q_proj_weight: Optional[Tensor] = None,
|
182 |
+
k_proj_weight: Optional[Tensor] = None,
|
183 |
+
v_proj_weight: Optional[Tensor] = None,
|
184 |
+
static_k: Optional[Tensor] = None,
|
185 |
+
static_v: Optional[Tensor] = None,
|
186 |
+
out_dim: Optional[Tensor] = None
|
187 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
188 |
+
r"""
|
189 |
+
Args:
|
190 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
191 |
+
See "Attention Is All You Need" for more details.
|
192 |
+
embed_dim_to_check: total dimension of the model.
|
193 |
+
num_heads: parallel attention heads.
|
194 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
195 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
196 |
+
add_zero_attn: add a new batch of zeros to the key and
|
197 |
+
value sequences at dim=1.
|
198 |
+
dropout_p: probability of an element to be zeroed.
|
199 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
200 |
+
training: apply dropout if is ``True``.
|
201 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
202 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
203 |
+
the corresponding value on the attention layer will be filled with -inf.
|
204 |
+
need_weights: output attn_output_weights.
|
205 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
206 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
207 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
208 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
209 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
210 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
211 |
+
static_k, static_v: static key and value used for attention operators.
|
212 |
+
Shape:
|
213 |
+
Inputs:
|
214 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
215 |
+
the embedding dimension.
|
216 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
217 |
+
the embedding dimension.
|
218 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
219 |
+
the embedding dimension.
|
220 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
221 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
222 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
223 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
224 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
225 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
226 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
227 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
228 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
229 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
230 |
+
is provided, it will be added to the attention weight.
|
231 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
232 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
233 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
234 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
235 |
+
Outputs:
|
236 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
237 |
+
E is the embedding dimension.
|
238 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
239 |
+
L is the target sequence length, S is the source sequence length.
|
240 |
+
"""
|
241 |
+
if not torch.jit.is_scripting():
|
242 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
|
243 |
+
out_proj_weight, out_proj_bias)
|
244 |
+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
|
245 |
+
return handle_torch_function(
|
246 |
+
multi_head_attention_forward, tens_ops, query, key, value,
|
247 |
+
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
|
248 |
+
bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
|
249 |
+
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
|
250 |
+
need_weights=need_weights, attn_mask=attn_mask,
|
251 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
252 |
+
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
|
253 |
+
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
|
254 |
+
tgt_len, bsz, embed_dim = query.size()
|
255 |
+
assert embed_dim == embed_dim_to_check
|
256 |
+
# allow MHA to have different sizes for the feature dimension
|
257 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
258 |
+
|
259 |
+
head_dim = embed_dim // num_heads
|
260 |
+
v_head_dim = out_dim // num_heads
|
261 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
262 |
+
scaling = float(head_dim) ** -0.5
|
263 |
+
|
264 |
+
q = query * scaling
|
265 |
+
k = key
|
266 |
+
v = value
|
267 |
+
|
268 |
+
if attn_mask is not None:
|
269 |
+
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
270 |
+
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
271 |
+
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
272 |
+
if attn_mask.dtype == torch.uint8:
|
273 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
274 |
+
attn_mask = attn_mask.to(torch.bool)
|
275 |
+
|
276 |
+
if attn_mask.dim() == 2:
|
277 |
+
attn_mask = attn_mask.unsqueeze(0)
|
278 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
279 |
+
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
280 |
+
elif attn_mask.dim() == 3:
|
281 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
282 |
+
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
283 |
+
else:
|
284 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
285 |
+
# attn_mask's dim is 3 now.
|
286 |
+
|
287 |
+
# convert ByteTensor key_padding_mask to bool
|
288 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
289 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
290 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
291 |
+
|
292 |
+
if bias_k is not None and bias_v is not None:
|
293 |
+
if static_k is None and static_v is None:
|
294 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
295 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
296 |
+
if attn_mask is not None:
|
297 |
+
attn_mask = pad(attn_mask, (0, 1))
|
298 |
+
if key_padding_mask is not None:
|
299 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
300 |
+
else:
|
301 |
+
assert static_k is None, "bias cannot be added to static key."
|
302 |
+
assert static_v is None, "bias cannot be added to static value."
|
303 |
+
else:
|
304 |
+
assert bias_k is None
|
305 |
+
assert bias_v is None
|
306 |
+
|
307 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
308 |
+
if k is not None:
|
309 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
310 |
+
if v is not None:
|
311 |
+
v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
|
312 |
+
|
313 |
+
if static_k is not None:
|
314 |
+
assert static_k.size(0) == bsz * num_heads
|
315 |
+
assert static_k.size(2) == head_dim
|
316 |
+
k = static_k
|
317 |
+
|
318 |
+
if static_v is not None:
|
319 |
+
assert static_v.size(0) == bsz * num_heads
|
320 |
+
assert static_v.size(2) == v_head_dim
|
321 |
+
v = static_v
|
322 |
+
|
323 |
+
src_len = k.size(1)
|
324 |
+
|
325 |
+
if key_padding_mask is not None:
|
326 |
+
assert key_padding_mask.size(0) == bsz
|
327 |
+
assert key_padding_mask.size(1) == src_len
|
328 |
+
|
329 |
+
if add_zero_attn:
|
330 |
+
src_len += 1
|
331 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
332 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
333 |
+
if attn_mask is not None:
|
334 |
+
attn_mask = pad(attn_mask, (0, 1))
|
335 |
+
if key_padding_mask is not None:
|
336 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
337 |
+
|
338 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
339 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
340 |
+
|
341 |
+
if attn_mask is not None:
|
342 |
+
if attn_mask.dtype == torch.bool:
|
343 |
+
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
344 |
+
else:
|
345 |
+
attn_output_weights += attn_mask
|
346 |
+
|
347 |
+
|
348 |
+
if key_padding_mask is not None:
|
349 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
350 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
351 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
352 |
+
float('-inf'),
|
353 |
+
)
|
354 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
355 |
+
|
356 |
+
# attn_output_weights = softmax(
|
357 |
+
# attn_output_weights, dim=-1)
|
358 |
+
attn_output_weights = softmax(
|
359 |
+
attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1)
|
360 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
|
361 |
+
|
362 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
363 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
|
364 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
|
365 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
366 |
+
|
367 |
+
if need_weights:
|
368 |
+
# average attention weights over heads
|
369 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
370 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
371 |
+
else:
|
372 |
+
return attn_output, None
|
373 |
+
|
src/models/XPose/models/UniPose/backbone.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# UniPose
|
3 |
+
# url: https://github.com/IDEA-Research/UniPose
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
# Conditional DETR
|
8 |
+
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
9 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
# Copied from DETR (https://github.com/facebookresearch/detr)
|
12 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
13 |
+
# ------------------------------------------------------------------------
|
14 |
+
|
15 |
+
"""
|
16 |
+
Backbone modules.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torchvision
|
22 |
+
from torch import nn
|
23 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
24 |
+
from typing import Dict, List
|
25 |
+
|
26 |
+
from ...util.misc import NestedTensor, is_main_process
|
27 |
+
|
28 |
+
from .position_encoding import build_position_encoding
|
29 |
+
from .swin_transformer import build_swin_transformer
|
30 |
+
|
31 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
32 |
+
"""
|
33 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
34 |
+
|
35 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
36 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
37 |
+
produce nans.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, n):
|
41 |
+
super(FrozenBatchNorm2d, self).__init__()
|
42 |
+
self.register_buffer("weight", torch.ones(n))
|
43 |
+
self.register_buffer("bias", torch.zeros(n))
|
44 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
45 |
+
self.register_buffer("running_var", torch.ones(n))
|
46 |
+
|
47 |
+
def _load_from_state_dict(
|
48 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
49 |
+
):
|
50 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
51 |
+
if num_batches_tracked_key in state_dict:
|
52 |
+
del state_dict[num_batches_tracked_key]
|
53 |
+
|
54 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
55 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
# move reshapes to the beginning
|
60 |
+
# to make it fuser-friendly
|
61 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
62 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
63 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
64 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
65 |
+
eps = 1e-5
|
66 |
+
scale = w * (rv + eps).rsqrt()
|
67 |
+
bias = b - rm * scale
|
68 |
+
return x * scale + bias
|
69 |
+
|
70 |
+
|
71 |
+
class BackboneBase(nn.Module):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
backbone: nn.Module,
|
75 |
+
train_backbone: bool,
|
76 |
+
num_channels: int,
|
77 |
+
return_interm_indices: list,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
for name, parameter in backbone.named_parameters():
|
81 |
+
if (
|
82 |
+
not train_backbone
|
83 |
+
or "layer2" not in name
|
84 |
+
and "layer3" not in name
|
85 |
+
and "layer4" not in name
|
86 |
+
):
|
87 |
+
parameter.requires_grad_(False)
|
88 |
+
|
89 |
+
return_layers = {}
|
90 |
+
for idx, layer_index in enumerate(return_interm_indices):
|
91 |
+
return_layers.update(
|
92 |
+
{"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
|
93 |
+
)
|
94 |
+
|
95 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
96 |
+
self.num_channels = num_channels
|
97 |
+
|
98 |
+
def forward(self, tensor_list: NestedTensor):
|
99 |
+
xs = self.body(tensor_list.tensors)
|
100 |
+
out: Dict[str, NestedTensor] = {}
|
101 |
+
for name, x in xs.items():
|
102 |
+
m = tensor_list.mask
|
103 |
+
assert m is not None
|
104 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
105 |
+
out[name] = NestedTensor(x, mask)
|
106 |
+
# import ipdb; ipdb.set_trace()
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class Backbone(BackboneBase):
|
111 |
+
"""ResNet backbone with frozen BatchNorm."""
|
112 |
+
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
name: str,
|
116 |
+
train_backbone: bool,
|
117 |
+
dilation: bool,
|
118 |
+
return_interm_indices: list,
|
119 |
+
batch_norm=FrozenBatchNorm2d,
|
120 |
+
):
|
121 |
+
if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
|
122 |
+
backbone = getattr(torchvision.models, name)(
|
123 |
+
replace_stride_with_dilation=[False, False, dilation],
|
124 |
+
pretrained=is_main_process(),
|
125 |
+
norm_layer=batch_norm,
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
raise NotImplementedError("Why you can get here with name {}".format(name))
|
129 |
+
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
130 |
+
assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
|
131 |
+
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
132 |
+
num_channels_all = [256, 512, 1024, 2048]
|
133 |
+
num_channels = num_channels_all[4 - len(return_interm_indices) :]
|
134 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
|
135 |
+
|
136 |
+
|
137 |
+
class Joiner(nn.Sequential):
|
138 |
+
def __init__(self, backbone, position_embedding):
|
139 |
+
super().__init__(backbone, position_embedding)
|
140 |
+
|
141 |
+
def forward(self, tensor_list: NestedTensor):
|
142 |
+
xs = self[0](tensor_list)
|
143 |
+
out: List[NestedTensor] = []
|
144 |
+
pos = []
|
145 |
+
for name, x in xs.items():
|
146 |
+
out.append(x)
|
147 |
+
# position encoding
|
148 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
149 |
+
|
150 |
+
return out, pos
|
151 |
+
|
152 |
+
|
153 |
+
def build_backbone(args):
|
154 |
+
"""
|
155 |
+
Useful args:
|
156 |
+
- backbone: backbone name
|
157 |
+
- lr_backbone:
|
158 |
+
- dilation
|
159 |
+
- return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
|
160 |
+
- backbone_freeze_keywords:
|
161 |
+
- use_checkpoint: for swin only for now
|
162 |
+
|
163 |
+
"""
|
164 |
+
position_embedding = build_position_encoding(args)
|
165 |
+
train_backbone = True
|
166 |
+
if not train_backbone:
|
167 |
+
raise ValueError("Please set lr_backbone > 0")
|
168 |
+
return_interm_indices = args.return_interm_indices
|
169 |
+
assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
|
170 |
+
args.backbone_freeze_keywords
|
171 |
+
use_checkpoint = getattr(args, "use_checkpoint", False)
|
172 |
+
|
173 |
+
if args.backbone in ["resnet50", "resnet101"]:
|
174 |
+
backbone = Backbone(
|
175 |
+
args.backbone,
|
176 |
+
train_backbone,
|
177 |
+
args.dilation,
|
178 |
+
return_interm_indices,
|
179 |
+
batch_norm=FrozenBatchNorm2d,
|
180 |
+
)
|
181 |
+
bb_num_channels = backbone.num_channels
|
182 |
+
elif args.backbone in [
|
183 |
+
"swin_T_224_1k",
|
184 |
+
"swin_B_224_22k",
|
185 |
+
"swin_B_384_22k",
|
186 |
+
"swin_L_224_22k",
|
187 |
+
"swin_L_384_22k",
|
188 |
+
]:
|
189 |
+
pretrain_img_size = int(args.backbone.split("_")[-2])
|
190 |
+
backbone = build_swin_transformer(
|
191 |
+
args.backbone,
|
192 |
+
pretrain_img_size=pretrain_img_size,
|
193 |
+
out_indices=tuple(return_interm_indices),
|
194 |
+
dilation=False,
|
195 |
+
use_checkpoint=use_checkpoint,
|
196 |
+
)
|
197 |
+
|
198 |
+
bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
|
199 |
+
else:
|
200 |
+
raise NotImplementedError("Unknown backbone {}".format(args.backbone))
|
201 |
+
|
202 |
+
assert len(bb_num_channels) == len(
|
203 |
+
return_interm_indices
|
204 |
+
), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
|
205 |
+
|
206 |
+
model = Joiner(backbone, position_embedding)
|
207 |
+
model.num_channels = bb_num_channels
|
208 |
+
assert isinstance(
|
209 |
+
bb_num_channels, List
|
210 |
+
), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
|
211 |
+
return model
|
src/models/XPose/models/UniPose/deformable_transformer.py
ADDED
@@ -0,0 +1,1230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# UniPose
|
3 |
+
# url: https://github.com/IDEA-Research/UniPose
|
4 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------
|
7 |
+
# ED-Pose
|
8 |
+
# Copyright (c) 2023 IDEA. All Rights Reserved.
|
9 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
# DINO
|
12 |
+
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
13 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
16 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
17 |
+
# ------------------------------------------------------------------------
|
18 |
+
|
19 |
+
import math
|
20 |
+
import copy
|
21 |
+
import torch
|
22 |
+
import torch.utils.checkpoint as checkpoint
|
23 |
+
from torch import nn, Tensor
|
24 |
+
from typing import Optional
|
25 |
+
from ...util.misc import inverse_sigmoid
|
26 |
+
|
27 |
+
from .transformer_vanilla import TransformerEncoderLayer
|
28 |
+
from .fuse_modules import BiAttentionBlock
|
29 |
+
from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position, get_sine_pos_embed
|
30 |
+
from .ops.modules import MSDeformAttn
|
31 |
+
|
32 |
+
|
33 |
+
class DeformableTransformer(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, d_model=256, nhead=8,
|
36 |
+
num_queries=300,
|
37 |
+
num_encoder_layers=6,
|
38 |
+
num_unicoder_layers=0,
|
39 |
+
num_decoder_layers=6,
|
40 |
+
dim_feedforward=2048, dropout=0.0,
|
41 |
+
activation="relu", normalize_before=False,
|
42 |
+
return_intermediate_dec=False, query_dim=4,
|
43 |
+
num_patterns=0,
|
44 |
+
modulate_hw_attn=False,
|
45 |
+
# for deformable encoder
|
46 |
+
deformable_encoder=False,
|
47 |
+
deformable_decoder=False,
|
48 |
+
num_feature_levels=1,
|
49 |
+
enc_n_points=4,
|
50 |
+
dec_n_points=4,
|
51 |
+
use_deformable_box_attn=False,
|
52 |
+
box_attn_type='roi_align',
|
53 |
+
# init query
|
54 |
+
learnable_tgt_init=False,
|
55 |
+
decoder_query_perturber=None,
|
56 |
+
add_channel_attention=False,
|
57 |
+
add_pos_value=False,
|
58 |
+
random_refpoints_xy=False,
|
59 |
+
# two stage
|
60 |
+
two_stage_type='no',
|
61 |
+
two_stage_pat_embed=0,
|
62 |
+
two_stage_add_query_num=0,
|
63 |
+
two_stage_learn_wh=False,
|
64 |
+
two_stage_keep_all_tokens=False,
|
65 |
+
# evo of #anchors
|
66 |
+
dec_layer_number=None,
|
67 |
+
rm_enc_query_scale=True,
|
68 |
+
rm_dec_query_scale=True,
|
69 |
+
rm_self_attn_layers=None,
|
70 |
+
key_aware_type=None,
|
71 |
+
# layer share
|
72 |
+
layer_share_type=None,
|
73 |
+
# for detach
|
74 |
+
rm_detach=None,
|
75 |
+
decoder_sa_type='ca',
|
76 |
+
module_seq=['sa', 'ca', 'ffn'],
|
77 |
+
# for dn
|
78 |
+
embed_init_tgt=False,
|
79 |
+
|
80 |
+
use_detached_boxes_dec_out=False,
|
81 |
+
use_text_enhancer=False,
|
82 |
+
use_fusion_layer=False,
|
83 |
+
use_checkpoint=False,
|
84 |
+
use_transformer_ckpt=False,
|
85 |
+
use_text_cross_attention=False,
|
86 |
+
text_dropout=0.1,
|
87 |
+
fusion_dropout=0.1,
|
88 |
+
fusion_droppath=0.0,
|
89 |
+
|
90 |
+
binary_query_selection=False,
|
91 |
+
ffn_extra_layernorm=False,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
self.num_feature_levels = num_feature_levels
|
95 |
+
self.num_encoder_layers = num_encoder_layers
|
96 |
+
self.num_unicoder_layers = num_unicoder_layers
|
97 |
+
self.num_decoder_layers = num_decoder_layers
|
98 |
+
self.deformable_encoder = deformable_encoder
|
99 |
+
self.deformable_decoder = deformable_decoder
|
100 |
+
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
|
101 |
+
self.num_queries = num_queries
|
102 |
+
self.random_refpoints_xy = random_refpoints_xy
|
103 |
+
self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
|
104 |
+
self.ffn_extra_layernorm = ffn_extra_layernorm
|
105 |
+
assert query_dim == 4
|
106 |
+
|
107 |
+
self.binary_query_selection = binary_query_selection
|
108 |
+
if self.binary_query_selection:
|
109 |
+
self.binary_query_selection_layer = nn.Linear(d_model, 1)
|
110 |
+
# assert not binary_query_selection, 'binary_query_selection not implemented yet'
|
111 |
+
|
112 |
+
if num_feature_levels > 1:
|
113 |
+
assert deformable_encoder, "only support deformable_encoder for num_feature_levels > 1"
|
114 |
+
if use_deformable_box_attn:
|
115 |
+
assert deformable_encoder or deformable_encoder
|
116 |
+
|
117 |
+
assert layer_share_type in [None, 'encoder', 'decoder', 'both']
|
118 |
+
if layer_share_type in ['encoder', 'both']:
|
119 |
+
enc_layer_share = True
|
120 |
+
else:
|
121 |
+
enc_layer_share = False
|
122 |
+
if layer_share_type in ['decoder', 'both']:
|
123 |
+
dec_layer_share = True
|
124 |
+
else:
|
125 |
+
dec_layer_share = False
|
126 |
+
assert layer_share_type is None
|
127 |
+
|
128 |
+
self.decoder_sa_type = decoder_sa_type
|
129 |
+
assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
|
130 |
+
|
131 |
+
# choose encoder layer type
|
132 |
+
if deformable_encoder:
|
133 |
+
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
|
134 |
+
dropout, activation,
|
135 |
+
num_feature_levels, nhead, enc_n_points,
|
136 |
+
add_channel_attention=add_channel_attention,
|
137 |
+
use_deformable_box_attn=use_deformable_box_attn,
|
138 |
+
box_attn_type=box_attn_type)
|
139 |
+
else:
|
140 |
+
raise NotImplementedError
|
141 |
+
|
142 |
+
if use_text_enhancer:
|
143 |
+
text_enhance_layer = TransformerEncoderLayer(
|
144 |
+
d_model=d_model,
|
145 |
+
nhead=nhead // 2,
|
146 |
+
dim_feedforward=dim_feedforward // 2,
|
147 |
+
dropout=text_dropout
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
text_enhance_layer = None
|
151 |
+
|
152 |
+
if use_fusion_layer:
|
153 |
+
feature_fusion_layer = BiAttentionBlock(
|
154 |
+
v_dim=d_model,
|
155 |
+
l_dim=d_model,
|
156 |
+
embed_dim=dim_feedforward // 2,
|
157 |
+
num_heads=nhead // 2,
|
158 |
+
dropout=fusion_dropout,
|
159 |
+
drop_path=fusion_droppath
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
feature_fusion_layer = None
|
163 |
+
|
164 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
165 |
+
assert encoder_norm is None
|
166 |
+
self.encoder = TransformerEncoder(
|
167 |
+
encoder_layer, num_encoder_layers, d_model=d_model,
|
168 |
+
num_queries=num_queries,
|
169 |
+
enc_layer_share=enc_layer_share,
|
170 |
+
text_enhance_layer=text_enhance_layer,
|
171 |
+
feature_fusion_layer=feature_fusion_layer,
|
172 |
+
use_checkpoint=use_checkpoint,
|
173 |
+
use_transformer_ckpt=use_transformer_ckpt,
|
174 |
+
)
|
175 |
+
|
176 |
+
# choose decoder layer type
|
177 |
+
if deformable_decoder:
|
178 |
+
decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
|
179 |
+
dropout, activation,
|
180 |
+
num_feature_levels, nhead, dec_n_points,
|
181 |
+
use_text_cross_attention=use_text_cross_attention,
|
182 |
+
ffn_extra_layernorm=ffn_extra_layernorm, )
|
183 |
+
|
184 |
+
else:
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
decoder_norm = nn.LayerNorm(d_model)
|
188 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
189 |
+
return_intermediate=return_intermediate_dec,
|
190 |
+
d_model=d_model, query_dim=query_dim,
|
191 |
+
modulate_hw_attn=modulate_hw_attn,
|
192 |
+
num_feature_levels=num_feature_levels,
|
193 |
+
deformable_decoder=deformable_decoder,
|
194 |
+
decoder_query_perturber=decoder_query_perturber,
|
195 |
+
dec_layer_number=dec_layer_number, rm_dec_query_scale=rm_dec_query_scale,
|
196 |
+
dec_layer_share=dec_layer_share,
|
197 |
+
use_detached_boxes_dec_out=use_detached_boxes_dec_out
|
198 |
+
)
|
199 |
+
|
200 |
+
self.d_model = d_model
|
201 |
+
self.nhead = nhead
|
202 |
+
self.dec_layers = num_decoder_layers
|
203 |
+
self.num_queries = num_queries # useful for single stage model only
|
204 |
+
self.num_patterns = num_patterns
|
205 |
+
if not isinstance(num_patterns, int):
|
206 |
+
Warning("num_patterns should be int but {}".format(type(num_patterns)))
|
207 |
+
self.num_patterns = 0
|
208 |
+
|
209 |
+
if num_feature_levels > 1:
|
210 |
+
if self.num_encoder_layers > 0:
|
211 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
212 |
+
else:
|
213 |
+
self.level_embed = None
|
214 |
+
|
215 |
+
self.learnable_tgt_init = learnable_tgt_init
|
216 |
+
assert learnable_tgt_init, "why not learnable_tgt_init"
|
217 |
+
self.embed_init_tgt = embed_init_tgt
|
218 |
+
if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type == 'no'):
|
219 |
+
self.tgt_embed = nn.Embedding(self.num_queries, d_model)
|
220 |
+
nn.init.normal_(self.tgt_embed.weight.data)
|
221 |
+
else:
|
222 |
+
self.tgt_embed = None
|
223 |
+
|
224 |
+
# for two stage
|
225 |
+
self.two_stage_type = two_stage_type
|
226 |
+
self.two_stage_pat_embed = two_stage_pat_embed
|
227 |
+
self.two_stage_add_query_num = two_stage_add_query_num
|
228 |
+
self.two_stage_learn_wh = two_stage_learn_wh
|
229 |
+
assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type)
|
230 |
+
if two_stage_type == 'standard':
|
231 |
+
# anchor selection at the output of encoder
|
232 |
+
self.enc_output = nn.Linear(d_model, d_model)
|
233 |
+
self.enc_output_norm = nn.LayerNorm(d_model)
|
234 |
+
|
235 |
+
if two_stage_pat_embed > 0:
|
236 |
+
self.pat_embed_for_2stage = nn.Parameter(torch.Tensor(two_stage_pat_embed, d_model))
|
237 |
+
nn.init.normal_(self.pat_embed_for_2stage)
|
238 |
+
|
239 |
+
if two_stage_add_query_num > 0:
|
240 |
+
self.tgt_embed = nn.Embedding(self.two_stage_add_query_num, d_model)
|
241 |
+
|
242 |
+
if two_stage_learn_wh:
|
243 |
+
# import ipdb; ipdb.set_trace()
|
244 |
+
self.two_stage_wh_embedding = nn.Embedding(1, 2)
|
245 |
+
else:
|
246 |
+
self.two_stage_wh_embedding = None
|
247 |
+
|
248 |
+
if two_stage_type == 'no':
|
249 |
+
self.init_ref_points(num_queries) # init self.refpoint_embed
|
250 |
+
|
251 |
+
self.enc_out_class_embed = None
|
252 |
+
self.enc_out_bbox_embed = None
|
253 |
+
|
254 |
+
# evolution of anchors
|
255 |
+
self.dec_layer_number = dec_layer_number
|
256 |
+
if dec_layer_number is not None:
|
257 |
+
if self.two_stage_type != 'no' or num_patterns == 0:
|
258 |
+
assert dec_layer_number[
|
259 |
+
0] == num_queries, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})"
|
260 |
+
else:
|
261 |
+
assert dec_layer_number[
|
262 |
+
0] == num_queries * num_patterns, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})"
|
263 |
+
|
264 |
+
self._reset_parameters()
|
265 |
+
|
266 |
+
self.rm_self_attn_layers = rm_self_attn_layers
|
267 |
+
if rm_self_attn_layers is not None:
|
268 |
+
# assert len(rm_self_attn_layers) == num_decoder_layers
|
269 |
+
print("Removing the self-attn in {} decoder layers".format(rm_self_attn_layers))
|
270 |
+
for lid, dec_layer in enumerate(self.decoder.layers):
|
271 |
+
if lid in rm_self_attn_layers:
|
272 |
+
dec_layer.rm_self_attn_modules()
|
273 |
+
|
274 |
+
self.rm_detach = rm_detach
|
275 |
+
if self.rm_detach:
|
276 |
+
assert isinstance(rm_detach, list)
|
277 |
+
assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
|
278 |
+
self.decoder.rm_detach = rm_detach
|
279 |
+
|
280 |
+
def _reset_parameters(self):
|
281 |
+
for p in self.parameters():
|
282 |
+
if p.dim() > 1:
|
283 |
+
nn.init.xavier_uniform_(p)
|
284 |
+
for m in self.modules():
|
285 |
+
if isinstance(m, MSDeformAttn):
|
286 |
+
m._reset_parameters()
|
287 |
+
if self.num_feature_levels > 1 and self.level_embed is not None:
|
288 |
+
nn.init.normal_(self.level_embed)
|
289 |
+
|
290 |
+
if self.two_stage_learn_wh:
|
291 |
+
nn.init.constant_(self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05)))
|
292 |
+
|
293 |
+
def get_valid_ratio(self, mask):
|
294 |
+
_, H, W = mask.shape
|
295 |
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
296 |
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
297 |
+
valid_ratio_h = valid_H.float() / H
|
298 |
+
valid_ratio_w = valid_W.float() / W
|
299 |
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
300 |
+
return valid_ratio
|
301 |
+
|
302 |
+
def init_ref_points(self, use_num_queries):
|
303 |
+
self.refpoint_embed = nn.Embedding(use_num_queries, 4)
|
304 |
+
|
305 |
+
if self.random_refpoints_xy:
|
306 |
+
# import ipdb; ipdb.set_trace()
|
307 |
+
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
|
308 |
+
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
|
309 |
+
self.refpoint_embed.weight.data[:, :2].requires_grad = False
|
310 |
+
|
311 |
+
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, attn_mask2=None, text_dict=None,
|
312 |
+
dn_meta=None,targets=None,kpt_embed=None):
|
313 |
+
"""
|
314 |
+
Input:
|
315 |
+
- srcs: List of multi features [bs, ci, hi, wi]
|
316 |
+
- masks: List of multi masks [bs, hi, wi]
|
317 |
+
- refpoint_embed: [bs, num_dn, 4]. None in infer
|
318 |
+
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
|
319 |
+
- tgt: [bs, num_dn, d_model]. None in infer
|
320 |
+
|
321 |
+
"""
|
322 |
+
# if self.two_stage_type != 'no' and self.two_stage_add_query_num == 0:
|
323 |
+
# assert refpoint_embed is None
|
324 |
+
|
325 |
+
# prepare input for encoder
|
326 |
+
src_flatten = []
|
327 |
+
mask_flatten = []
|
328 |
+
lvl_pos_embed_flatten = []
|
329 |
+
spatial_shapes = []
|
330 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
331 |
+
bs, c, h, w = src.shape
|
332 |
+
spatial_shape = (h, w)
|
333 |
+
spatial_shapes.append(spatial_shape)
|
334 |
+
|
335 |
+
src = src.flatten(2).transpose(1, 2) # bs, hw, c
|
336 |
+
mask = mask.flatten(1) # bs, hw
|
337 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
|
338 |
+
if self.num_feature_levels > 1 and self.level_embed is not None:
|
339 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
340 |
+
else:
|
341 |
+
lvl_pos_embed = pos_embed
|
342 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
343 |
+
src_flatten.append(src)
|
344 |
+
mask_flatten.append(mask)
|
345 |
+
src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
|
346 |
+
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
|
347 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
|
348 |
+
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
349 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
350 |
+
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
351 |
+
|
352 |
+
# two stage
|
353 |
+
enc_topk_proposals = enc_refpoint_embed = None
|
354 |
+
|
355 |
+
#########################################################
|
356 |
+
# Begin Encoder
|
357 |
+
#########################################################
|
358 |
+
memory, memory_text = self.encoder(
|
359 |
+
src_flatten,
|
360 |
+
pos=lvl_pos_embed_flatten,
|
361 |
+
level_start_index=level_start_index,
|
362 |
+
spatial_shapes=spatial_shapes,
|
363 |
+
valid_ratios=valid_ratios,
|
364 |
+
key_padding_mask=mask_flatten,
|
365 |
+
memory_text=text_dict['encoded_text'],
|
366 |
+
text_attention_mask=~text_dict['text_token_mask'],
|
367 |
+
# we ~ the mask . False means use the token; True means pad the token
|
368 |
+
position_ids=text_dict['position_ids'],
|
369 |
+
text_self_attention_masks=text_dict['text_self_attention_masks'],
|
370 |
+
)
|
371 |
+
#########################################################
|
372 |
+
# End Encoder
|
373 |
+
# - memory: bs, \sum{hw}, c
|
374 |
+
# - mask_flatten: bs, \sum{hw}
|
375 |
+
# - lvl_pos_embed_flatten: bs, \sum{hw}, c
|
376 |
+
# - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
|
377 |
+
# - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
|
378 |
+
#########################################################
|
379 |
+
text_dict['encoded_text'] = memory_text
|
380 |
+
|
381 |
+
if self.two_stage_type == 'standard':
|
382 |
+
if self.two_stage_learn_wh:
|
383 |
+
input_hw = self.two_stage_wh_embedding.weight[0]
|
384 |
+
else:
|
385 |
+
input_hw = None
|
386 |
+
output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes,
|
387 |
+
input_hw)
|
388 |
+
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
389 |
+
|
390 |
+
if self.two_stage_pat_embed > 0:
|
391 |
+
bs, nhw, _ = output_memory.shape
|
392 |
+
# output_memory: bs, n, 256; self.pat_embed_for_2stage: k, 256
|
393 |
+
output_memory = output_memory.repeat(1, self.two_stage_pat_embed, 1)
|
394 |
+
_pats = self.pat_embed_for_2stage.repeat_interleave(nhw, 0)
|
395 |
+
output_memory = output_memory + _pats
|
396 |
+
output_proposals = output_proposals.repeat(1, self.two_stage_pat_embed, 1)
|
397 |
+
|
398 |
+
if self.two_stage_add_query_num > 0:
|
399 |
+
assert refpoint_embed is not None
|
400 |
+
output_memory = torch.cat((output_memory, tgt), dim=1)
|
401 |
+
output_proposals = torch.cat((output_proposals, refpoint_embed), dim=1)
|
402 |
+
|
403 |
+
if self.binary_query_selection:
|
404 |
+
topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1)
|
405 |
+
else:
|
406 |
+
if text_dict is not None:
|
407 |
+
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
|
408 |
+
else:
|
409 |
+
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
|
410 |
+
|
411 |
+
topk_logits = enc_outputs_class_unselected.max(-1)[0]
|
412 |
+
enc_outputs_coord_unselected = self.enc_out_bbox_embed(
|
413 |
+
output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
|
414 |
+
topk = self.num_queries
|
415 |
+
|
416 |
+
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
|
417 |
+
|
418 |
+
# gather boxes
|
419 |
+
refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1,
|
420 |
+
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
|
421 |
+
refpoint_embed_ = refpoint_embed_undetach.detach()
|
422 |
+
init_box_proposal = torch.gather(output_proposals, 1,
|
423 |
+
topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid
|
424 |
+
|
425 |
+
# gather tgt
|
426 |
+
tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
|
427 |
+
if self.embed_init_tgt:
|
428 |
+
tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
|
429 |
+
else:
|
430 |
+
tgt_ = tgt_undetach.detach()
|
431 |
+
|
432 |
+
if refpoint_embed is not None:
|
433 |
+
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
|
434 |
+
tgt = torch.cat([tgt, tgt_], dim=1)
|
435 |
+
else:
|
436 |
+
refpoint_embed, tgt = refpoint_embed_, tgt_
|
437 |
+
|
438 |
+
elif self.two_stage_type == 'no':
|
439 |
+
tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
|
440 |
+
refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, 4
|
441 |
+
|
442 |
+
if refpoint_embed is not None:
|
443 |
+
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
|
444 |
+
tgt = torch.cat([tgt, tgt_], dim=1)
|
445 |
+
else:
|
446 |
+
refpoint_embed, tgt = refpoint_embed_, tgt_
|
447 |
+
|
448 |
+
if self.num_patterns > 0:
|
449 |
+
tgt_embed = tgt.repeat(1, self.num_patterns, 1)
|
450 |
+
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
|
451 |
+
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(self.num_queries,
|
452 |
+
1) # 1, n_q*n_pat, d_model
|
453 |
+
tgt = tgt_embed + tgt_pat
|
454 |
+
|
455 |
+
init_box_proposal = refpoint_embed_.sigmoid()
|
456 |
+
|
457 |
+
else:
|
458 |
+
raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
|
459 |
+
#########################################################
|
460 |
+
# End preparing tgt
|
461 |
+
# - tgt: bs, NQ, d_model
|
462 |
+
# - refpoint_embed(unsigmoid): bs, NQ, d_model
|
463 |
+
#########################################################
|
464 |
+
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
465 |
+
# if refpoint_embed.isnan().any() | refpoint_embed.isinf().any():
|
466 |
+
# import ipdb; ipdb.set_trace()
|
467 |
+
# if tgt.isnan().any() | tgt.isinf().any():
|
468 |
+
# import ipdb; ipdb.set_trace()
|
469 |
+
|
470 |
+
#########################################################
|
471 |
+
# Begin Decoder
|
472 |
+
#########################################################
|
473 |
+
hs, references = self.decoder(
|
474 |
+
tgt=tgt.transpose(0, 1),
|
475 |
+
memory=memory.transpose(0, 1),
|
476 |
+
memory_key_padding_mask=mask_flatten,
|
477 |
+
pos=lvl_pos_embed_flatten.transpose(0, 1),
|
478 |
+
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
|
479 |
+
level_start_index=level_start_index,
|
480 |
+
spatial_shapes=spatial_shapes,
|
481 |
+
valid_ratios=valid_ratios, tgt_mask=attn_mask,
|
482 |
+
tgt_mask2=attn_mask2,
|
483 |
+
memory_text=text_dict['encoded_text'],
|
484 |
+
text_attention_mask=~text_dict['text_token_mask'],
|
485 |
+
text_dict=text_dict,
|
486 |
+
dn_meta=dn_meta,
|
487 |
+
targets=targets,
|
488 |
+
kpt_embed=kpt_embed
|
489 |
+
# we ~ the mask . False means use the token; True means pad the token
|
490 |
+
)
|
491 |
+
#########################################################
|
492 |
+
# End Decoder
|
493 |
+
# hs: n_dec, bs, nq, d_model
|
494 |
+
# references: n_dec+1, bs, nq, query_dim
|
495 |
+
#########################################################
|
496 |
+
|
497 |
+
#########################################################
|
498 |
+
# Begin postprocess
|
499 |
+
#########################################################
|
500 |
+
if self.two_stage_type == 'standard':
|
501 |
+
if self.two_stage_keep_all_tokens:
|
502 |
+
hs_enc = output_memory.unsqueeze(0)
|
503 |
+
ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
|
504 |
+
init_box_proposal = output_proposals
|
505 |
+
# import ipdb; ipdb.set_trace()
|
506 |
+
else:
|
507 |
+
hs_enc = tgt_undetach.unsqueeze(0)
|
508 |
+
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
|
509 |
+
else:
|
510 |
+
hs_enc = ref_enc = None
|
511 |
+
#########################################################
|
512 |
+
# End postprocess
|
513 |
+
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
|
514 |
+
# ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
|
515 |
+
#########################################################
|
516 |
+
|
517 |
+
return hs, references, hs_enc, ref_enc, init_box_proposal
|
518 |
+
# hs: (n_dec, bs, nq, d_model)
|
519 |
+
# references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
|
520 |
+
# hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
|
521 |
+
# ref_enc: sigmoid coordinates. \
|
522 |
+
# (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
|
523 |
+
|
524 |
+
|
525 |
+
class TransformerEncoder(nn.Module):
|
526 |
+
|
527 |
+
def __init__(self,
|
528 |
+
encoder_layer, num_layers, d_model=256,
|
529 |
+
num_queries=300,
|
530 |
+
enc_layer_share=False,
|
531 |
+
text_enhance_layer=None,
|
532 |
+
feature_fusion_layer=None,
|
533 |
+
use_checkpoint=False,
|
534 |
+
use_transformer_ckpt=False,
|
535 |
+
):
|
536 |
+
"""_summary_
|
537 |
+
|
538 |
+
Args:
|
539 |
+
encoder_layer (_type_): _description_
|
540 |
+
num_layers (_type_): _description_
|
541 |
+
norm (_type_, optional): _description_. Defaults to None.
|
542 |
+
d_model (int, optional): _description_. Defaults to 256.
|
543 |
+
num_queries (int, optional): _description_. Defaults to 300.
|
544 |
+
enc_layer_share (bool, optional): _description_. Defaults to False.
|
545 |
+
|
546 |
+
"""
|
547 |
+
super().__init__()
|
548 |
+
# prepare layers
|
549 |
+
self.layers = []
|
550 |
+
self.text_layers = []
|
551 |
+
self.fusion_layers = []
|
552 |
+
if num_layers > 0:
|
553 |
+
self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
|
554 |
+
|
555 |
+
if text_enhance_layer is not None:
|
556 |
+
self.text_layers = _get_clones(text_enhance_layer, num_layers, layer_share=enc_layer_share)
|
557 |
+
if feature_fusion_layer is not None:
|
558 |
+
self.fusion_layers = _get_clones(feature_fusion_layer, num_layers, layer_share=enc_layer_share)
|
559 |
+
else:
|
560 |
+
self.layers = []
|
561 |
+
del encoder_layer
|
562 |
+
|
563 |
+
if text_enhance_layer is not None:
|
564 |
+
self.text_layers = []
|
565 |
+
del text_enhance_layer
|
566 |
+
if feature_fusion_layer is not None:
|
567 |
+
self.fusion_layers = []
|
568 |
+
del feature_fusion_layer
|
569 |
+
|
570 |
+
self.query_scale = None
|
571 |
+
self.num_queries = num_queries
|
572 |
+
self.num_layers = num_layers
|
573 |
+
self.d_model = d_model
|
574 |
+
|
575 |
+
self.use_checkpoint = use_checkpoint
|
576 |
+
self.use_transformer_ckpt = use_transformer_ckpt
|
577 |
+
|
578 |
+
@staticmethod
|
579 |
+
def get_reference_points(spatial_shapes, valid_ratios, device):
|
580 |
+
reference_points_list = []
|
581 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
582 |
+
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
583 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),)
|
584 |
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
585 |
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
586 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
587 |
+
reference_points_list.append(ref)
|
588 |
+
reference_points = torch.cat(reference_points_list, 1)
|
589 |
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
590 |
+
return reference_points
|
591 |
+
|
592 |
+
def forward(self,
|
593 |
+
# for images
|
594 |
+
src: Tensor,
|
595 |
+
pos: Tensor,
|
596 |
+
spatial_shapes: Tensor,
|
597 |
+
level_start_index: Tensor,
|
598 |
+
valid_ratios: Tensor,
|
599 |
+
key_padding_mask: Tensor,
|
600 |
+
# for texts
|
601 |
+
memory_text: Tensor = None,
|
602 |
+
text_attention_mask: Tensor = None,
|
603 |
+
pos_text: Tensor = None,
|
604 |
+
text_self_attention_masks: Tensor = None,
|
605 |
+
position_ids: Tensor = None,
|
606 |
+
):
|
607 |
+
"""
|
608 |
+
Input:
|
609 |
+
- src: [bs, sum(hi*wi), 256]
|
610 |
+
- pos: pos embed for src. [bs, sum(hi*wi), 256]
|
611 |
+
- spatial_shapes: h,w of each level [num_level, 2]
|
612 |
+
- level_start_index: [num_level] start point of level in sum(hi*wi).
|
613 |
+
- valid_ratios: [bs, num_level, 2]
|
614 |
+
- key_padding_mask: [bs, sum(hi*wi)]
|
615 |
+
|
616 |
+
- memory_text: bs, n_text, 256
|
617 |
+
- text_attention_mask: bs, n_text
|
618 |
+
False for no padding; True for padding
|
619 |
+
- pos_text: bs, n_text, 256
|
620 |
+
|
621 |
+
- position_ids: bs, n_text
|
622 |
+
Intermedia:
|
623 |
+
- reference_points: [bs, sum(hi*wi), num_level, 2]
|
624 |
+
Outpus:
|
625 |
+
- output: [bs, sum(hi*wi), 256]
|
626 |
+
"""
|
627 |
+
|
628 |
+
output = src
|
629 |
+
|
630 |
+
# preparation and reshape
|
631 |
+
if self.num_layers > 0:
|
632 |
+
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
|
633 |
+
|
634 |
+
if self.text_layers:
|
635 |
+
# generate pos_text
|
636 |
+
bs, n_text, text_dim = memory_text.shape
|
637 |
+
if pos_text is None and position_ids is None:
|
638 |
+
pos_text = torch.arange(n_text, device=memory_text.device).float().unsqueeze(0).unsqueeze(-1).repeat(bs,
|
639 |
+
1,
|
640 |
+
1)
|
641 |
+
pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
|
642 |
+
if position_ids is not None:
|
643 |
+
pos_text = get_sine_pos_embed(position_ids[..., None], num_pos_feats=256, exchange_xy=False)
|
644 |
+
|
645 |
+
# main process
|
646 |
+
for layer_id, layer in enumerate(self.layers):
|
647 |
+
# if output.isnan().any() or memory_text.isnan().any():
|
648 |
+
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
649 |
+
# import ipdb; ipdb.set_trace()
|
650 |
+
if self.fusion_layers:
|
651 |
+
if self.use_checkpoint:
|
652 |
+
output, memory_text = checkpoint.checkpoint(
|
653 |
+
self.fusion_layers[layer_id],
|
654 |
+
output,
|
655 |
+
memory_text,
|
656 |
+
key_padding_mask,
|
657 |
+
text_attention_mask
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
output, memory_text = self.fusion_layers[layer_id](v=output, l=memory_text,
|
661 |
+
attention_mask_v=key_padding_mask,
|
662 |
+
attention_mask_l=text_attention_mask)
|
663 |
+
|
664 |
+
if self.text_layers:
|
665 |
+
memory_text = self.text_layers[layer_id](
|
666 |
+
src=memory_text.transpose(0, 1),
|
667 |
+
src_mask=~text_self_attention_masks, # note we use ~ for mask here
|
668 |
+
src_key_padding_mask=text_attention_mask,
|
669 |
+
pos=(pos_text.transpose(0, 1) if pos_text is not None else None)
|
670 |
+
).transpose(0, 1)
|
671 |
+
|
672 |
+
# main process
|
673 |
+
if self.use_transformer_ckpt:
|
674 |
+
output = checkpoint.checkpoint(
|
675 |
+
layer,
|
676 |
+
output,
|
677 |
+
pos,
|
678 |
+
reference_points,
|
679 |
+
spatial_shapes,
|
680 |
+
level_start_index,
|
681 |
+
key_padding_mask
|
682 |
+
)
|
683 |
+
else:
|
684 |
+
output = layer(src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes,
|
685 |
+
level_start_index=level_start_index, key_padding_mask=key_padding_mask)
|
686 |
+
|
687 |
+
return output, memory_text
|
688 |
+
|
689 |
+
|
690 |
+
class TransformerDecoder(nn.Module):
|
691 |
+
|
692 |
+
def __init__(self, decoder_layer, num_layers, norm=None,
|
693 |
+
return_intermediate=False,
|
694 |
+
d_model=256, query_dim=4,
|
695 |
+
modulate_hw_attn=False,
|
696 |
+
num_feature_levels=1,
|
697 |
+
deformable_decoder=False,
|
698 |
+
decoder_query_perturber=None,
|
699 |
+
dec_layer_number=None, # number of queries each layer in decoder
|
700 |
+
rm_dec_query_scale=False,
|
701 |
+
dec_layer_share=False,
|
702 |
+
dec_layer_dropout_prob=None,
|
703 |
+
use_detached_boxes_dec_out=False,
|
704 |
+
num_box_decoder_layers=2,
|
705 |
+
num_body_points=68,
|
706 |
+
):
|
707 |
+
super().__init__()
|
708 |
+
if num_layers > 0:
|
709 |
+
self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share)
|
710 |
+
else:
|
711 |
+
self.layers = []
|
712 |
+
self.num_layers = num_layers
|
713 |
+
self.norm = norm
|
714 |
+
self.return_intermediate = return_intermediate
|
715 |
+
assert return_intermediate, "support return_intermediate only"
|
716 |
+
self.query_dim = query_dim
|
717 |
+
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
|
718 |
+
self.num_feature_levels = num_feature_levels
|
719 |
+
self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
|
720 |
+
|
721 |
+
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
|
722 |
+
if not deformable_decoder:
|
723 |
+
self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
|
724 |
+
else:
|
725 |
+
self.query_pos_sine_scale = None
|
726 |
+
|
727 |
+
if rm_dec_query_scale:
|
728 |
+
self.query_scale = None
|
729 |
+
else:
|
730 |
+
raise NotImplementedError
|
731 |
+
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
732 |
+
self.bbox_embed = None
|
733 |
+
self.class_embed = None
|
734 |
+
self.pose_embed = None
|
735 |
+
self.pose_hw_embed = None
|
736 |
+
self.d_model = d_model
|
737 |
+
self.modulate_hw_attn = modulate_hw_attn
|
738 |
+
self.deformable_decoder = deformable_decoder
|
739 |
+
|
740 |
+
if not deformable_decoder and modulate_hw_attn:
|
741 |
+
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
|
742 |
+
else:
|
743 |
+
self.ref_anchor_head = None
|
744 |
+
|
745 |
+
self.decoder_query_perturber = decoder_query_perturber
|
746 |
+
self.box_pred_damping = None
|
747 |
+
|
748 |
+
self.dec_layer_number = dec_layer_number
|
749 |
+
if dec_layer_number is not None:
|
750 |
+
assert isinstance(dec_layer_number, list)
|
751 |
+
assert len(dec_layer_number) == num_layers
|
752 |
+
# assert dec_layer_number[0] ==
|
753 |
+
|
754 |
+
self.dec_layer_dropout_prob = dec_layer_dropout_prob
|
755 |
+
if dec_layer_dropout_prob is not None:
|
756 |
+
assert isinstance(dec_layer_dropout_prob, list)
|
757 |
+
assert len(dec_layer_dropout_prob) == num_layers
|
758 |
+
for i in dec_layer_dropout_prob:
|
759 |
+
assert 0.0 <= i <= 1.0
|
760 |
+
|
761 |
+
self.rm_detach = None
|
762 |
+
self.num_body_points = num_body_points
|
763 |
+
|
764 |
+
self.hw = nn.Embedding(17, 2)
|
765 |
+
self.num_box_decoder_layers = num_box_decoder_layers
|
766 |
+
self.kpt_index = [x for x in range(50 * (self.num_body_points + 1)) if x % (self.num_body_points + 1) != 0]
|
767 |
+
self.hw_append = nn.Embedding(self.num_body_points-17, 2)
|
768 |
+
|
769 |
+
def forward(self, tgt, memory,
|
770 |
+
tgt_mask: Optional[Tensor] = None,
|
771 |
+
tgt_mask2: Optional[Tensor] = None,
|
772 |
+
memory_mask: Optional[Tensor] = None,
|
773 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
774 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
775 |
+
pos: Optional[Tensor] = None,
|
776 |
+
refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
|
777 |
+
# for memory
|
778 |
+
level_start_index: Optional[Tensor] = None, # num_levels
|
779 |
+
spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
780 |
+
valid_ratios: Optional[Tensor] = None,
|
781 |
+
# for text
|
782 |
+
memory_text: Optional[Tensor] = None,
|
783 |
+
text_attention_mask: Optional[Tensor] = None,
|
784 |
+
text_dict: Optional[Tensor] = None,
|
785 |
+
dn_meta: Optional[Tensor] = None,
|
786 |
+
targets: Optional[Tensor] = None,
|
787 |
+
kpt_embed: Optional[Tensor] = None
|
788 |
+
):
|
789 |
+
"""
|
790 |
+
Input:
|
791 |
+
- tgt: nq, bs, d_model
|
792 |
+
- memory: hw, bs, d_model
|
793 |
+
- pos: hw, bs, d_model
|
794 |
+
- refpoints_unsigmoid: nq, bs, 2/4
|
795 |
+
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
796 |
+
"""
|
797 |
+
|
798 |
+
output = tgt
|
799 |
+
output += self.hw.weight[0, 0] * 0.0
|
800 |
+
|
801 |
+
|
802 |
+
intermediate = []
|
803 |
+
reference_points = refpoints_unsigmoid.sigmoid()
|
804 |
+
ref_points = [reference_points]
|
805 |
+
effect_num_dn = dn_meta['pad_size'] if self.training else 0
|
806 |
+
inter_select_number = 50
|
807 |
+
for layer_id, layer in enumerate(self.layers):
|
808 |
+
|
809 |
+
if reference_points.shape[-1] == 4:
|
810 |
+
reference_points_input = reference_points[:, :, None] \
|
811 |
+
* torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4
|
812 |
+
else:
|
813 |
+
assert reference_points.shape[-1] == 2
|
814 |
+
reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
|
815 |
+
query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # nq, bs, 256*2
|
816 |
+
|
817 |
+
# conditional query
|
818 |
+
raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
|
819 |
+
pos_scale = self.query_scale(output) if self.query_scale is not None else 1
|
820 |
+
query_pos = pos_scale * raw_query_pos
|
821 |
+
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
822 |
+
# if query_pos.isnan().any() | query_pos.isinf().any():
|
823 |
+
# import ipdb; ipdb.set_trace()
|
824 |
+
|
825 |
+
# main process
|
826 |
+
output = layer(
|
827 |
+
tgt=output,
|
828 |
+
tgt_query_pos=query_pos,
|
829 |
+
tgt_query_sine_embed=query_sine_embed,
|
830 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
831 |
+
tgt_reference_points=reference_points_input,
|
832 |
+
|
833 |
+
memory_text=memory_text,
|
834 |
+
text_attention_mask=text_attention_mask,
|
835 |
+
|
836 |
+
memory=memory,
|
837 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
838 |
+
memory_level_start_index=level_start_index,
|
839 |
+
memory_spatial_shapes=spatial_shapes,
|
840 |
+
memory_pos=pos,
|
841 |
+
|
842 |
+
self_attn_mask=tgt_mask,
|
843 |
+
cross_attn_mask=memory_mask
|
844 |
+
)
|
845 |
+
if output.isnan().any() | output.isinf().any():
|
846 |
+
print(f"output layer_id {layer_id} is nan")
|
847 |
+
try:
|
848 |
+
num_nan = output.isnan().sum().item()
|
849 |
+
num_inf = output.isinf().sum().item()
|
850 |
+
print(f"num_nan {num_nan}, num_inf {num_inf}")
|
851 |
+
except Exception as e:
|
852 |
+
print(e)
|
853 |
+
|
854 |
+
|
855 |
+
|
856 |
+
|
857 |
+
intermediate.append(self.norm(output))
|
858 |
+
# iter update
|
859 |
+
if layer_id < self.num_box_decoder_layers:
|
860 |
+
reference_before_sigmoid = inverse_sigmoid(reference_points)
|
861 |
+
delta_unsig = self.bbox_embed[layer_id](output)
|
862 |
+
outputs_unsig = delta_unsig + reference_before_sigmoid
|
863 |
+
new_reference_points = outputs_unsig.sigmoid()
|
864 |
+
|
865 |
+
# select # ref points as anchors
|
866 |
+
if layer_id == self.num_box_decoder_layers - 1:
|
867 |
+
dn_output = output[:effect_num_dn]
|
868 |
+
dn_new_reference_points = new_reference_points[:effect_num_dn]
|
869 |
+
class_unselected = self.class_embed[layer_id](output.transpose(0, 1), text_dict)[:,
|
870 |
+
effect_num_dn:].transpose(0, 1)
|
871 |
+
topk_proposals = torch.topk(class_unselected.max(-1)[0], inter_select_number, dim=0)[1]
|
872 |
+
new_reference_points_for_box = torch.gather(new_reference_points[effect_num_dn:], 0,
|
873 |
+
topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
|
874 |
+
new_output_for_box = torch.gather(output[effect_num_dn:], 0,
|
875 |
+
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
|
876 |
+
keypoint_embed=kpt_embed.transpose(0, 1)
|
877 |
+
|
878 |
+
new_output_for_keypoint = keypoint_embed[None, :, :, :].repeat(new_output_for_box.shape[0],1,1,1)
|
879 |
+
delta_xy = self.pose_embed[-1](new_output_for_keypoint)[..., :2]
|
880 |
+
keypoint_xy = (inverse_sigmoid(new_reference_points_for_box[..., :2][:, None]) + delta_xy).sigmoid()
|
881 |
+
num_queries, _, bs, _ = keypoint_xy.shape
|
882 |
+
aa = torch.cat((self.hw.weight,self.hw_append.weight),dim=0)
|
883 |
+
keypoint_wh_weight = aa.unsqueeze(0).unsqueeze(-2).repeat(num_queries, 1, bs, 1).sigmoid()
|
884 |
+
keypoint_wh = keypoint_wh_weight * new_reference_points_for_box[..., 2:][:, None]
|
885 |
+
new_reference_points_for_keypoint = torch.cat((keypoint_xy, keypoint_wh), dim=-1)
|
886 |
+
new_reference_points = torch.cat(
|
887 |
+
(new_reference_points_for_box.unsqueeze(1), new_reference_points_for_keypoint), dim=1).flatten(0, 1)
|
888 |
+
output = torch.cat((new_output_for_box.unsqueeze(1), new_output_for_keypoint), dim=1).flatten(0, 1)
|
889 |
+
new_reference_points = torch.cat((dn_new_reference_points, new_reference_points), dim=0)
|
890 |
+
output = torch.cat((dn_output, output), dim=0)
|
891 |
+
tgt_mask = tgt_mask2
|
892 |
+
|
893 |
+
if layer_id >= self.num_box_decoder_layers:
|
894 |
+
reference_before_sigmoid = inverse_sigmoid(reference_points)
|
895 |
+
output_bbox_dn = output[:effect_num_dn]
|
896 |
+
output_bbox_norm = output[effect_num_dn:][0::(self.num_body_points + 1)]
|
897 |
+
reference_before_sigmoid_bbox_dn = reference_before_sigmoid[:effect_num_dn]
|
898 |
+
reference_before_sigmoid_bbox_norm = reference_before_sigmoid[effect_num_dn:][
|
899 |
+
0::(self.num_body_points + 1)]
|
900 |
+
delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn)
|
901 |
+
delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm)
|
902 |
+
outputs_unsig_dn = delta_unsig_dn + reference_before_sigmoid_bbox_dn
|
903 |
+
outputs_unsig_norm = delta_unsig_norm + reference_before_sigmoid_bbox_norm
|
904 |
+
new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid()
|
905 |
+
new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid()
|
906 |
+
output_kpt = output[effect_num_dn:].index_select(0, torch.tensor(self.kpt_index, device=output.device))
|
907 |
+
delta_xy_unsig = self.pose_embed[layer_id - self.num_box_decoder_layers](output_kpt)
|
908 |
+
outputs_unsig = reference_before_sigmoid[effect_num_dn:].index_select(0, torch.tensor(self.kpt_index,
|
909 |
+
device=output.device)).clone() ##
|
910 |
+
delta_hw_unsig = self.pose_hw_embed[layer_id - self.num_box_decoder_layers](output_kpt)
|
911 |
+
outputs_unsig[..., :2] += delta_xy_unsig[..., :2]
|
912 |
+
outputs_unsig[..., 2:] += delta_hw_unsig
|
913 |
+
new_reference_points_for_keypoint = outputs_unsig.sigmoid()
|
914 |
+
bs = new_reference_points_for_box_norm.shape[1]
|
915 |
+
new_reference_points_norm = torch.cat((new_reference_points_for_box_norm.unsqueeze(1),
|
916 |
+
new_reference_points_for_keypoint.view(-1, self.num_body_points,
|
917 |
+
bs, 4)), dim=1).flatten(0,
|
918 |
+
1)
|
919 |
+
new_reference_points = torch.cat((new_reference_points_for_box_dn, new_reference_points_norm), dim=0)
|
920 |
+
|
921 |
+
if self.rm_detach and 'dec' in self.rm_detach:
|
922 |
+
reference_points = new_reference_points
|
923 |
+
else:
|
924 |
+
reference_points = new_reference_points.detach()
|
925 |
+
|
926 |
+
# if layer_id != self.num_layers - 1:
|
927 |
+
if self.use_detached_boxes_dec_out:
|
928 |
+
ref_points.append(reference_points)
|
929 |
+
else:
|
930 |
+
ref_points.append(new_reference_points)
|
931 |
+
|
932 |
+
return [
|
933 |
+
[itm_out.transpose(0, 1) for itm_out in intermediate],
|
934 |
+
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]
|
935 |
+
]
|
936 |
+
|
937 |
+
|
938 |
+
class DeformableTransformerEncoderLayer(nn.Module):
|
939 |
+
def __init__(self,
|
940 |
+
d_model=256, d_ffn=1024,
|
941 |
+
dropout=0.1, activation="relu",
|
942 |
+
n_levels=4, n_heads=8, n_points=4,
|
943 |
+
add_channel_attention=False,
|
944 |
+
use_deformable_box_attn=False,
|
945 |
+
box_attn_type='roi_align',
|
946 |
+
):
|
947 |
+
super().__init__()
|
948 |
+
|
949 |
+
# self attention
|
950 |
+
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
951 |
+
self.dropout1 = nn.Dropout(dropout)
|
952 |
+
self.norm1 = nn.LayerNorm(d_model)
|
953 |
+
|
954 |
+
# ffn
|
955 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
956 |
+
self.activation = _get_activation_fn(activation, d_model=d_ffn)
|
957 |
+
self.dropout2 = nn.Dropout(dropout)
|
958 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
959 |
+
self.dropout3 = nn.Dropout(dropout)
|
960 |
+
self.norm2 = nn.LayerNorm(d_model)
|
961 |
+
|
962 |
+
# channel attention
|
963 |
+
self.add_channel_attention = add_channel_attention
|
964 |
+
if add_channel_attention:
|
965 |
+
self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model)
|
966 |
+
self.norm_channel = nn.LayerNorm(d_model)
|
967 |
+
|
968 |
+
@staticmethod
|
969 |
+
def with_pos_embed(tensor, pos):
|
970 |
+
return tensor if pos is None else tensor + pos
|
971 |
+
|
972 |
+
def forward_ffn(self, src):
|
973 |
+
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
974 |
+
src = src + self.dropout3(src2)
|
975 |
+
src = self.norm2(src)
|
976 |
+
return src
|
977 |
+
|
978 |
+
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
|
979 |
+
# self attention
|
980 |
+
# import ipdb; ipdb.set_trace()
|
981 |
+
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index,
|
982 |
+
key_padding_mask)
|
983 |
+
src = src + self.dropout1(src2)
|
984 |
+
src = self.norm1(src)
|
985 |
+
|
986 |
+
# ffn
|
987 |
+
src = self.forward_ffn(src)
|
988 |
+
|
989 |
+
# channel attn
|
990 |
+
if self.add_channel_attention:
|
991 |
+
src = self.norm_channel(src + self.activ_channel(src))
|
992 |
+
|
993 |
+
return src
|
994 |
+
|
995 |
+
|
996 |
+
class DeformableTransformerDecoderLayer(nn.Module):
|
997 |
+
def __init__(self, d_model=256, d_ffn=1024,
|
998 |
+
dropout=0.1, activation="relu",
|
999 |
+
n_levels=4, n_heads=8, n_points=4,
|
1000 |
+
use_text_feat_guide=False,
|
1001 |
+
use_text_cross_attention=False,
|
1002 |
+
ffn_extra_layernorm=False
|
1003 |
+
):
|
1004 |
+
super().__init__()
|
1005 |
+
|
1006 |
+
# cross attention
|
1007 |
+
# self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
1008 |
+
self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
1009 |
+
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
1010 |
+
self.norm1 = nn.LayerNorm(d_model)
|
1011 |
+
|
1012 |
+
# cross attention text
|
1013 |
+
if use_text_cross_attention:
|
1014 |
+
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
1015 |
+
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
1016 |
+
self.catext_norm = nn.LayerNorm(d_model)
|
1017 |
+
|
1018 |
+
# self attention
|
1019 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
|
1020 |
+
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
1021 |
+
self.norm2 = nn.LayerNorm(d_model)
|
1022 |
+
|
1023 |
+
# ffn
|
1024 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
1025 |
+
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
|
1026 |
+
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
1027 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
1028 |
+
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
1029 |
+
self.norm3 = nn.LayerNorm(d_model)
|
1030 |
+
if ffn_extra_layernorm:
|
1031 |
+
raise NotImplementedError('ffn_extra_layernorm not implemented')
|
1032 |
+
self.norm_ext = nn.LayerNorm(d_ffn)
|
1033 |
+
else:
|
1034 |
+
self.norm_ext = None
|
1035 |
+
|
1036 |
+
self.key_aware_proj = None
|
1037 |
+
self.use_text_feat_guide = use_text_feat_guide
|
1038 |
+
assert not use_text_feat_guide
|
1039 |
+
self.use_text_cross_attention = use_text_cross_attention
|
1040 |
+
|
1041 |
+
def rm_self_attn_modules(self):
|
1042 |
+
self.self_attn = None
|
1043 |
+
self.dropout2 = None
|
1044 |
+
self.norm2 = None
|
1045 |
+
|
1046 |
+
@staticmethod
|
1047 |
+
def with_pos_embed(tensor, pos):
|
1048 |
+
return tensor if pos is None else tensor + pos
|
1049 |
+
|
1050 |
+
def forward_ffn(self, tgt, ipdb_flag=False):
|
1051 |
+
|
1052 |
+
with torch.cuda.amp.autocast(enabled=False):
|
1053 |
+
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
1054 |
+
|
1055 |
+
tgt = tgt + self.dropout4(tgt2)
|
1056 |
+
tgt = self.norm3(tgt)
|
1057 |
+
return tgt
|
1058 |
+
|
1059 |
+
def forward(self,
|
1060 |
+
# for tgt
|
1061 |
+
tgt: Optional[Tensor], # nq, bs, d_model
|
1062 |
+
tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
|
1063 |
+
tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
|
1064 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
1065 |
+
tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
|
1066 |
+
|
1067 |
+
memory_text: Optional[Tensor] = None, # bs, num_token, d_model
|
1068 |
+
text_attention_mask: Optional[Tensor] = None, # bs, num_token
|
1069 |
+
|
1070 |
+
# for memory
|
1071 |
+
memory: Optional[Tensor] = None, # hw, bs, d_model
|
1072 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
1073 |
+
memory_level_start_index: Optional[Tensor] = None, # num_levels
|
1074 |
+
memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
|
1075 |
+
memory_pos: Optional[Tensor] = None, # pos for memory
|
1076 |
+
|
1077 |
+
# sa
|
1078 |
+
self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
|
1079 |
+
cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
|
1080 |
+
):
|
1081 |
+
"""
|
1082 |
+
Input:
|
1083 |
+
- tgt/tgt_query_pos: nq, bs, d_model
|
1084 |
+
-
|
1085 |
+
"""
|
1086 |
+
assert cross_attn_mask is None
|
1087 |
+
|
1088 |
+
# self attention
|
1089 |
+
if self.self_attn is not None:
|
1090 |
+
# import ipdb; ipdb.set_trace()
|
1091 |
+
q = k = self.with_pos_embed(tgt, tgt_query_pos)
|
1092 |
+
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
|
1093 |
+
tgt = tgt + self.dropout2(tgt2)
|
1094 |
+
tgt = self.norm2(tgt)
|
1095 |
+
|
1096 |
+
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
1097 |
+
# if tgt.isnan().any() | tgt.isinf().any() :
|
1098 |
+
# import ipdb; ipdb.set_trace()
|
1099 |
+
|
1100 |
+
if self.use_text_cross_attention:
|
1101 |
+
tgt2 = self.ca_text(self.with_pos_embed(tgt, tgt_query_pos), memory_text.transpose(0, 1),
|
1102 |
+
memory_text.transpose(0, 1), key_padding_mask=text_attention_mask)[0]
|
1103 |
+
tgt = tgt + self.catext_dropout(tgt2)
|
1104 |
+
tgt = self.catext_norm(tgt)
|
1105 |
+
|
1106 |
+
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
1107 |
+
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
1108 |
+
# import ipdb; ipdb.set_trace()
|
1109 |
+
|
1110 |
+
# if tgt.isnan().any() | tgt.isinf().any() :
|
1111 |
+
# import ipdb; ipdb.set_trace()
|
1112 |
+
|
1113 |
+
tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
|
1114 |
+
tgt_reference_points.transpose(0, 1).contiguous(),
|
1115 |
+
memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index,
|
1116 |
+
memory_key_padding_mask).transpose(0, 1)
|
1117 |
+
tgt = tgt + self.dropout1(tgt2)
|
1118 |
+
tgt = self.norm1(tgt)
|
1119 |
+
|
1120 |
+
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
1121 |
+
# tgtk = tgt.clone()
|
1122 |
+
# if tgt.isnan().any() | tgt.isinf().any() :
|
1123 |
+
# import ipdb; ipdb.set_trace()
|
1124 |
+
|
1125 |
+
# ffn
|
1126 |
+
tgt = self.forward_ffn(tgt)
|
1127 |
+
# if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
|
1128 |
+
# if tgt.isnan().any() | tgt.isinf().any() :
|
1129 |
+
# tgtk = self.forward_ffn(tgtk, ipdb_flag=True)
|
1130 |
+
# import ipdb; ipdb.set_trace()
|
1131 |
+
|
1132 |
+
return tgt
|
1133 |
+
|
1134 |
+
|
1135 |
+
def _get_clones(module, N, layer_share=False):
|
1136 |
+
# import ipdb; ipdb.set_trace()
|
1137 |
+
if layer_share:
|
1138 |
+
return nn.ModuleList([module for i in range(N)])
|
1139 |
+
else:
|
1140 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
1141 |
+
|
1142 |
+
|
1143 |
+
def build_deformable_transformer(args):
|
1144 |
+
decoder_query_perturber = None
|
1145 |
+
if args.decoder_layer_noise:
|
1146 |
+
from .utils import RandomBoxPerturber
|
1147 |
+
decoder_query_perturber = RandomBoxPerturber(
|
1148 |
+
x_noise_scale=args.dln_xy_noise, y_noise_scale=args.dln_xy_noise,
|
1149 |
+
w_noise_scale=args.dln_hw_noise, h_noise_scale=args.dln_hw_noise)
|
1150 |
+
|
1151 |
+
use_detached_boxes_dec_out = False
|
1152 |
+
try:
|
1153 |
+
use_detached_boxes_dec_out = args.use_detached_boxes_dec_out
|
1154 |
+
except:
|
1155 |
+
use_detached_boxes_dec_out = False
|
1156 |
+
|
1157 |
+
binary_query_selection = False
|
1158 |
+
try:
|
1159 |
+
binary_query_selection = args.binary_query_selection
|
1160 |
+
except:
|
1161 |
+
binary_query_selection = False
|
1162 |
+
|
1163 |
+
ffn_extra_layernorm = False
|
1164 |
+
try:
|
1165 |
+
ffn_extra_layernorm = args.ffn_extra_layernorm
|
1166 |
+
except:
|
1167 |
+
print('ffn_extra_layernorm not found, set to False')
|
1168 |
+
ffn_extra_layernorm = False
|
1169 |
+
|
1170 |
+
return DeformableTransformer(
|
1171 |
+
d_model=args.hidden_dim,
|
1172 |
+
dropout=args.dropout,
|
1173 |
+
nhead=args.nheads,
|
1174 |
+
num_queries=args.num_queries,
|
1175 |
+
dim_feedforward=args.dim_feedforward,
|
1176 |
+
num_encoder_layers=args.enc_layers,
|
1177 |
+
num_unicoder_layers=args.unic_layers,
|
1178 |
+
num_decoder_layers=args.dec_layers,
|
1179 |
+
normalize_before=args.pre_norm,
|
1180 |
+
return_intermediate_dec=True,
|
1181 |
+
query_dim=args.query_dim,
|
1182 |
+
activation=args.transformer_activation,
|
1183 |
+
num_patterns=args.num_patterns,
|
1184 |
+
modulate_hw_attn=True,
|
1185 |
+
|
1186 |
+
deformable_encoder=True,
|
1187 |
+
deformable_decoder=True,
|
1188 |
+
num_feature_levels=args.num_feature_levels,
|
1189 |
+
enc_n_points=args.enc_n_points,
|
1190 |
+
dec_n_points=args.dec_n_points,
|
1191 |
+
use_deformable_box_attn=args.use_deformable_box_attn,
|
1192 |
+
box_attn_type=args.box_attn_type,
|
1193 |
+
|
1194 |
+
learnable_tgt_init=True,
|
1195 |
+
decoder_query_perturber=decoder_query_perturber,
|
1196 |
+
|
1197 |
+
add_channel_attention=args.add_channel_attention,
|
1198 |
+
add_pos_value=args.add_pos_value,
|
1199 |
+
random_refpoints_xy=args.random_refpoints_xy,
|
1200 |
+
|
1201 |
+
# two stage
|
1202 |
+
two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
|
1203 |
+
two_stage_pat_embed=args.two_stage_pat_embed,
|
1204 |
+
two_stage_add_query_num=args.two_stage_add_query_num,
|
1205 |
+
two_stage_learn_wh=args.two_stage_learn_wh,
|
1206 |
+
two_stage_keep_all_tokens=args.two_stage_keep_all_tokens,
|
1207 |
+
dec_layer_number=args.dec_layer_number,
|
1208 |
+
rm_self_attn_layers=None,
|
1209 |
+
key_aware_type=None,
|
1210 |
+
layer_share_type=None,
|
1211 |
+
|
1212 |
+
rm_detach=None,
|
1213 |
+
decoder_sa_type=args.decoder_sa_type,
|
1214 |
+
module_seq=args.decoder_module_seq,
|
1215 |
+
|
1216 |
+
embed_init_tgt=args.embed_init_tgt,
|
1217 |
+
use_detached_boxes_dec_out=use_detached_boxes_dec_out,
|
1218 |
+
use_text_enhancer=args.use_text_enhancer,
|
1219 |
+
use_fusion_layer=args.use_fusion_layer,
|
1220 |
+
use_checkpoint=args.use_checkpoint,
|
1221 |
+
use_transformer_ckpt=args.use_transformer_ckpt,
|
1222 |
+
use_text_cross_attention=args.use_text_cross_attention,
|
1223 |
+
|
1224 |
+
text_dropout=args.text_dropout,
|
1225 |
+
fusion_dropout=args.fusion_dropout,
|
1226 |
+
fusion_droppath=args.fusion_droppath,
|
1227 |
+
|
1228 |
+
binary_query_selection=binary_query_selection,
|
1229 |
+
ffn_extra_layernorm=ffn_extra_layernorm,
|
1230 |
+
)
|
src/models/XPose/models/UniPose/fuse_modules.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# from timm.models.layers import DropPath
|
6 |
+
from src.models.util import DropPath
|
7 |
+
|
8 |
+
|
9 |
+
class FeatureResizer(nn.Module):
|
10 |
+
"""
|
11 |
+
This class takes as input a set of embeddings of dimension C1 and outputs a set of
|
12 |
+
embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
|
16 |
+
super().__init__()
|
17 |
+
self.do_ln = do_ln
|
18 |
+
# Object feature encoding
|
19 |
+
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
|
20 |
+
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
|
21 |
+
self.dropout = nn.Dropout(dropout)
|
22 |
+
|
23 |
+
def forward(self, encoder_features):
|
24 |
+
x = self.fc(encoder_features)
|
25 |
+
if self.do_ln:
|
26 |
+
x = self.layer_norm(x)
|
27 |
+
output = self.dropout(x)
|
28 |
+
return output
|
29 |
+
|
30 |
+
|
31 |
+
def l1norm(X, dim, eps=1e-8):
|
32 |
+
"""L1-normalize columns of X
|
33 |
+
"""
|
34 |
+
norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
|
35 |
+
X = torch.div(X, norm)
|
36 |
+
return X
|
37 |
+
|
38 |
+
|
39 |
+
def l2norm(X, dim, eps=1e-8):
|
40 |
+
"""L2-normalize columns of X
|
41 |
+
"""
|
42 |
+
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
|
43 |
+
X = torch.div(X, norm)
|
44 |
+
return X
|
45 |
+
|
46 |
+
|
47 |
+
def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
|
48 |
+
"""
|
49 |
+
query: (n_context, queryL, d)
|
50 |
+
context: (n_context, sourceL, d)
|
51 |
+
"""
|
52 |
+
batch_size_q, queryL = query.size(0), query.size(1)
|
53 |
+
batch_size, sourceL = context.size(0), context.size(1)
|
54 |
+
|
55 |
+
# Get attention
|
56 |
+
# --> (batch, d, queryL)
|
57 |
+
queryT = torch.transpose(query, 1, 2)
|
58 |
+
|
59 |
+
# (batch, sourceL, d)(batch, d, queryL)
|
60 |
+
# --> (batch, sourceL, queryL)
|
61 |
+
attn = torch.bmm(context, queryT)
|
62 |
+
if raw_feature_norm == "softmax":
|
63 |
+
# --> (batch*sourceL, queryL)
|
64 |
+
attn = attn.view(batch_size * sourceL, queryL)
|
65 |
+
attn = nn.Softmax()(attn)
|
66 |
+
# --> (batch, sourceL, queryL)
|
67 |
+
attn = attn.view(batch_size, sourceL, queryL)
|
68 |
+
elif raw_feature_norm == "l2norm":
|
69 |
+
attn = l2norm(attn, 2)
|
70 |
+
elif raw_feature_norm == "clipped_l2norm":
|
71 |
+
attn = nn.LeakyReLU(0.1)(attn)
|
72 |
+
attn = l2norm(attn, 2)
|
73 |
+
else:
|
74 |
+
raise ValueError("unknown first norm type:", raw_feature_norm)
|
75 |
+
# --> (batch, queryL, sourceL)
|
76 |
+
attn = torch.transpose(attn, 1, 2).contiguous()
|
77 |
+
# --> (batch*queryL, sourceL)
|
78 |
+
attn = attn.view(batch_size * queryL, sourceL)
|
79 |
+
attn = nn.Softmax()(attn * smooth)
|
80 |
+
# --> (batch, queryL, sourceL)
|
81 |
+
attn = attn.view(batch_size, queryL, sourceL)
|
82 |
+
# --> (batch, sourceL, queryL)
|
83 |
+
attnT = torch.transpose(attn, 1, 2).contiguous()
|
84 |
+
|
85 |
+
# --> (batch, d, sourceL)
|
86 |
+
contextT = torch.transpose(context, 1, 2)
|
87 |
+
# (batch x d x sourceL)(batch x sourceL x queryL)
|
88 |
+
# --> (batch, d, queryL)
|
89 |
+
weightedContext = torch.bmm(contextT, attnT)
|
90 |
+
# --> (batch, queryL, d)
|
91 |
+
weightedContext = torch.transpose(weightedContext, 1, 2)
|
92 |
+
|
93 |
+
return weightedContext, attnT
|
94 |
+
|
95 |
+
|
96 |
+
class BiMultiHeadAttention(nn.Module):
|
97 |
+
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
|
98 |
+
super(BiMultiHeadAttention, self).__init__()
|
99 |
+
|
100 |
+
self.embed_dim = embed_dim
|
101 |
+
self.num_heads = num_heads
|
102 |
+
self.head_dim = embed_dim // num_heads
|
103 |
+
self.v_dim = v_dim
|
104 |
+
self.l_dim = l_dim
|
105 |
+
|
106 |
+
assert (
|
107 |
+
self.head_dim * self.num_heads == self.embed_dim
|
108 |
+
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
109 |
+
self.scale = self.head_dim ** (-0.5)
|
110 |
+
self.dropout = dropout
|
111 |
+
|
112 |
+
self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
113 |
+
self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
114 |
+
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
|
115 |
+
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
|
116 |
+
|
117 |
+
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
|
118 |
+
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
|
119 |
+
|
120 |
+
self.stable_softmax_2d = True
|
121 |
+
self.clamp_min_for_underflow = True
|
122 |
+
self.clamp_max_for_overflow = True
|
123 |
+
|
124 |
+
self._reset_parameters()
|
125 |
+
|
126 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
127 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
128 |
+
|
129 |
+
def _reset_parameters(self):
|
130 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
131 |
+
self.v_proj.bias.data.fill_(0)
|
132 |
+
nn.init.xavier_uniform_(self.l_proj.weight)
|
133 |
+
self.l_proj.bias.data.fill_(0)
|
134 |
+
nn.init.xavier_uniform_(self.values_v_proj.weight)
|
135 |
+
self.values_v_proj.bias.data.fill_(0)
|
136 |
+
nn.init.xavier_uniform_(self.values_l_proj.weight)
|
137 |
+
self.values_l_proj.bias.data.fill_(0)
|
138 |
+
nn.init.xavier_uniform_(self.out_v_proj.weight)
|
139 |
+
self.out_v_proj.bias.data.fill_(0)
|
140 |
+
nn.init.xavier_uniform_(self.out_l_proj.weight)
|
141 |
+
self.out_l_proj.bias.data.fill_(0)
|
142 |
+
|
143 |
+
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
144 |
+
"""_summary_
|
145 |
+
|
146 |
+
Args:
|
147 |
+
v (_type_): bs, n_img, dim
|
148 |
+
l (_type_): bs, n_text, dim
|
149 |
+
attention_mask_v (_type_, optional): _description_. bs, n_img
|
150 |
+
attention_mask_l (_type_, optional): _description_. bs, n_text
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
_type_: _description_
|
154 |
+
"""
|
155 |
+
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
|
156 |
+
# import ipdb; ipdb.set_trace()
|
157 |
+
bsz, tgt_len, _ = v.size()
|
158 |
+
|
159 |
+
query_states = self.v_proj(v) * self.scale
|
160 |
+
key_states = self._shape(self.l_proj(l), -1, bsz)
|
161 |
+
value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
|
162 |
+
value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
|
163 |
+
|
164 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
165 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
166 |
+
key_states = key_states.view(*proj_shape)
|
167 |
+
value_v_states = value_v_states.view(*proj_shape)
|
168 |
+
value_l_states = value_l_states.view(*proj_shape)
|
169 |
+
|
170 |
+
src_len = key_states.size(1)
|
171 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
|
172 |
+
|
173 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
174 |
+
raise ValueError(
|
175 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
176 |
+
)
|
177 |
+
|
178 |
+
if self.stable_softmax_2d:
|
179 |
+
attn_weights = attn_weights - attn_weights.max()
|
180 |
+
|
181 |
+
if self.clamp_min_for_underflow:
|
182 |
+
attn_weights = torch.clamp(attn_weights,
|
183 |
+
min=-50000) # Do not increase -50000, data type half has quite limited range
|
184 |
+
if self.clamp_max_for_overflow:
|
185 |
+
attn_weights = torch.clamp(attn_weights,
|
186 |
+
max=50000) # Do not increase 50000, data type half has quite limited range
|
187 |
+
|
188 |
+
attn_weights_T = attn_weights.transpose(1, 2)
|
189 |
+
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
|
190 |
+
0])
|
191 |
+
if self.clamp_min_for_underflow:
|
192 |
+
attn_weights_l = torch.clamp(attn_weights_l,
|
193 |
+
min=-50000) # Do not increase -50000, data type half has quite limited range
|
194 |
+
if self.clamp_max_for_overflow:
|
195 |
+
attn_weights_l = torch.clamp(attn_weights_l,
|
196 |
+
max=50000) # Do not increase 50000, data type half has quite limited range
|
197 |
+
|
198 |
+
# mask vison for language
|
199 |
+
if attention_mask_v is not None:
|
200 |
+
attention_mask_v = attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
201 |
+
attn_weights_l.masked_fill_(attention_mask_v, float('-inf'))
|
202 |
+
|
203 |
+
attn_weights_l = attn_weights_l.softmax(dim=-1)
|
204 |
+
|
205 |
+
# mask language for vision
|
206 |
+
if attention_mask_l is not None:
|
207 |
+
attention_mask_l = attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
208 |
+
attn_weights.masked_fill_(attention_mask_l, float('-inf'))
|
209 |
+
attn_weights_v = attn_weights.softmax(dim=-1)
|
210 |
+
|
211 |
+
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
|
212 |
+
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
|
213 |
+
|
214 |
+
attn_output_v = torch.bmm(attn_probs_v, value_l_states)
|
215 |
+
attn_output_l = torch.bmm(attn_probs_l, value_v_states)
|
216 |
+
|
217 |
+
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
218 |
+
raise ValueError(
|
219 |
+
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
|
220 |
+
)
|
221 |
+
|
222 |
+
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
|
223 |
+
raise ValueError(
|
224 |
+
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
|
225 |
+
)
|
226 |
+
|
227 |
+
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
228 |
+
attn_output_v = attn_output_v.transpose(1, 2)
|
229 |
+
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
|
230 |
+
|
231 |
+
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
|
232 |
+
attn_output_l = attn_output_l.transpose(1, 2)
|
233 |
+
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
|
234 |
+
|
235 |
+
attn_output_v = self.out_v_proj(attn_output_v)
|
236 |
+
attn_output_l = self.out_l_proj(attn_output_l)
|
237 |
+
|
238 |
+
return attn_output_v, attn_output_l
|
239 |
+
|
240 |
+
|
241 |
+
# Bi-Direction MHA (text->image, image->text)
|
242 |
+
class BiAttentionBlock(nn.Module):
|
243 |
+
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
|
244 |
+
drop_path=.0, init_values=1e-4, cfg=None):
|
245 |
+
"""
|
246 |
+
Inputs:
|
247 |
+
embed_dim - Dimensionality of input and attention feature vectors
|
248 |
+
hidden_dim - Dimensionality of hidden layer in feed-forward network
|
249 |
+
(usually 2-4x larger than embed_dim)
|
250 |
+
num_heads - Number of heads to use in the Multi-Head Attention block
|
251 |
+
dropout - Amount of dropout to apply in the feed-forward network
|
252 |
+
"""
|
253 |
+
super(BiAttentionBlock, self).__init__()
|
254 |
+
|
255 |
+
# pre layer norm
|
256 |
+
self.layer_norm_v = nn.LayerNorm(v_dim)
|
257 |
+
self.layer_norm_l = nn.LayerNorm(l_dim)
|
258 |
+
self.attn = BiMultiHeadAttention(v_dim=v_dim,
|
259 |
+
l_dim=l_dim,
|
260 |
+
embed_dim=embed_dim,
|
261 |
+
num_heads=num_heads,
|
262 |
+
dropout=dropout)
|
263 |
+
|
264 |
+
# add layer scale for training stability
|
265 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
266 |
+
self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=False)
|
267 |
+
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=False)
|
268 |
+
|
269 |
+
def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
|
270 |
+
v = self.layer_norm_v(v)
|
271 |
+
l = self.layer_norm_l(l)
|
272 |
+
delta_v, delta_l = self.attn(v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l)
|
273 |
+
# v, l = v + delta_v, l + delta_l
|
274 |
+
v = v + self.drop_path(self.gamma_v * delta_v)
|
275 |
+
l = l + self.drop_path(self.gamma_l * delta_l)
|
276 |
+
return v, l
|
src/models/XPose/models/UniPose/mask_generate.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def prepare_for_mask(kpt_mask):
|
5 |
+
|
6 |
+
|
7 |
+
tgt_size2 = 50 * 69
|
8 |
+
attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0
|
9 |
+
group_bbox_kpt = 69
|
10 |
+
num_group=50
|
11 |
+
for matchj in range(num_group * group_bbox_kpt):
|
12 |
+
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
|
13 |
+
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
|
14 |
+
if sj > 0:
|
15 |
+
attn_mask2[:,:,matchj, :sj] = True
|
16 |
+
if ej < num_group * group_bbox_kpt:
|
17 |
+
attn_mask2[:,:,matchj, ej:] = True
|
18 |
+
|
19 |
+
|
20 |
+
bs, length = kpt_mask.shape
|
21 |
+
equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :]
|
22 |
+
equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1)
|
23 |
+
for idx in range(num_group):
|
24 |
+
start_idx = idx * length
|
25 |
+
end_idx = (idx + 1) * length
|
26 |
+
attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False
|
27 |
+
attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
input_query_label = None
|
33 |
+
input_query_bbox = None
|
34 |
+
attn_mask = None
|
35 |
+
dn_meta = None
|
36 |
+
|
37 |
+
return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta
|
38 |
+
|
39 |
+
|
40 |
+
def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
|
41 |
+
|
42 |
+
if dn_meta and dn_meta['pad_size'] > 0:
|
43 |
+
|
44 |
+
output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class]
|
45 |
+
output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord]
|
46 |
+
|
47 |
+
outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class]
|
48 |
+
outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord]
|
49 |
+
|
50 |
+
out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
|
51 |
+
if aux_loss:
|
52 |
+
out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
|
53 |
+
dn_meta['output_known_lbs_bboxes'] = out
|
54 |
+
return outputs_class, outputs_coord
|
55 |
+
|
56 |
+
|
src/models/XPose/models/UniPose/ops/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/8/5 21:58
|
3 |
+
# @Author : shaoguowen
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @Project : FasterLivePortrait
|
6 |
+
# @FileName: __init__.py.py
|
src/models/XPose/models/UniPose/ops/functions/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
from .ms_deform_attn_func import MSDeformAttnFunction
|
10 |
+
|
src/models/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
from __future__ import absolute_import
|
10 |
+
from __future__ import print_function
|
11 |
+
from __future__ import division
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.autograd import Function
|
16 |
+
from torch.autograd.function import once_differentiable
|
17 |
+
|
18 |
+
import MultiScaleDeformableAttention as MSDA
|
19 |
+
|
20 |
+
|
21 |
+
class MSDeformAttnFunction(Function):
|
22 |
+
@staticmethod
|
23 |
+
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
|
24 |
+
ctx.im2col_step = im2col_step
|
25 |
+
output = MSDA.ms_deform_attn_forward(
|
26 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
|
27 |
+
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
|
28 |
+
return output
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
@once_differentiable
|
32 |
+
def backward(ctx, grad_output):
|
33 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
|
34 |
+
grad_value, grad_sampling_loc, grad_attn_weight = \
|
35 |
+
MSDA.ms_deform_attn_backward(
|
36 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
|
37 |
+
|
38 |
+
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
39 |
+
|
40 |
+
|
41 |
+
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
|
42 |
+
# for debug and test only,
|
43 |
+
# need to use cuda version instead
|
44 |
+
N_, S_, M_, D_ = value.shape
|
45 |
+
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
46 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
47 |
+
sampling_grids = 2 * sampling_locations - 1
|
48 |
+
sampling_value_list = []
|
49 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
50 |
+
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
51 |
+
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
|
52 |
+
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
53 |
+
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
54 |
+
# N_*M_, D_, Lq_, P_
|
55 |
+
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
56 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
57 |
+
sampling_value_list.append(sampling_value_l_)
|
58 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
59 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
|
60 |
+
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
61 |
+
return output.transpose(1, 2).contiguous()
|