AndroidGuy commited on
Commit
8dc9718
·
1 Parent(s): 736c8f2

Add files with Git LFS support

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +14 -0
  3. DockerfileAPI +7 -0
  4. LICENSE +33 -0
  5. README.md +183 -11
  6. README_ZH.md +173 -0
  7. api.py +479 -0
  8. assets/.gitignore +2 -0
  9. assets/docs/API.md +41 -0
  10. assets/docs/API_ZH.md +47 -0
  11. assets/gradio/gradio_description_animate_clear.md +6 -0
  12. assets/gradio/gradio_description_animation.md +19 -0
  13. assets/gradio/gradio_description_retargeting.md +14 -0
  14. assets/gradio/gradio_description_upload.md +16 -0
  15. assets/gradio/gradio_title.md +19 -0
  16. assets/mask_template.png +0 -0
  17. camera.bat +32 -0
  18. configs/onnx_infer.yaml +114 -0
  19. configs/onnx_mp_infer.yaml +108 -0
  20. configs/trt_infer.yaml +114 -0
  21. configs/trt_mp_infer.yaml +108 -0
  22. requirements.txt +18 -0
  23. requirements_macos.txt +18 -0
  24. requirements_win.txt +17 -0
  25. run.py +322 -0
  26. scripts/all_onnx2trt.bat +29 -0
  27. scripts/all_onnx2trt.sh +17 -0
  28. scripts/all_onnx2trt_animal.sh +12 -0
  29. scripts/onnx2trt.py +161 -0
  30. scripts/start_api.sh +3 -0
  31. src/__init__.py +5 -0
  32. src/models/JoyVASA/__init__.py +6 -0
  33. src/models/JoyVASA/common.py +46 -0
  34. src/models/JoyVASA/dit_talking_head.py +538 -0
  35. src/models/JoyVASA/helper.py +32 -0
  36. src/models/JoyVASA/hubert.py +51 -0
  37. src/models/JoyVASA/wav2vec2.py +119 -0
  38. src/models/XPose/__init__.py +6 -0
  39. src/models/XPose/config_model/UniPose_SwinT.py +125 -0
  40. src/models/XPose/config_model/__init__.py +6 -0
  41. src/models/XPose/config_model/coco_transformer.py +8 -0
  42. src/models/XPose/models/UniPose/__init__.py +10 -0
  43. src/models/XPose/models/UniPose/attention.py +373 -0
  44. src/models/XPose/models/UniPose/backbone.py +211 -0
  45. src/models/XPose/models/UniPose/deformable_transformer.py +1230 -0
  46. src/models/XPose/models/UniPose/fuse_modules.py +276 -0
  47. src/models/XPose/models/UniPose/mask_generate.py +56 -0
  48. src/models/XPose/models/UniPose/ops/__init__.py +6 -0
  49. src/models/XPose/models/UniPose/ops/functions/__init__.py +10 -0
  50. src/models/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py +61 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ *.wav filter=lfs diff=lfs merge=lfs -text
39
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .idea
3
+ *.pyc
4
+ .DS_Store
5
+ checkpoints
6
+ results
7
+ venv
8
+ *.egg-info
9
+ build
10
+ dist
11
+ *.eg
12
+ checkpoints_test
13
+ logs
14
+ third_party
DockerfileAPI ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FROM shaoguo/faster_liveportrait:v3
2
+ USER root
3
+ RUN mkdir -p /root/FasterLiveportrait
4
+ RUN chown -R /root/FasterLiveportrait
5
+ COPY . /root/FasterLiveportrait
6
+ WORKDIR /root/FasterLiveportrait
7
+ CMD ["/bin/bash && bash scripts/start_api.sh"]
LICENSE ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 warmshao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ---
24
+
25
+ ADDITIONAL NOTICE FOR MODELS:
26
+
27
+ This repository may contain or reference machine learning models. These models
28
+ are subject to their respective licenses, which may differ from the MIT license
29
+ applied to the code in this repository. Users are responsible for complying
30
+ with the license terms of any models they use. This repository and its
31
+ maintainers assume no responsibility for model licensing compliance.
32
+
33
+ Please check the original source and license of each model before use.
README.md CHANGED
@@ -1,11 +1,183 @@
1
- ---
2
- title: FasterLivepotrait
3
- emoji: 💻
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## FasterLivePortrait: Bring portraits to life in Real Time!
2
+ <a href="README.md">English</a> | <a href="README_ZH.md">中文</a>
3
+
4
+ **Original repository: [LivePortrait](https://github.com/KwaiVGI/LivePortrait), thanks to the authors for sharing**
5
+
6
+ **New features:**
7
+ * Achieved real-time running of LivePortrait on RTX 3090 GPU using TensorRT, reaching speeds of 30+ FPS. This is the speed for rendering a single frame, including pre- and post-processing, not just the model inference speed.
8
+ * Seamless support for native gradio app, with several times faster speed and support for simultaneous inference on multiple faces and Animal Model.
9
+ * Added support for [JoyVASA](https://github.com/jdh-algo/JoyVASA), which can drive videos or images with audio.
10
+
11
+ **If you find this project useful, please give it a star ✨✨**
12
+
13
+ ### Demo (Explore more features)
14
+ * Anyone want this? Fell free to contact me.
15
+
16
+ <video src="https://github.com/user-attachments/assets/554c37fc-d098-4938-a638-1660d85d222e" controls="controls" width="500" height="300">Your browser does not support this video!</video>
17
+
18
+
19
+ * Text-driven video, based on kokoro-82M:
20
+
21
+ <video src="https://github.com/user-attachments/assets/04e962e2-6c57-4d01-ae4a-2f6d2d501c5a" controls="controls" width="500" height="300">Your browser does not support this video!</video>
22
+
23
+ * Audio-driven video (real-time):
24
+
25
+ <video src="https://github.com/user-attachments/assets/98bb5ff7-0796-42db-9d7b-e04ddd2c3c14" controls="controls" width="500" height="300">Your browser does not support this video!</video>
26
+
27
+ * Animal-driven:
28
+
29
+ <video src="https://github.com/user-attachments/assets/dada0a92-593a-480b-a034-cbcce16e38b9" controls="controls" width="500" height="300">Your browser does not support this video!</video>
30
+
31
+ * Multiple faces driven simultaneously:
32
+
33
+ <video src="https://github.com/KwaiVGI/LivePortrait/assets/138360003/b37de35d-6feb-4100-b73f-58ac23121483" controls="controls" width="500" height="300">Your browser does not support this video!</video>
34
+
35
+
36
+ ### Environment Setup
37
+ * Option 1 (recommended): If you are a Windows user, you can directly download the [integrated package](https://github.com/warmshao/FasterLivePortrait/releases/tag/v1.8).
38
+ * You need to install [git](https://git-scm.com/downloads) first, then double-click `update.bat` to update the code.
39
+ * Double-click `scripts/all_onnx2trt.bat` to convert onnx files to tensorrt files.
40
+ * Double-click `webui.bat` to open the webpage, or double-click `camera.bat` to open the camera for real-time operation.
41
+ * Option 2: Docker.A docker image is provided for eliminating the need to install onnxruntime-gpu and TensorRT manually.
42
+ * Install [Docker](https://docs.docker.com/desktop/install/windows-install/) according to your system
43
+ * Download the image: `docker pull shaoguo/faster_liveportrait:v3`
44
+ * Execute the command, replace `$FasterLivePortrait_ROOT` with the local directory where you downloaded FasterLivePortrait:
45
+ ```shell
46
+ docker run -it --gpus=all \
47
+ --name faster_liveportrait \
48
+ -v $FasterLivePortrait_ROOT:/root/FasterLivePortrait \
49
+ --restart=always \
50
+ -p 9870:9870 \
51
+ shaoguo/faster_liveportrait:v3 \
52
+ /bin/bash
53
+ ```
54
+ * Option 3: Create a new Python virtual environment and install the necessary Python packages manually.
55
+ * First, install [ffmpeg](https://www.ffmpeg.org/download.html)
56
+ * Run `pip install -r requirements.txt`
57
+ * Then follow the tutorials below to install onnxruntime-gpu or TensorRT. Note that this has only been tested on Linux systems.
58
+
59
+ ### Usage
60
+ #### 1. TensorRT Inference(Recommended)
61
+ * (Ignored in Docker) Install TensorRT 8.x (versions >=10.x are not compatible). Remember the installation path of [TensorRT](https://developer.nvidia.com/tensorrt).
62
+ * (Ignored in Docker) Install the grid_sample TensorRT plugin, as the model uses grid sample that requires 5D input, which is not supported by the native grid_sample operator.
63
+ * `git clone https://github.com/SeanWangJS/grid-sample3d-trt-plugin`
64
+ * Modify line 30 in `CMakeLists.txt` to: `set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "60;70;75;80;86")`
65
+ * `export PATH=/usr/local/cuda/bin:$PATH`
66
+ * `mkdir build && cd build`
67
+ * `cmake .. -DTensorRT_ROOT=$TENSORRT_HOME`, replace $TENSORRT_HOME with your own TensorRT root directory.
68
+ * `make`, remember the address of the .so file, replace `/opt/grid-sample3d-trt-plugin/build/libgrid_sample_3d_plugin.so` in `scripts/onnx2trt.py` and `src/models/predictor.py` with your own .so file path
69
+ * Download ONNX model files:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`. Convert all ONNX models to TensorRT, run `sh scripts/all_onnx2trt.sh` and `sh scripts/all_onnx2trt_animal.sh`
70
+ * Test the pipeline using tensorrt:
71
+ ```shell
72
+ python run.py \
73
+ --src_image assets/examples/source/s10.jpg \
74
+ --dri_video assets/examples/driving/d14.mp4 \
75
+ --cfg configs/trt_infer.yaml
76
+ * To run in real-time using a camera:
77
+ ```shell
78
+ python run.py \
79
+ --src_image assets/examples/source/s10.jpg \
80
+ --dri_video 0 \
81
+ --cfg configs/trt_infer.yaml \
82
+ --realtime
83
+ ```
84
+
85
+ #### 2. Onnxruntime Inference
86
+ * First, download the converted onnx model files:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`.
87
+ * (Ignored in Docker)If you want to use onnxruntime cpu inference, simply `pip install onnxruntime`. However, cpu inference is extremely slow and not recommended. The latest onnxruntime-gpu still doesn't support grid_sample cuda, but I found a branch that supports it. Follow these steps to install `onnxruntime-gpu` from source:
88
+ * `git clone https://github.com/microsoft/onnxruntime`
89
+ * `git checkout liqun/ImageDecoder-cuda`. Thanks to liqun for the grid_sample with cuda implementation!
90
+ * Run the following commands to compile, changing `cuda_version` and `CMAKE_CUDA_ARCHITECTURES` according to your machine (your cuDNN version must be 8.x, 9.x is not compatible):
91
+ ```shell
92
+ ./build.sh --parallel \
93
+ --build_shared_lib --use_cuda \
94
+ --cuda_version 11.8 \
95
+ --cuda_home /usr/local/cuda --cudnn_home /usr/local/cuda/ \
96
+ --config Release --build_wheel --skip_tests \
97
+ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="60;70;75;80;86" \
98
+ --cmake_extra_defines CMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
99
+ --disable_contrib_ops \
100
+ --allow_running_as_root
101
+ ```
102
+ * `pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl`
103
+ * Test the pipeline using onnxruntime:
104
+ ```
105
+ python run.py \
106
+ --src_image assets/examples/source/s10.jpg \
107
+ --dri_video assets/examples/driving/d14.mp4 \
108
+ --cfg configs/onnx_infer.yaml
109
+ ```
110
+
111
+
112
+ ### Gradio WebUI
113
+ * onnxruntime: `python webui.py --mode onnx`
114
+ * tensorrt: `python webui.py --mode trt`
115
+ * The default port is 9870. Open the webpage: `http://localhost:9870/`
116
+
117
+ Hotkeys for webcam mode (when render window is on focus)\
118
+ Q > exit\
119
+ S > Stitching\
120
+ Z > RelativeMotion\
121
+ X > AnimationRegion\
122
+ C > CropDrivingVideo\
123
+ K,L > AdjustSourceScale\
124
+ N,M > AdjustDriverScale
125
+
126
+ ## License
127
+
128
+ - **Code**: This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
129
+ - **Models**: Any machine learning models used in this project are subject to their respective licenses. Please refer to the original model sources for license information. We do not take responsibility for model license compliance.
130
+
131
+
132
+ **Changelog**
133
+ - [x] **2025/06/29:** LivePortrait animal v1.1 onnx models are available. Download from [this](https://huggingface.co/warmshao/FasterLivePortrait/tree/main/liveportrait_animal_onnx_v1.1).
134
+ - [x] **2024/12/22:** Add API Deployment `python api.py`, For more information, please refer to the [tutorial](assets/docs/API.md).
135
+ - [x] **2024/12/21:** Added support for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M), enabling text-driven video or image generation.
136
+ - Updated code: `git pull origin master` and install the latest Python dependencies `pip install requirements.txt`, or simply double-click `update.bat` on Windows.
137
+ - Download the model: `huggingface-cli download hexgrad/Kokoro-82M --local-dir .\checkpoints\Kokoro-82M`.
138
+ - For Linux, install `espeak-ng`: `apt-get -qq -y install espeak-ng > /dev/null 2>&1`
139
+ - For Windows, refer to [manual installation instructions](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12) and configure the `espeak-ng` environment variables. The current read location is [here](src/pipelines/gradio_live_portrait_pipeline.py:437); modify it if your installation path differs.
140
+ - Now you can use it normally in the "Drive Text" tab.
141
+ - [x] **2024/12/16:** Added support for [JoyVASA](https://github.com/jdh-algo/JoyVASA), which can drive videos or images with audio. Very cool!
142
+ - Update code, then download the models: `huggingface-cli download TencentGameMate/chinese-hubert-base --local-dir .\checkpoints\chinese-hubert-base` and `huggingface-cli download jdh-algo/JoyVASA --local-dir ./checkpoints/JoyVASA`
143
+ - After launching the webui, follow the tutorial below. When the source is a video, it's recommended to only drive the mouth movements
144
+
145
+ <video src="https://github.com/user-attachments/assets/42fb24be-0cde-4138-9671-e52eec95e7f5" controls="controls" width="500" height="400">您的浏览器不支持播放该视频!</video>
146
+
147
+ - [x] **2024/12/14:** Added pickle and image driving, as well as region driving animation_region.
148
+ - Please update the latest code. Windows users can directly double-click `update.bat` to update, but note that your local code will be overwritten.
149
+ - Running `python run.py` now automatically saves the corresponding pickle to the same directory as the driving video, allowing for direct reuse.
150
+ - After opening webui, you can experience the new pickle and image driving, as well as the region driving animation_region features. Note that for image driving, remember to disable `relative motion`.
151
+ - [x] **2024/08/11:** Optimized paste_back speed and fixed some bugs.
152
+ - Used torchgeometry + cuda to optimize the paste_back function, significantly improving speed. Example: `python run.py --src_image assets/examples/source/s39.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_infer.yaml --paste_back --animal`
153
+ - Fixed issues with Xpose ops causing errors on some GPUs and other bugs. Please use the latest docker image: `docker pull shaoguo/faster_liveportrait:v3`
154
+ - [x] **2024/08/11:** Optimized paste_back speed and fixed some bugs.
155
+ - Used torchgeometry + cuda to optimize the paste_back function, significantly improving speed. Example: `python run.py --src_image assets/examples/source/s39.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_infer.yaml --paste_back --animal`
156
+ - Fixed issues with Xpose ops causing errors on some GPUs and other bugs. Please use the latest docker image: `docker pull shaoguo/faster_liveportrait:v3`
157
+ - [x] **2024/08/07:** Added support for animal models and MediaPipe models, so you no longer need to worry about copyright issues.
158
+ - Added support for animal models.
159
+ - Download the animal ONNX file: `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`, then convert it to TRT format.
160
+ - Update the Docker image: `docker pull shaoguo/faster_liveportrait:v3`. Using animal model:`python run.py --src_image assets/examples/source/s39.jpg --dri_video 0 --cfg configs/trt_infer.yaml --realtime --animal`
161
+ - Windows users can download the latest [Windows all-in-one package](https://github.com/warmshao/FasterLivePortrait/releases) from the release page, then unzip and use it.
162
+ - Simple usage tutorial:
163
+
164
+ <video src="https://github.com/user-attachments/assets/dc37e2dd-551a-43b0-8929-fc5d5fe16ec5" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
165
+
166
+ - Using MediaPipe model to replace InsightFace
167
+ - For web usage: `python webui.py --mode trt --mp` or `python webui.py --mode onnx --mp`
168
+ - For local webcam: `python run.py --src_image assets/examples/source/s12.jpg --dri_video 0 --cfg configs/trt_mp_infer.yaml`
169
+ - [x] **2024/07/24:** Windows integration package, no installation required, one-click run, supports TensorRT and OnnxruntimeGPU. Thanks to @zhanghongyong123456 for their contribution in this [issue](https://github.com/warmshao/FasterLivePortrait/issues/22).
170
+ - [Optional] If you have already installed CUDA and cuDNN on your Windows computer, please skip this step. I have only verified on CUDA 12.2. If you haven't installed CUDA or encounter CUDA-related errors, you need to follow these steps:
171
+ - Download [CUDA 12.2](https://developer.nvidia.com/cuda-12-2-0-download-archive?target_os=Windows&target_arch=x86_64), double-click the exe and install following the default settings step by step.
172
+ - Download the [cuDNN](https://developer.nvidia.com/downloads/compute/cudnn/secure/8.9.7/local_installers/12.x/cudnn-windows-x86_64-8.9.7.29_cuda12-archive.zip) zip file, extract it, and copy the lib, bin, and include folders from the cuDNN folder to the CUDA 12.2 folder (default is C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2)
173
+ - Download the installation-free [Windows integration package](https://github.com/warmshao/FasterLivePortrait/releases) from the release page and extract it.
174
+ - Enter `FasterLivePortrait-windows` and double-click `scripts/all_onnx2trt.bat` to convert onnx files, which will take some time.
175
+ - For web demo: Double-click `webui.bat`, open the webpage: `http://localhost:9870/`
176
+ - For real-time camera operation, double-click `camera.bat`,press `q` to stop. If you want to change the target image, run in command line: `camera.bat assets/examples/source/s9.jpg`
177
+ - [x] **2024/07/18:** macOS support added(No need for Docker, Python is enough). M1/M2 chips are faster, but it's still quite slow 😟
178
+ - Install ffmpeg: `brew install ffmpeg`
179
+ - Set up a Python 3.10 virtual environment. Recommend using [miniforge](https://github.com/conda-forge/miniforge): `conda create -n flip python=3.10 && conda activate flip`
180
+ - Install requirements: `pip install -r requirements_macos.txt`
181
+ - Download ONNX files: `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`
182
+ - Test: `python webui.py --mode onnx`
183
+ - [x] **2024/07/17:** Added support for Docker environment, providing a runnable image.
README_ZH.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## FasterLivePortrait:Bring portrait to life in Real Time!
2
+ <a href="README.md">English</a> | <a href="README_ZH.md">中文</a>
3
+
4
+ **原仓库: [LivePortrait](https://github.com/KwaiVGI/LivePortrait),感谢作者的分享**
5
+
6
+ **新增功能:**
7
+ * 通过TensorRT实现在RTX 3090显卡上**实时**运行LivePortrait,速度达到 30+ FPS. 这个速度是实测渲染出一帧的速度,而不仅仅是模型的推理时间。
8
+ * 无缝支持原生的gradio app, 速度快了好几倍,同时支持多张人脸、Animal模型。
9
+ * 增加[JoyVASA](https://github.com/jdh-algo/JoyVASA)的支持,可以用音频驱动视频或图片。
10
+
11
+ **如果你觉得这个项目有用,帮我点个star吧✨✨**
12
+
13
+ ### Demo(还有很多功能等你探索)
14
+ * 文本驱动视频,基于kokoro-82M:
15
+
16
+ <video src="https://github.com/user-attachments/assets/04e962e2-6c57-4d01-ae4a-2f6d2d501c5a" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
17
+ * 声音驱动视频(可以实时):
18
+
19
+ <video src="https://github.com/user-attachments/assets/98bb5ff7-0796-42db-9d7b-e04ddd2c3c14" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
20
+ * 动物驱动:
21
+
22
+ <video src="https://github.com/user-attachments/assets/dada0a92-593a-480b-a034-cbcce16e38b9" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
23
+ * 多张人脸同时驱动:
24
+
25
+ <video src="https://github.com/KwaiVGI/LivePortrait/assets/138360003/b37de35d-6feb-4100-b73f-58ac23121483" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
26
+
27
+
28
+ ### 环境安装
29
+ * 方式1:如果你是Windows用户,推荐可以直接下载[整合包](https://github.com/warmshao/FasterLivePortrait/releases/tag/v1.8)。
30
+ * 需要先安装好[git](https://git-scm.com/downloads), 双击`update.bat`更新代码。
31
+ * 双击`scripts/all_onnx2trt.bat`转换onnx文件为tensorrt文件。
32
+ * 双击`webui.bat`打开网页,或者双击`camera.bat`打开摄像头实时运行。
33
+ * 方式2:Docker,提供了一个镜像,不用再自己安装onnxruntime-gpu和TensorRT。
34
+ * 根据自己的系统安装[docker](https://docs.docker.com/desktop/install/windows-install/)
35
+ * 下载镜像:`docker pull shaoguo/faster_liveportrait:v3`
36
+ * 执行命令, `$FasterLivePortrait_ROOT`要替换成你下载的FasterLivePortrait在本地的目录:
37
+ ```shell
38
+ docker run -it --gpus=all \
39
+ --name faster_liveportrait \
40
+ -v $FasterLivePortrait_ROOT:/root/FasterLivePortrait \
41
+ --restart=always \
42
+ -p 9870:9870 \
43
+ shaoguo/faster_liveportrait:v3 \
44
+ /bin/bash
45
+ ```
46
+ * 然后可以根据下面Onnxruntime 推理和TensorRT 推理教程进行使用。
47
+
48
+ * 方式3:新建一个python虚拟环境,自己安装必要的python包
49
+ * 请先安装[ffmpeg](https://www.ffmpeg.org/download.html)
50
+ * `pip install -r requirements.txt`
51
+ * 再根据以下教程安装onnxruntime-gpu或TensorRT。
52
+
53
+ ### 使用方法
54
+ #### 1. TensorRT 推理(推荐, 可以实时)
55
+ * (Docker环境可忽略)安装TensorRT,请记住[TensorRT](https://developer.nvidia.com/tensorrt)安装的路径。
56
+ * (Docker环境可忽略)安装 grid_sample的tensorrt插件,因为模型用到的grid sample需要有5d的输入,原生的grid_sample 算子不支持。
57
+ * `git clone https://github.com/SeanWangJS/grid-sample3d-trt-plugin`
58
+ * 修改`CMakeLists.txt`中第30行为:`set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "60;70;75;80;86")`
59
+ * `export PATH=/usr/local/cuda/bin:$PATH`
60
+ * `mkdir build && cd build`
61
+ * `cmake .. -DTensorRT_ROOT=$TENSORRT_HOME`,$TENSORRT_HOME 替换成你自己TensorRT的根目录。
62
+ * `make`,记住so文件的地址,将`scripts/onnx2trt.py`和`src/models/predictor.py`里`/opt/grid-sample3d-trt-plugin/build/libgrid_sample_3d_plugin.so`替换成自己的so路径
63
+ * 下载Onnx文件:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`。将onnx模型转为tensorrt,运行`sh scripts/all_onnx2trt.sh`和`sh scripts/all_onnx2trt_animal.sh`
64
+ * 用tensorrt测试pipeline:
65
+ ```shell
66
+ python run.py \
67
+ --src_image assets/examples/source/s10.jpg \
68
+ --dri_video assets/examples/driving/d14.mp4 \
69
+ --cfg configs/trt_infer.yaml
70
+ ```
71
+ 如果要使用摄像头实时运行:
72
+ ```shell
73
+ python run.py \
74
+ --src_image assets/examples/source/s10.jpg \
75
+ --dri_video 0 \
76
+ --cfg configs/trt_infer.yaml \
77
+ --realtime
78
+ ```
79
+ #### 2. Onnxruntime 推理
80
+ * 首先下载我转换好的[模型onnx文件](https://huggingface.co/warmshao/FasterLivePortrait): `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`。
81
+ * (Docker环境可忽略)如果你要用onnxruntime cpu推理的话,直接`pip install onnxruntime`即可,但是cpu推理超级慢。但是最新的onnxruntime-gpu仍然无法支持grid_sample cuda,好在我看到一位大佬在分支上支持了,按照以下步骤源码安装`onnxruntime-gpu`:
82
+ * `git clone https://github.com/microsoft/onnxruntime`
83
+ * `git checkout liqun/ImageDecoder-cuda`. Thanks for liqun's grid_sample with cuda implementation!
84
+ * 运行以下命令编译,`cuda_version`和`CMAKE_CUDA_ARCHITECTURES`根据自己的机器更改:
85
+ ```shell
86
+ ./build.sh --parallel \
87
+ --build_shared_lib --use_cuda \
88
+ --cuda_version 11.8 \
89
+ --cuda_home /usr/local/cuda --cudnn_home /usr/local/cuda/ \
90
+ --config Release --build_wheel --skip_tests \
91
+ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="60;70;75;80;86" \
92
+ --cmake_extra_defines CMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
93
+ --disable_contrib_ops \
94
+ --allow_running_as_root
95
+ ```
96
+ * `pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl`就可以了
97
+ * 用onnxruntime测试pipeline:
98
+ ```shell
99
+ python run.py \
100
+ --src_image assets/examples/source/s10.jpg \
101
+ --dri_video assets/examples/driving/d14.mp4 \
102
+ --cfg configs/onnx_infer.yaml
103
+ ```
104
+
105
+ ### Gradio WebUI
106
+ * onnxruntime: `python webui.py --mode onnx`
107
+ * tensorrt: `python webui.py --mode trt`
108
+ * 默认端口在9870,打开网页:`http://localhost:9870/`
109
+
110
+ Hotkeys for webcam mode (when render window is on focus)\
111
+ Q > exit\
112
+ S > Stitching\
113
+ Z > RelativeMotion\
114
+ X > AnimationRegion\
115
+ C > CropDrivingVideo\
116
+ K,L > AdjustSourceScale\
117
+ N,M > AdjustDriverScale
118
+
119
+ ## 许可证
120
+
121
+ - **代码**: 本项目采用 MIT 许可证 - 详细信息请查看 [LICENSE](LICENSE) 文件。
122
+ - **模型**: 本项目中使用的任何机器学习模型均遵循其各自的许可证。请参考原始模型来源获取许可证信息。我们不承担模型许可证合规性的责任。
123
+
124
+
125
+ **日志**
126
+ - [x] **2025/06/29:** [LivePortrait animal v1.1 onnx模型](https://huggingface.co/warmshao/FasterLivePortrait/tree/main/liveportrait_animal_onnx_v1.1)。
127
+ - [x] **2024/12/22:** 增加api部署`python api.py`, 其他参考[教程](assets/docs/API_ZH.md)使用。
128
+ - [x] **2024/12/21:** 增加[Kokoro-82M](hhttps://huggingface.co/hexgrad/Kokoro-82M)的支持,可以用文本驱动视频或图片。
129
+ - 更新代码, `git pull origin master`并安装最新的python依赖 `pip install requirements.txt`, 或者 windows下直接双击 `update.bat`.
130
+ - 然后下载模型: `huggingface-cli download hexgrad/Kokoro-82M --local-dir .\checkpoints\Kokoro-82M`.
131
+ - 如果是Linux请安装`apt-get -qq -y install espeak-ng > /dev/null 2>&1`
132
+ - 如果是windows请参考[自行安装](https://huggingface.co/hexgrad/Kokoro-82M/discussions/12)并配置好`espeak-ng`环境变量。我是在[这里](src/pipelines/gradio_live_portrait_pipeline.py:437)读取,如果你的位置变了,请自行修改。
133
+ - 然后就可以在Drive Text的标签页正常使用了。
134
+ - [x] **2024/12/16:** 增加[JoyVASA](https://github.com/jdh-algo/JoyVASA)的支持,可以用音频驱动视频或图片。非常酷!
135
+ - 更新代码,然后下载模型: `huggingface-cli download TencentGameMate/chinese-hubert-base --local-dir .\checkpoints\chinese-hubert-base` 和 ` huggingface-cli download jdh-algo/JoyVASA --local-dir ./checkpoints/JoyVASA`
136
+ - 启动webui后根据以下教程使用即可,建议source 是视频的情况下只驱动嘴部
137
+
138
+ <video src="https://github.com/user-attachments/assets/42fb24be-0cde-4138-9671-e52eec95e7f5" controls="controls" width="500" height="400">您的浏览器不支持播放该视频!</video>
139
+
140
+ - [x] **2024/12/14:** 增加pickle和image驱动以及区域驱动`animation_region`。
141
+ - 请更新最新的代码,windows用户可以直接双击`update.bat`更新,但请注意本地的代码将会被覆盖。
142
+ - `python run.py ` 现在运行 `driving video`会自动保存对应的pickle到跟`driving video`一样的目录,可以直接复用。
143
+ - 打开`webui`后即可体验新的pickle和image驱动以及区域驱动`animation_region`等功能。注意image驱动记得把`relative motion`取消掉。
144
+ - [x] **2024/08/11:** 优化paste_back的速度,修复一些bug。
145
+ - 用torchgeometry + cuda优化paste_back函数,现在速度提升了很多。示例:`python run.py --src_image assets/examples/source/s39.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_infer.yaml --paste_back --animal`
146
+ - 修复Xpose的ops在一些显卡运行报错的问题等bug。请使用最新的镜像:`docker pull shaoguo/faster_liveportrait:v3`
147
+ - [x] **2024/08/07:** 增加animal模型的支持,同时支持mediapipe模型,现在你不用再担心版权的问题。
148
+ - 增加对animal模型的支持。
149
+ - 需要下载animal的onnx文件:`huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`,然后转换成trt文件。
150
+ - 更新镜像`docker pull shaoguo/faster_liveportrait:v3`, 使用animal模型的示例:`python run.py --src_image assets/examples/source/s39.jpg --dri_video 0 --cfg configs/trt_infer.yaml --realtime --animal`
151
+ - windows系统可以从release页下载最新的[windows 整合包](https://github.com/warmshao/FasterLivePortrait/releases),解压后使用。
152
+ - 简单的使用教程:
153
+
154
+ <video src="https://github.com/user-attachments/assets/dc37e2dd-551a-43b0-8929-fc5d5fe16ec5" controls="controls" width="500" height="300">您的浏览器不支持播放该视频!</video>
155
+
156
+ - 使用mediapipe模型替代insight_face
157
+ - 网页端使用: `python webui.py --mode trt --mp` 或 `python webui.py --mode onnx --mp`
158
+ - 本地摄像头运行: `python run.py --src_image assets/examples/source/s12.jpg --dri_video assets/examples/driving/d0.mp4 --cfg configs/trt_mp_infer.yaml`
159
+ - [x] **2024/07/24:** Windows的整合包, 免安装一键运行,支持TensorRT和OnnxruntimeGPU。感谢@zhanghongyong123456在[issue](https://github.com/warmshao/FasterLivePortrait/issues/22)的贡献。
160
+ - 【可选】如果你的windows电脑已经装过cuda和cudnn,请忽略这一步。我只在cuda12.2上验证过,如果没安装cuda或报cuda相关的错,你需要按照以下步骤进行安装:
161
+ - 下载[cuda12.2](https://developer.nvidia.com/cuda-12-2-0-download-archive?target_os=Windows&target_arch=x86_64), 双击exe后按照默认设置一步步安装即可。
162
+ - 下载[cudnn](https://developer.nvidia.com/downloads/compute/cudnn/secure/8.9.7/local_installers/12.x/cudnn-windows-x86_64-8.9.7.29_cuda12-archive.zip) 压缩包,解压后将cudnn 文件夹下的lib、bin、include 文件夹复制到 CUDA12.2 文件夹下(默认为C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2)
163
+ - 从release页下载免安装[windows 整合包](https://github.com/warmshao/FasterLivePortrait/releases)并解压。
164
+ - 进入`FasterLivePortrait-windows`后双击`scripts/all_onnx2trt.bat`对onnx文件进行转换,这会等上一段时间。
165
+ - 网页端demo:双击`webui.bat`, 打开网页:`http://localhost:9870/`
166
+ - 摄像头实时运行,双击`camera.bat`,按`q`停止。如果你想更换目标图像,命令行运行:`camera.bat assets/examples/source/s9.jpg`。
167
+ - [x] **2024/07/18:** MacOS支持(不需要Docker,python就可以了),M1/M2的速度比较快,但还是很慢😟
168
+ - 安装ffmpeg: `brew install ffmpeg`
169
+ - 安装python=3.10的虚拟环境,推荐可以用[miniforge](https://github.com/conda-forge/miniforge).`conda create -n flip python=3.10 && conda activate flip`
170
+ - `pip install -r requirements_macos.txt`
171
+ - 下载onnx文件: `huggingface-cli download warmshao/FasterLivePortrait --local-dir ./checkpoints`
172
+ - 测试: `python webui.py --mode onnx`
173
+ - [x] **2024/07/17:** 增加docker环境的支持,提供可运行的镜像。
api.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/9/13 0:23
3
+ # @Project : FasterLivePortrait
4
+ # @FileName: api.py
5
+ import pdb
6
+ import shutil
7
+ from typing import Optional, Dict, Any
8
+ import io
9
+ import os
10
+ import subprocess
11
+ import uvicorn
12
+ import cv2
13
+ import time
14
+ import numpy as np
15
+ import os
16
+ import datetime
17
+ import platform
18
+ import pickle
19
+ from tqdm import tqdm
20
+ from pydantic import BaseModel
21
+ from fastapi import APIRouter, Depends, FastAPI, Request, Response, UploadFile
22
+ from fastapi import File, Body, Form
23
+ from omegaconf import OmegaConf
24
+ from fastapi.responses import StreamingResponse
25
+ from zipfile import ZipFile
26
+ from src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
27
+ from src.utils.utils import video_has_audio
28
+ from src.utils import logger
29
+
30
+ # model dir
31
+ project_dir = os.path.dirname(__file__)
32
+ checkpoints_dir = os.environ.get("FLIP_CHECKPOINT_DIR", os.path.join(project_dir, "checkpoints"))
33
+ log_dir = os.path.join(project_dir, "logs")
34
+ os.makedirs(log_dir, exist_ok=True)
35
+ result_dir = os.path.join(project_dir, "results")
36
+ os.makedirs(result_dir, exist_ok=True)
37
+
38
+ logger_f = logger.get_logger("faster_liveportrait_api", log_file=os.path.join(log_dir, "log_run.log"))
39
+
40
+ app = FastAPI()
41
+
42
+ global pipe
43
+
44
+ if platform.system().lower() == 'windows':
45
+ FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe"
46
+ else:
47
+ FFMPEG = "ffmpeg"
48
+
49
+
50
+ def check_all_checkpoints_exist(infer_cfg):
51
+ """
52
+ check whether all checkpoints exist
53
+ :return:
54
+ """
55
+ ret = True
56
+ for name in infer_cfg.models:
57
+ if not isinstance(infer_cfg.models[name].model_path, str):
58
+ for i in range(len(infer_cfg.models[name].model_path)):
59
+ infer_cfg.models[name].model_path[i] = infer_cfg.models[name].model_path[i].replace("./checkpoints",
60
+ checkpoints_dir)
61
+ if not os.path.exists(infer_cfg.models[name].model_path[i]) and not os.path.exists(
62
+ infer_cfg.models[name].model_path[i][:-4] + ".onnx"):
63
+ return False
64
+ else:
65
+ infer_cfg.models[name].model_path = infer_cfg.models[name].model_path.replace("./checkpoints",
66
+ checkpoints_dir)
67
+ if not os.path.exists(infer_cfg.models[name].model_path) and not os.path.exists(
68
+ infer_cfg.models[name].model_path[:-4] + ".onnx"):
69
+ return False
70
+ for name in infer_cfg.animal_models:
71
+ if not isinstance(infer_cfg.animal_models[name].model_path, str):
72
+ for i in range(len(infer_cfg.animal_models[name].model_path)):
73
+ infer_cfg.animal_models[name].model_path[i] = infer_cfg.animal_models[name].model_path[i].replace(
74
+ "./checkpoints",
75
+ checkpoints_dir)
76
+ if not os.path.exists(infer_cfg.animal_models[name].model_path[i]) and not os.path.exists(
77
+ infer_cfg.animal_models[name].model_path[i][:-4] + ".onnx"):
78
+ return False
79
+ else:
80
+ infer_cfg.animal_models[name].model_path = infer_cfg.animal_models[name].model_path.replace("./checkpoints",
81
+ checkpoints_dir)
82
+ if not os.path.exists(infer_cfg.animal_models[name].model_path) and not os.path.exists(
83
+ infer_cfg.animal_models[name].model_path[:-4] + ".onnx"):
84
+ return False
85
+
86
+ # XPOSE
87
+ xpose_model_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/xpose.pth")
88
+ if not os.path.exists(xpose_model_path):
89
+ return False
90
+ embeddings_cache_9_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/clip_embedding_9.pkl")
91
+ if not os.path.exists(embeddings_cache_9_path):
92
+ return False
93
+ embeddings_cache_68_path = os.path.join(checkpoints_dir, "liveportrait_animal_onnx/clip_embedding_68.pkl")
94
+ if not os.path.exists(embeddings_cache_68_path):
95
+ return False
96
+ return ret
97
+
98
+
99
+ def convert_onnx_to_trt_models(infer_cfg):
100
+ ret = True
101
+ for name in infer_cfg.models:
102
+ if not isinstance(infer_cfg.models[name].model_path, str):
103
+ for i in range(len(infer_cfg.models[name].model_path)):
104
+ trt_path = infer_cfg.models[name].model_path[i]
105
+ onnx_path = trt_path[:-4] + ".onnx"
106
+ if not os.path.exists(trt_path):
107
+ convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
108
+ logger_f.info(f"convert onnx model: {onnx_path}")
109
+ result = subprocess.run(convert_cmd, shell=True, check=True)
110
+ # 检查结果
111
+ if result.returncode == 0:
112
+ logger_f.info(f"convert onnx model: {onnx_path} successful")
113
+ else:
114
+ logger_f.error(f"convert onnx model: {onnx_path} failed")
115
+ return False
116
+ else:
117
+ trt_path = infer_cfg.models[name].model_path
118
+ onnx_path = trt_path[:-4] + ".onnx"
119
+ if not os.path.exists(trt_path):
120
+ convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
121
+ logger_f.info(f"convert onnx model: {onnx_path}")
122
+ result = subprocess.run(convert_cmd, shell=True, check=True)
123
+ # 检查结果
124
+ if result.returncode == 0:
125
+ logger_f.info(f"convert onnx model: {onnx_path} successful")
126
+ else:
127
+ logger_f.error(f"convert onnx model: {onnx_path} failed")
128
+ return False
129
+
130
+ for name in infer_cfg.animal_models:
131
+ if not isinstance(infer_cfg.animal_models[name].model_path, str):
132
+ for i in range(len(infer_cfg.animal_models[name].model_path)):
133
+ trt_path = infer_cfg.animal_models[name].model_path[i]
134
+ onnx_path = trt_path[:-4] + ".onnx"
135
+ if not os.path.exists(trt_path):
136
+ convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
137
+ logger_f.info(f"convert onnx model: {onnx_path}")
138
+ result = subprocess.run(convert_cmd, shell=True, check=True)
139
+ # 检查结果
140
+ if result.returncode == 0:
141
+ logger_f.info(f"convert onnx model: {onnx_path} successful")
142
+ else:
143
+ logger_f.error(f"convert onnx model: {onnx_path} failed")
144
+ return False
145
+ else:
146
+ trt_path = infer_cfg.animal_models[name].model_path
147
+ onnx_path = trt_path[:-4] + ".onnx"
148
+ if not os.path.exists(trt_path):
149
+ convert_cmd = f"python scripts/onnx2trt.py -o {onnx_path}"
150
+ logger_f.info(f"convert onnx model: {onnx_path}")
151
+ result = subprocess.run(convert_cmd, shell=True, check=True)
152
+ # 检查结果
153
+ if result.returncode == 0:
154
+ logger_f.info(f"convert onnx model: {onnx_path} successful")
155
+ else:
156
+ logger_f.error(f"convert onnx model: {onnx_path} failed")
157
+ return False
158
+ return ret
159
+
160
+
161
+ @app.on_event("startup")
162
+ async def startup_event():
163
+ global pipe
164
+ # default use trt model
165
+ cfg_file = os.path.join(project_dir, "configs/trt_infer.yaml")
166
+ infer_cfg = OmegaConf.load(cfg_file)
167
+ checkpoints_exist = check_all_checkpoints_exist(infer_cfg)
168
+
169
+ # first: download model if not exist
170
+ if not checkpoints_exist:
171
+ download_cmd = f"huggingface-cli download warmshao/FasterLivePortrait --local-dir {checkpoints_dir}"
172
+ logger_f.info(f"download model: {download_cmd}")
173
+ result = subprocess.run(download_cmd, shell=True, check=True)
174
+ # 检查结果
175
+ if result.returncode == 0:
176
+ logger_f.info(f"Download checkpoints to {checkpoints_dir} successful")
177
+ else:
178
+ logger_f.error(f"Download checkpoints to {checkpoints_dir} failed")
179
+ exit(1)
180
+ # second: convert onnx model to trt
181
+ convert_ret = convert_onnx_to_trt_models(infer_cfg)
182
+ if not convert_ret:
183
+ logger_f.error(f"convert onnx model to trt failed")
184
+ exit(1)
185
+
186
+ infer_cfg.infer_params.flag_pasteback = True
187
+ pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=True)
188
+
189
+
190
+ def run_with_video(source_image_path, driving_video_path, save_dir):
191
+ global pipe
192
+ ret = pipe.prepare_source(source_image_path, realtime=False)
193
+ if not ret:
194
+ logger_f.warning(f"no face in {source_image_path}! exit!")
195
+ return
196
+ vcap = cv2.VideoCapture(driving_video_path)
197
+ fps = int(vcap.get(cv2.CAP_PROP_FPS))
198
+ h, w = pipe.src_imgs[0].shape[:2]
199
+
200
+ # render output video
201
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
202
+ vsave_crop_path = os.path.join(save_dir,
203
+ f"{os.path.basename(source_image_path)}-{os.path.basename(driving_video_path)}-crop.mp4")
204
+ vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512 * 2, 512))
205
+ vsave_org_path = os.path.join(save_dir,
206
+ f"{os.path.basename(source_image_path)}-{os.path.basename(driving_video_path)}-org.mp4")
207
+ vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
208
+
209
+ infer_times = []
210
+ motion_lst = []
211
+ c_eyes_lst = []
212
+ c_lip_lst = []
213
+
214
+ frame_ind = 0
215
+ while vcap.isOpened():
216
+ ret, frame = vcap.read()
217
+ if not ret:
218
+ break
219
+ t0 = time.time()
220
+ first_frame = frame_ind == 0
221
+ dri_crop, out_crop, out_org, dri_motion_info = pipe.run(frame, pipe.src_imgs[0], pipe.src_infos[0],
222
+ first_frame=first_frame)
223
+ frame_ind += 1
224
+ if out_crop is None:
225
+ logger_f.warning(f"no face in driving frame:{frame_ind}")
226
+ continue
227
+
228
+ motion_lst.append(dri_motion_info[0])
229
+ c_eyes_lst.append(dri_motion_info[1])
230
+ c_lip_lst.append(dri_motion_info[2])
231
+
232
+ infer_times.append(time.time() - t0)
233
+ # print(time.time() - t0)
234
+ dri_crop = cv2.resize(dri_crop, (512, 512))
235
+ out_crop = np.concatenate([dri_crop, out_crop], axis=1)
236
+ out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
237
+ vout_crop.write(out_crop)
238
+ out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
239
+ vout_org.write(out_org)
240
+ vcap.release()
241
+ vout_crop.release()
242
+ vout_org.release()
243
+ if video_has_audio(driving_video_path):
244
+ vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4"
245
+ subprocess.call(
246
+ [FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path,
247
+ "-b:v", "10M", "-c:v",
248
+ "libx264", "-map", "0:v", "-map", "1:a",
249
+ "-c:a", "aac",
250
+ "-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
251
+ vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4"
252
+ subprocess.call(
253
+ [FFMPEG, "-i", vsave_org_path, "-i", driving_video_path,
254
+ "-b:v", "10M", "-c:v",
255
+ "libx264", "-map", "0:v", "-map", "1:a",
256
+ "-c:a", "aac",
257
+ "-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
258
+
259
+ logger_f.info(vsave_crop_path_new)
260
+ logger_f.info(vsave_org_path_new)
261
+ else:
262
+ logger_f.info(vsave_crop_path)
263
+ logger_f.info(vsave_org_path)
264
+
265
+ logger_f.info(
266
+ "inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
267
+ np.mean(infer_times) * 1000))
268
+ # save driving motion to pkl
269
+ template_dct = {
270
+ 'n_frames': len(motion_lst),
271
+ 'output_fps': fps,
272
+ 'motion': motion_lst,
273
+ 'c_eyes_lst': c_eyes_lst,
274
+ 'c_lip_lst': c_lip_lst,
275
+ }
276
+ template_pkl_path = os.path.join(save_dir,
277
+ f"{os.path.basename(driving_video_path)}.pkl")
278
+ with open(template_pkl_path, "wb") as fw:
279
+ pickle.dump(template_dct, fw)
280
+ logger_f.info(f"save driving motion pkl file at : {template_pkl_path}")
281
+
282
+
283
+ def run_with_pkl(source_image_path, driving_pickle_path, save_dir):
284
+ global pipe
285
+ ret = pipe.prepare_source(source_image_path, realtime=False)
286
+ if not ret:
287
+ logger_f.warning(f"no face in {source_image_path}! exit!")
288
+ return
289
+
290
+ with open(driving_pickle_path, "rb") as fin:
291
+ dri_motion_infos = pickle.load(fin)
292
+
293
+ fps = int(dri_motion_infos["output_fps"])
294
+ h, w = pipe.src_imgs[0].shape[:2]
295
+
296
+ # render output video
297
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
298
+ vsave_crop_path = os.path.join(save_dir,
299
+ f"{os.path.basename(source_image_path)}-{os.path.basename(driving_pickle_path)}-crop.mp4")
300
+ vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512, 512))
301
+ vsave_org_path = os.path.join(save_dir,
302
+ f"{os.path.basename(source_image_path)}-{os.path.basename(driving_pickle_path)}-org.mp4")
303
+ vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
304
+
305
+ infer_times = []
306
+ motion_lst = dri_motion_infos["motion"]
307
+ c_eyes_lst = dri_motion_infos["c_eyes_lst"] if "c_eyes_lst" in dri_motion_infos else dri_motion_infos[
308
+ "c_d_eyes_lst"]
309
+ c_lip_lst = dri_motion_infos["c_lip_lst"] if "c_lip_lst" in dri_motion_infos else dri_motion_infos["c_d_lip_lst"]
310
+
311
+ frame_num = len(motion_lst)
312
+ for frame_ind in tqdm(range(frame_num)):
313
+ t0 = time.time()
314
+ first_frame = frame_ind == 0
315
+ dri_motion_info_ = [motion_lst[frame_ind], c_eyes_lst[frame_ind], c_lip_lst[frame_ind]]
316
+ out_crop, out_org = pipe.run_with_pkl(dri_motion_info_, pipe.src_imgs[0], pipe.src_infos[0],
317
+ first_frame=first_frame)
318
+ if out_crop is None:
319
+ logger_f.warning(f"no face in driving frame:{frame_ind}")
320
+ continue
321
+
322
+ infer_times.append(time.time() - t0)
323
+ # print(time.time() - t0)
324
+ out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
325
+ vout_crop.write(out_crop)
326
+ out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
327
+ vout_org.write(out_org)
328
+
329
+ vout_crop.release()
330
+ vout_org.release()
331
+ logger_f.info(vsave_crop_path)
332
+ logger_f.info(vsave_org_path)
333
+ logger_f.info(
334
+ "inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
335
+ np.mean(infer_times) * 1000))
336
+
337
+
338
+ class LivePortraitParams(BaseModel):
339
+ flag_pickle: bool = False
340
+ flag_relative_input: bool = True
341
+ flag_do_crop_input: bool = True
342
+ flag_remap_input: bool = True
343
+ driving_multiplier: float = 1.0
344
+ flag_stitching: bool = True
345
+ flag_crop_driving_video_input: bool = True
346
+ flag_video_editing_head_rotation: bool = False
347
+ flag_is_animal: bool = True
348
+ scale: float = 2.3
349
+ vx_ratio: float = 0.0
350
+ vy_ratio: float = -0.125
351
+ scale_crop_driving_video: float = 2.2
352
+ vx_ratio_crop_driving_video: float = 0.0
353
+ vy_ratio_crop_driving_video: float = -0.1
354
+ driving_smooth_observation_variance: float = 1e-7
355
+
356
+
357
+ @app.post("/predict/")
358
+ async def upload_files(
359
+ source_image: Optional[UploadFile] = File(None),
360
+ driving_video: Optional[UploadFile] = File(None),
361
+ driving_pickle: Optional[UploadFile] = File(None),
362
+ flag_is_animal: bool = Form(...),
363
+ flag_pickle: bool = Form(...),
364
+ flag_relative_input: bool = Form(...),
365
+ flag_do_crop_input: bool = Form(...),
366
+ flag_remap_input: bool = Form(...),
367
+ driving_multiplier: float = Form(...),
368
+ flag_stitching: bool = Form(...),
369
+ flag_crop_driving_video_input: bool = Form(...),
370
+ flag_video_editing_head_rotation: bool = Form(...),
371
+ scale: float = Form(...),
372
+ vx_ratio: float = Form(...),
373
+ vy_ratio: float = Form(...),
374
+ scale_crop_driving_video: float = Form(...),
375
+ vx_ratio_crop_driving_video: float = Form(...),
376
+ vy_ratio_crop_driving_video: float = Form(...),
377
+ driving_smooth_observation_variance: float = Form(...)
378
+ ):
379
+ # 根据传入的表单参数构建 infer_params
380
+ infer_params = LivePortraitParams(
381
+ flag_is_animal=flag_is_animal,
382
+ flag_pickle=flag_pickle,
383
+ flag_relative_input=flag_relative_input,
384
+ flag_do_crop_input=flag_do_crop_input,
385
+ flag_remap_input=flag_remap_input,
386
+ driving_multiplier=driving_multiplier,
387
+ flag_stitching=flag_stitching,
388
+ flag_crop_driving_video_input=flag_crop_driving_video_input,
389
+ flag_video_editing_head_rotation=flag_video_editing_head_rotation,
390
+ scale=scale,
391
+ vx_ratio=vx_ratio,
392
+ vy_ratio=vy_ratio,
393
+ scale_crop_driving_video=scale_crop_driving_video,
394
+ vx_ratio_crop_driving_video=vx_ratio_crop_driving_video,
395
+ vy_ratio_crop_driving_video=vy_ratio_crop_driving_video,
396
+ driving_smooth_observation_variance=driving_smooth_observation_variance
397
+ )
398
+
399
+ global pipe
400
+ pipe.init_vars()
401
+ if infer_params.flag_is_animal != pipe.is_animal:
402
+ pipe.init_models(is_animal=infer_params.flag_is_animal)
403
+
404
+ args_user = {
405
+ 'flag_relative_motion': infer_params.flag_relative_input,
406
+ 'flag_do_crop': infer_params.flag_do_crop_input,
407
+ 'flag_pasteback': infer_params.flag_remap_input,
408
+ 'driving_multiplier': infer_params.driving_multiplier,
409
+ 'flag_stitching': infer_params.flag_stitching,
410
+ 'flag_crop_driving_video': infer_params.flag_crop_driving_video_input,
411
+ 'flag_video_editing_head_rotation': infer_params.flag_video_editing_head_rotation,
412
+ 'src_scale': infer_params.scale,
413
+ 'src_vx_ratio': infer_params.vx_ratio,
414
+ 'src_vy_ratio': infer_params.vy_ratio,
415
+ 'dri_scale': infer_params.scale_crop_driving_video,
416
+ 'dri_vx_ratio': infer_params.vx_ratio_crop_driving_video,
417
+ 'dri_vy_ratio': infer_params.vy_ratio_crop_driving_video,
418
+ }
419
+ # update config from user input
420
+ update_ret = pipe.update_cfg(args_user)
421
+
422
+ # 保存 source_image 到指定目录
423
+ temp_dir = os.path.join(result_dir, f"temp-{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}")
424
+ os.makedirs(temp_dir, exist_ok=True)
425
+ if source_image and source_image.filename:
426
+ source_image_path = os.path.join(temp_dir, source_image.filename)
427
+ with open(source_image_path, "wb") as buffer:
428
+ buffer.write(await source_image.read()) # 将内容写入文件
429
+ else:
430
+ source_image_path = None
431
+
432
+ if driving_video and driving_video.filename:
433
+ driving_video_path = os.path.join(temp_dir, driving_video.filename)
434
+ with open(driving_video_path, "wb") as buffer:
435
+ buffer.write(await driving_video.read()) # 将内容写入文件
436
+ else:
437
+ driving_video_path = None
438
+
439
+ if driving_pickle and driving_pickle.filename:
440
+ driving_pickle_path = os.path.join(temp_dir, driving_pickle.filename)
441
+ with open(driving_pickle_path, "wb") as buffer:
442
+ buffer.write(await driving_pickle.read()) # 将内容写入文件
443
+ else:
444
+ driving_pickle_path = None
445
+
446
+ save_dir = os.path.join(result_dir, f"{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}")
447
+ os.makedirs(save_dir, exist_ok=True)
448
+
449
+ if infer_params.flag_pickle:
450
+ if source_image_path and driving_pickle_path:
451
+ run_with_pkl(source_image_path, driving_pickle_path, save_dir)
452
+ else:
453
+ if source_image_path and driving_video_path:
454
+ run_with_video(source_image_path, driving_video_path, save_dir)
455
+ # zip all files and return
456
+ # 使用 BytesIO 在内存中创建一个字节流
457
+ zip_buffer = io.BytesIO()
458
+
459
+ # 使用 ZipFile 将文件夹内容压缩到 zip_buffer 中
460
+ with ZipFile(zip_buffer, "w") as zip_file:
461
+ for root, dirs, files in os.walk(save_dir):
462
+ for file in files:
463
+ file_path = os.path.join(root, file)
464
+ # 添加文件到 ZIP 文件中
465
+ zip_file.write(file_path, arcname=os.path.relpath(file_path, save_dir))
466
+
467
+ # 确保缓冲区指针在开始位置,以便读取整个内容
468
+ zip_buffer.seek(0)
469
+ shutil.rmtree(temp_dir)
470
+ shutil.rmtree(save_dir)
471
+ # 通过 StreamingResponse 返回 zip 文件
472
+ return StreamingResponse(zip_buffer, media_type="application/zip",
473
+ headers={"Content-Disposition": "attachment; filename=output.zip"})
474
+
475
+
476
+ if __name__ == "__main__":
477
+ import uvicorn
478
+
479
+ uvicorn.run(app, host=os.environ.get("FLIP_IP", "127.0.0.1"), port=os.environ.get("FLIP_PORT", 9871))
assets/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ examples/driving/*.pkl
2
+ examples/driving/*_crop.mp4
assets/docs/API.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## FasterLivePortrait API Usage Guide
2
+
3
+ ### Building the Image
4
+ * Decide on an image name, for example `shaoguo/faster_liveportrait_api:v1.0`. Replace the `-t` parameter in the following command with your chosen name.
5
+ * Run `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .`
6
+
7
+ ### Running the Image
8
+ Ensure that your machine has Nvidia GPU drivers installed. CUDA version should be 12.0 or higher. Two scenarios are described below.
9
+
10
+ * Running on a Local Machine (typically for self-testing)
11
+ * Modify the image name according to what you defined above.
12
+ * Confirm the service port number, default is `9871`. You can define your own by changing the `SERVER_PORT` environment variable in the command below. Remember to also change `-p 9871:9871` to map the port.
13
+ * Set the model path environment variable `CHECKPOINT_DIR`. If you've previously downloaded FasterLivePortrait's onnx model and converted it to trt, I recommend mapping the model files into the container using `-v`, for example `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`. This avoids re-downloading the onnx model and doing trt conversion. Otherwise, I will check if `CHECKPOINT_DIR` has models, and if not, I will automatically download (ensure network connectivity) and do trt conversion, which will take considerable time.
14
+ * Run command (note: modify the following command according to your settings):
15
+ ```shell
16
+ docker run -d --gpus=all \
17
+ --name faster_liveportrait_api \
18
+ -v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \
19
+ -e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \
20
+ -e SERVER_PORT=9871 \
21
+ -p 9871:9871 \
22
+ --restart=always \
23
+ shaoguo/faster_liveportrait_api:v1.0 \
24
+ /bin/bash
25
+ ```
26
+ * Normal operation should display the following information(docker logs $container_id). The running logs are saved in `/root/FasterLivePortrait/logs/log_run.log`:
27
+ ```shell
28
+ INFO: Application startup complete.
29
+ INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit)
30
+ ```
31
+
32
+ * Running on Cloud GPU Cluster (production environment)
33
+ * This needs to be configured according to different clusters, but the core is the configuration of docker image and environment variables.
34
+ * Load balancing may need to be set up.
35
+
36
+ ### API Call Testing
37
+ Refer to `tests/test_api.py`. The default is the Animal model, but now it also supports the Human model.
38
+ The return is a compressed package, by default unzipped to `./results/api_*`. Confirm according to the actual printed log.
39
+ * `test_with_video_animal()`, image and video driving. Set `flag_pickle=False`. It will additionally return the driving video's pkl file, which can be called directly next time.
40
+ * `test_with_pkl_animal()`, image and pkl driving.
41
+ * `test_with_video_human()`, image and video driving under the Human model, set `flag_is_animal=False`
assets/docs/API_ZH.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## FasterLivePortrait API使用教程
2
+
3
+ ### 构建镜像
4
+
5
+ * 确定镜像的名字,比如 `shaoguo/faster_liveportrait_api:v1.0`。确认后替换为下面命令 `-t` 的参数。
6
+ * 运行 `docker build -t shaoguo/faster_liveportrait_api:v1.0 -f DockerfileAPI .`
7
+
8
+ ### 运行镜像
9
+
10
+ 请确保你的机器已经装了Nvidia显卡的驱动。CUDA的版本在cuda12.0及以上。以下分两种情况介绍。
11
+
12
+ * 本地机器运行(一般自己测试使用)
13
+ * 镜像名称根据上面你自己定义的更改。
14
+ * 确认服务的端口号,默认为`9871`,你可以自己定义,更改下面命令里环境变量`SERVER_PORT`。同时要记得更改`-p 9871:9871`,
15
+ 将端口映射出来。
16
+ * 设置模型路径环境变量 `CHECKPOINT_DIR`。如果你之前下载过FasterLivePortrait的onnx模型并做过trt的转换,我建议
17
+ 是可以通过 `-v`把
18
+ 模型文件映射进入容器,比如 `-v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints`,
19
+ 这样就避免重新下载onnx模型和做trt的转换。否则我将会检测`CHECKPOINT_DIR`是否有模型,没有的话,我将自动下载(确保有网络)和做trt的转换,这将耗时比较久的时间。
20
+ * 运行命令(注意你要根据自己的设置更改以下命令的信息):
21
+ ```shell
22
+ docker run -d --gpus=all \
23
+ --name faster_liveportrait_api \
24
+ -v E:\my_projects\FasterLivePortrait\checkpoints:/root/FasterLivePortrait/checkpoints \
25
+ -e CHECKPOINT_DIR=/root/FasterLivePortrait/checkpoints \
26
+ -e SERVER_PORT=9871 \
27
+ -p 9871:9871 \
28
+ --restart=always \
29
+ shaoguo/faster_liveportrait_api:v1.0
30
+ ```
31
+ * 正常运行应该会显示以下信息(docker logs container_id), 运行的日志保存在`/root/FasterLivePortrait/logs/log_run.log`:
32
+ ```shell
33
+ INFO: Application startup complete.
34
+ INFO: Uvicorn running on http://0.0.0.0:9871 (Press CTRL+C to quit)
35
+ ```
36
+ * 云端GPU集群运行(生产环境)
37
+ * 这需要根据不同的集群做配置,但核心就是镜像和环境变量的配置。
38
+ * 可能要设置负载均衡。
39
+
40
+ ### API调用测试
41
+
42
+ 可以参考`tests/test_api.py`, 默认是Animal的模型,但现在同时也支持Human的模型了。
43
+ 返回的是压缩包,默认解压在`./results/api_*`, 根据实际打印出来的日志确认。
44
+
45
+ * `test_with_video_animal()`, 图像和视频的驱动。设置`flag_pickle=False`。会额外返回driving video的pkl文件,下次可以直接调用。
46
+ * `test_with_pkl_animal()`, 图像和pkl的驱动。
47
+ * `test_with_video_human()`, Human模型下图像和视频的驱动,设置`flag_is_animal=False`
assets/gradio/gradio_description_animate_clear.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <div style="font-size: 1.2em; text-align: center;">
2
+ Step 3: Click the <strong>🚀 Animate</strong> button below to generate, or click <strong>🧹 Clear</strong> to erase the results
3
+ </div>
4
+ <!-- <div style="font-size: 1.1em; text-align: center;">
5
+ <strong style="color: red;">Note:</strong> If both <strong>Source Image</strong> and <strong>Video</strong> are uploaded, the <strong>Source Image</strong> will be used. Please click the <strong>🧹 Clear</strong> button, then re-upload the <strong>Source Image</strong> or <strong>Video</strong>.
6
+ </div> -->
assets/gradio/gradio_description_animation.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <span style="font-size: 1.2em;">🔥 To animate the source image or video with the driving video, please follow these steps:</span>
2
+ <div style="font-size: 1.2em; margin-left: 20px;">
3
+ 1. In the <strong>Animation Options for Source Image or Video</strong> section, we recommend enabling the <code>do crop (source)</code> option if faces occupy a small portion of your source image or video.
4
+ </div>
5
+ <div style="font-size: 1.2em; margin-left: 20px;">
6
+ 2. In the <strong>Animation Options for Driving Video</strong> section, the <code>relative head rotation</code> and <code>smooth strength</code> options only take effect if the source input is a video.
7
+ </div>
8
+ <div style="font-size: 1.2em; margin-left: 20px;">
9
+ 3. Press the <strong>🚀 Animate</strong> button and wait for a moment. Your animated video will appear in the result block. This may take a few moments. If the input is a source video, the length of the animated video is the minimum of the length of the source video and the driving video.
10
+ </div>
11
+ <div style="font-size: 1.2em; margin-left: 20px;">
12
+ 4. If you want to upload your own driving video, <strong>the best practice</strong>:
13
+
14
+ - Crop it to a 1:1 aspect ratio (e.g., 512x512 or 256x256 pixels), or enable auto-driving by checking `do crop (driving video)`.
15
+ - Focus on the head area, similar to the example videos.
16
+ - Minimize shoulder movement.
17
+ - Make sure the first frame of driving video is a frontal face with **neutral expression**.
18
+
19
+ </div>
assets/gradio/gradio_description_retargeting.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <br>
2
+
3
+ <!-- ## Retargeting -->
4
+ <!-- <span style="font-size: 1.2em;">🔥 To edit the eyes and lip open ratio of the source portrait, drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span> -->
5
+
6
+
7
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 1.2em;">
8
+ <div>
9
+ <h2>Retargeting</h2>
10
+ <p>Upload a Source Portrait as Retargeting Input, then drag the sliders and click the <strong>🚗 Retargeting</strong> button. You can try running it multiple times.
11
+ <br>
12
+ <strong>😊 Set both ratios to 0.8 to see what's going on!</strong></p>
13
+ </div>
14
+ </div>
assets/gradio/gradio_description_upload.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <br>
2
+ <div style="font-size: 1.2em; display: flex; justify-content: space-between;">
3
+ <div style="flex: 1; text-align: center; margin-right: 20px;">
4
+ <div style="display: inline-block;">
5
+ Step 1: Upload a <strong>Source Image</strong> or <strong>Video</strong> (any aspect ratio) ⬇️
6
+ </div>
7
+ </div>
8
+ <div style="flex: 1; text-align: center; margin-left: 20px;">
9
+ <div style="display: inline-block;">
10
+ Step 2: Upload a <strong>Driving Video</strong> (any aspect ratio) ⬇️
11
+ </div>
12
+ <div style="display: inline-block; font-size: 0.8em;">
13
+ <strong>Tips:</strong> Focus on the head, minimize shoulder movement, <strong>neutral expression</strong> in first frame.
14
+ </div>
15
+ </div>
16
+ </div>
assets/gradio/gradio_title.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
2
+ <div>
3
+ <h1>FasterLivePortrait: Bring Portraits to Life in Real Time</h1>
4
+ <span>Built on <a href="https://github.com/KwaiVGI/LivePortrait">LivePortrait</a></span>
5
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center; margin-top: 10px;">
6
+ <a href="https://huggingface.co/warmshao/FasterLivePortrait">
7
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue" alt="Hugging Face Spaces">
8
+ </a>
9
+ &nbsp;
10
+ <a href="https://github.com/warmshao/FasterLivePortrait">
11
+ <img src="https://img.shields.io/badge/Github-Code-blue" alt="Github Code">
12
+ </a>
13
+ &nbsp;
14
+ <a href="https://github.com/warmshao/FasterLivePortrait">
15
+ <img src="https://img.shields.io/github/stars/warmshao/FasterLivePortrait" alt="Github Stars">
16
+ </a>
17
+ </div>
18
+ </div>
19
+ </div>
assets/mask_template.png ADDED
camera.bat ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal enabledelayedexpansion
3
+
4
+ REM 设置默认源图像路径
5
+ set "default_src_image=assets\examples\source\s12.jpg"
6
+ set "src_image=%default_src_image%"
7
+ set "animal_param="
8
+ set "paste_back="
9
+
10
+ REM 解析命名参数
11
+ :parse_args
12
+ if "%~1"=="" goto end_parse_args
13
+ if /i "%~1"=="--src_image" (
14
+ set "src_image=%~2"
15
+ shift
16
+ ) else if /i "%~1"=="--animal" (
17
+ set "animal_param=--animal"
18
+ ) else if /i "%~1"=="--paste_back" (
19
+ set "paste_back=--paste_back"
20
+ )
21
+ shift
22
+ goto parse_args
23
+ :end_parse_args
24
+
25
+ echo source image: [!src_image!]
26
+ echo use animal: [!animal_param!]
27
+ echo paste_back: [!paste_back!]
28
+
29
+ REM 执行Python命令
30
+ .\venv\python.exe .\run.py --cfg configs/trt_infer.yaml --realtime --dri_video 0 --src_image !src_image! !animal_param! !paste_back!
31
+
32
+ endlocal
configs/onnx_infer.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ warping_spade:
3
+ name: "WarpingSpadeModel"
4
+ predict_type: "ort"
5
+ model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx"
6
+ motion_extractor:
7
+ name: "MotionExtractorModel"
8
+ predict_type: "ort"
9
+ model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx"
10
+ landmark:
11
+ name: "LandmarkModel"
12
+ predict_type: "ort"
13
+ model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
14
+ face_analysis:
15
+ name: "FaceAnalysisModel"
16
+ predict_type: "ort"
17
+ model_path:
18
+ - "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx"
19
+ - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx"
20
+ app_feat_extractor:
21
+ name: "AppearanceFeatureExtractorModel"
22
+ predict_type: "ort"
23
+ model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx"
24
+ stitching:
25
+ name: "StitchingModel"
26
+ predict_type: "ort"
27
+ model_path: "./checkpoints/liveportrait_onnx/stitching.onnx"
28
+ stitching_eye_retarget:
29
+ name: "StitchingModel"
30
+ predict_type: "ort"
31
+ model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx"
32
+ stitching_lip_retarget:
33
+ name: "StitchingModel"
34
+ predict_type: "ort"
35
+ model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx"
36
+
37
+ animal_models:
38
+ warping_spade:
39
+ name: "WarpingSpadeModel"
40
+ predict_type: "ort"
41
+ model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx"
42
+ motion_extractor:
43
+ name: "MotionExtractorModel"
44
+ predict_type: "ort"
45
+ model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx"
46
+ app_feat_extractor:
47
+ name: "AppearanceFeatureExtractorModel"
48
+ predict_type: "ort"
49
+ model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx"
50
+ stitching:
51
+ name: "StitchingModel"
52
+ predict_type: "ort"
53
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx"
54
+ stitching_eye_retarget:
55
+ name: "StitchingModel"
56
+ predict_type: "ort"
57
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx"
58
+ stitching_lip_retarget:
59
+ name: "StitchingModel"
60
+ predict_type: "ort"
61
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx"
62
+ landmark:
63
+ name: "LandmarkModel"
64
+ predict_type: "ort"
65
+ model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
66
+ face_analysis:
67
+ name: "FaceAnalysisModel"
68
+ predict_type: "ort"
69
+ model_path:
70
+ - "./checkpoints/liveportrait_onnx/retinaface_det_static.onnx"
71
+ - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx"
72
+
73
+ joyvasa_models:
74
+ motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
75
+ audio_model_path: "checkpoints/chinese-hubert-base"
76
+ motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
77
+
78
+ crop_params:
79
+ src_dsize: 512
80
+ src_scale: 2.3
81
+ src_vx_ratio: 0.0
82
+ src_vy_ratio: -0.125
83
+ dri_scale: 2.2
84
+ dri_vx_ratio: 0.0
85
+ dri_vy_ratio: -0.1
86
+
87
+
88
+ infer_params:
89
+ flag_crop_driving_video: False
90
+ flag_normalize_lip: True
91
+ flag_source_video_eye_retargeting: False
92
+ flag_video_editing_head_rotation: False
93
+ flag_eye_retargeting: False
94
+ flag_lip_retargeting: False
95
+ flag_stitching: True
96
+ flag_relative_motion: True
97
+ flag_pasteback: True
98
+ flag_do_crop: True
99
+ flag_do_rot: True
100
+
101
+ # NOT EXPOERTED PARAMS
102
+ lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
103
+ source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
104
+ driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
105
+ anchor_frame: 0 # TO IMPLEMENT
106
+ mask_crop_path: "./assets/mask_template.png"
107
+ driving_multiplier: 1.0
108
+ animation_region: "all"
109
+
110
+ cfg_mode: "incremental"
111
+ cfg_scale: 1.2
112
+
113
+ source_max_dim: 1280 # the max dim of height and width of source image
114
+ source_division: 2 # make sure the height and width of source image can be divided by this number
configs/onnx_mp_infer.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ warping_spade:
3
+ name: "WarpingSpadeModel"
4
+ predict_type: "ort"
5
+ model_path: "./checkpoints/liveportrait_onnx/warping_spade.onnx"
6
+ motion_extractor:
7
+ name: "MotionExtractorModel"
8
+ predict_type: "ort"
9
+ model_path: "./checkpoints/liveportrait_onnx/motion_extractor.onnx"
10
+ landmark:
11
+ name: "LandmarkModel"
12
+ predict_type: "ort"
13
+ model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
14
+ face_analysis:
15
+ name: "MediaPipeFaceModel"
16
+ predict_type: "mp"
17
+ app_feat_extractor:
18
+ name: "AppearanceFeatureExtractorModel"
19
+ predict_type: "ort"
20
+ model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx"
21
+ stitching:
22
+ name: "StitchingModel"
23
+ predict_type: "ort"
24
+ model_path: "./checkpoints/liveportrait_onnx/stitching.onnx"
25
+ stitching_eye_retarget:
26
+ name: "StitchingModel"
27
+ predict_type: "ort"
28
+ model_path: "./checkpoints/liveportrait_onnx/stitching_eye.onnx"
29
+ stitching_lip_retarget:
30
+ name: "StitchingModel"
31
+ predict_type: "ort"
32
+ model_path: "./checkpoints/liveportrait_onnx/stitching_lip.onnx"
33
+
34
+ animal_models:
35
+ warping_spade:
36
+ name: "WarpingSpadeModel"
37
+ predict_type: "ort"
38
+ model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade.onnx"
39
+ motion_extractor:
40
+ name: "MotionExtractorModel"
41
+ predict_type: "ort"
42
+ model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor.onnx"
43
+ app_feat_extractor:
44
+ name: "AppearanceFeatureExtractorModel"
45
+ predict_type: "ort"
46
+ model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor.onnx"
47
+ stitching:
48
+ name: "StitchingModel"
49
+ predict_type: "ort"
50
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching.onnx"
51
+ stitching_eye_retarget:
52
+ name: "StitchingModel"
53
+ predict_type: "ort"
54
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye.onnx"
55
+ stitching_lip_retarget:
56
+ name: "StitchingModel"
57
+ predict_type: "ort"
58
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip.onnx"
59
+ landmark:
60
+ name: "LandmarkModel"
61
+ predict_type: "ort"
62
+ model_path: "./checkpoints/liveportrait_onnx/landmark.onnx"
63
+ face_analysis:
64
+ name: "MediaPipeFaceModel"
65
+ predict_type: "mp"
66
+
67
+ joyvasa_models:
68
+ motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
69
+ audio_model_path: "checkpoints/chinese-hubert-base"
70
+ motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
71
+
72
+ crop_params:
73
+ src_dsize: 512
74
+ src_scale: 2.3
75
+ src_vx_ratio: 0.0
76
+ src_vy_ratio: -0.125
77
+ dri_scale: 2.2
78
+ dri_vx_ratio: 0.0
79
+ dri_vy_ratio: -0.1
80
+
81
+
82
+ infer_params:
83
+ flag_crop_driving_video: False
84
+ flag_normalize_lip: True
85
+ flag_source_video_eye_retargeting: False
86
+ flag_video_editing_head_rotation: False
87
+ flag_eye_retargeting: False
88
+ flag_lip_retargeting: False
89
+ flag_stitching: True
90
+ flag_relative_motion: True
91
+ flag_pasteback: True
92
+ flag_do_crop: True
93
+ flag_do_rot: True
94
+
95
+ # NOT EXPOERTED PARAMS
96
+ lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
97
+ source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
98
+ driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
99
+ anchor_frame: 0 # TO IMPLEMENT
100
+ mask_crop_path: "./assets/mask_template.png"
101
+ driving_multiplier: 1.0
102
+ animation_region: "all"
103
+
104
+ cfg_mode: "incremental"
105
+ cfg_scale: 1.2
106
+
107
+ source_max_dim: 1280 # the max dim of height and width of source image
108
+ source_division: 2 # make sure the height and width of source image can be divided by this number
configs/trt_infer.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ warping_spade:
3
+ name: "WarpingSpadeModel"
4
+ predict_type: "trt"
5
+ model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt"
6
+ motion_extractor:
7
+ name: "MotionExtractorModel"
8
+ predict_type: "trt"
9
+ model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt"
10
+ landmark:
11
+ name: "LandmarkModel"
12
+ predict_type: "trt"
13
+ model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
14
+ face_analysis:
15
+ name: "FaceAnalysisModel"
16
+ predict_type: "trt"
17
+ model_path:
18
+ - "./checkpoints/liveportrait_onnx/retinaface_det_static.trt"
19
+ - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt"
20
+ app_feat_extractor:
21
+ name: "AppearanceFeatureExtractorModel"
22
+ predict_type: "trt"
23
+ model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt"
24
+ stitching:
25
+ name: "StitchingModel"
26
+ predict_type: "trt"
27
+ model_path: "./checkpoints/liveportrait_onnx/stitching.trt"
28
+ stitching_eye_retarget:
29
+ name: "StitchingModel"
30
+ predict_type: "trt"
31
+ model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt"
32
+ stitching_lip_retarget:
33
+ name: "StitchingModel"
34
+ predict_type: "trt"
35
+ model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt"
36
+
37
+ animal_models:
38
+ warping_spade:
39
+ name: "WarpingSpadeModel"
40
+ predict_type: "trt"
41
+ model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt"
42
+ motion_extractor:
43
+ name: "MotionExtractorModel"
44
+ predict_type: "trt"
45
+ model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt"
46
+ app_feat_extractor:
47
+ name: "AppearanceFeatureExtractorModel"
48
+ predict_type: "trt"
49
+ model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt"
50
+ stitching:
51
+ name: "StitchingModel"
52
+ predict_type: "trt"
53
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt"
54
+ stitching_eye_retarget:
55
+ name: "StitchingModel"
56
+ predict_type: "trt"
57
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt"
58
+ stitching_lip_retarget:
59
+ name: "StitchingModel"
60
+ predict_type: "trt"
61
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt"
62
+ landmark:
63
+ name: "LandmarkModel"
64
+ predict_type: "trt"
65
+ model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
66
+ face_analysis:
67
+ name: "FaceAnalysisModel"
68
+ predict_type: "trt"
69
+ model_path:
70
+ - "./checkpoints/liveportrait_onnx/retinaface_det_static.trt"
71
+ - "./checkpoints/liveportrait_onnx/face_2dpose_106_static.trt"
72
+
73
+ joyvasa_models:
74
+ motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
75
+ audio_model_path: "checkpoints/chinese-hubert-base"
76
+ motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
77
+
78
+ crop_params:
79
+ src_dsize: 512
80
+ src_scale: 2.3
81
+ src_vx_ratio: 0.0
82
+ src_vy_ratio: -0.125
83
+ dri_scale: 2.2
84
+ dri_vx_ratio: 0.0
85
+ dri_vy_ratio: -0.1
86
+
87
+
88
+ infer_params:
89
+ flag_crop_driving_video: False
90
+ flag_normalize_lip: True
91
+ flag_source_video_eye_retargeting: False
92
+ flag_video_editing_head_rotation: False
93
+ flag_eye_retargeting: False
94
+ flag_lip_retargeting: False
95
+ flag_stitching: True
96
+ flag_relative_motion: True
97
+ flag_pasteback: True
98
+ flag_do_crop: True
99
+ flag_do_rot: True
100
+
101
+ # NOT EXPOERTED PARAMS
102
+ lip_normalize_threshold: 0.1 # threshold for flag_normalize_lip
103
+ source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
104
+ driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
105
+ anchor_frame: 0 # TO IMPLEMENT
106
+ mask_crop_path: "./assets/mask_template.png"
107
+ driving_multiplier: 1.0
108
+ animation_region: "all"
109
+
110
+ cfg_mode: "incremental"
111
+ cfg_scale: 1.2
112
+
113
+ source_max_dim: 1280 # the max dim of height and width of source image
114
+ source_division: 2 # make sure the height and width of source image can be divided by this number
configs/trt_mp_infer.yaml ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models:
2
+ warping_spade:
3
+ name: "WarpingSpadeModel"
4
+ predict_type: "trt"
5
+ model_path: "./checkpoints/liveportrait_onnx/warping_spade-fix.trt"
6
+ motion_extractor:
7
+ name: "MotionExtractorModel"
8
+ predict_type: "trt"
9
+ model_path: "./checkpoints/liveportrait_onnx/motion_extractor.trt"
10
+ landmark:
11
+ name: "LandmarkModel"
12
+ predict_type: "trt"
13
+ model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
14
+ face_analysis:
15
+ name: "MediaPipeFaceModel"
16
+ predict_type: "mp"
17
+ app_feat_extractor:
18
+ name: "AppearanceFeatureExtractorModel"
19
+ predict_type: "trt"
20
+ model_path: "./checkpoints/liveportrait_onnx/appearance_feature_extractor.trt"
21
+ stitching:
22
+ name: "StitchingModel"
23
+ predict_type: "trt"
24
+ model_path: "./checkpoints/liveportrait_onnx/stitching.trt"
25
+ stitching_eye_retarget:
26
+ name: "StitchingModel"
27
+ predict_type: "trt"
28
+ model_path: "./checkpoints/liveportrait_onnx/stitching_eye.trt"
29
+ stitching_lip_retarget:
30
+ name: "StitchingModel"
31
+ predict_type: "trt"
32
+ model_path: "./checkpoints/liveportrait_onnx/stitching_lip.trt"
33
+
34
+ animal_models:
35
+ warping_spade:
36
+ name: "WarpingSpadeModel"
37
+ predict_type: "trt"
38
+ model_path: "./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.trt"
39
+ motion_extractor:
40
+ name: "MotionExtractorModel"
41
+ predict_type: "trt"
42
+ model_path: "./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.trt"
43
+ app_feat_extractor:
44
+ name: "AppearanceFeatureExtractorModel"
45
+ predict_type: "trt"
46
+ model_path: "./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.trt"
47
+ stitching:
48
+ name: "StitchingModel"
49
+ predict_type: "trt"
50
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching-v1.1.trt"
51
+ stitching_eye_retarget:
52
+ name: "StitchingModel"
53
+ predict_type: "trt"
54
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.trt"
55
+ stitching_lip_retarget:
56
+ name: "StitchingModel"
57
+ predict_type: "trt"
58
+ model_path: "./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.trt"
59
+ landmark:
60
+ name: "LandmarkModel"
61
+ predict_type: "trt"
62
+ model_path: "./checkpoints/liveportrait_onnx/landmark.trt"
63
+ face_analysis:
64
+ name: "MediaPipeFaceModel"
65
+ predict_type: "mp"
66
+
67
+ joyvasa_models:
68
+ motion_model_path: "checkpoints/JoyVASA/motion_generator/motion_generator_hubert_chinese.pt"
69
+ audio_model_path: "checkpoints/chinese-hubert-base"
70
+ motion_template_path: "checkpoints/JoyVASA/motion_template/motion_template.pkl"
71
+
72
+ crop_params:
73
+ src_dsize: 512
74
+ src_scale: 2.3
75
+ src_vx_ratio: 0.0
76
+ src_vy_ratio: -0.125
77
+ dri_scale: 2.2
78
+ dri_vx_ratio: 0.0
79
+ dri_vy_ratio: -0.1
80
+
81
+
82
+ infer_params:
83
+ flag_crop_driving_video: False
84
+ flag_normalize_lip: True
85
+ flag_source_video_eye_retargeting: False
86
+ flag_video_editing_head_rotation: False
87
+ flag_eye_retargeting: False
88
+ flag_lip_retargeting: False
89
+ flag_stitching: True
90
+ flag_relative_motion: True
91
+ flag_pasteback: True
92
+ flag_do_crop: True
93
+ flag_do_rot: True
94
+ animation_region: "all"
95
+
96
+ # NOT EXPOERTED PARAMS
97
+ lip_normalize_threshold: 0.03 # threshold for flag_normalize_lip
98
+ source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video
99
+ driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
100
+ anchor_frame: 0 # TO IMPLEMENT
101
+ mask_crop_path: "./assets/mask_template.png"
102
+ driving_multiplier: 1.0
103
+
104
+ cfg_mode: "incremental"
105
+ cfg_scale: 1.2
106
+
107
+ source_max_dim: 1280 # the max dim of height and width of source image
108
+ source_division: 2 # make sure the height and width of source image can be divided by this number
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ffmpeg-python
2
+ omegaconf
3
+ onnx
4
+ pycuda
5
+ numpy
6
+ opencv-python
7
+ gradio
8
+ scikit-image
9
+ insightface
10
+ huggingface_hub[cli]
11
+ mediapipe
12
+ torchgeometry
13
+ soundfile
14
+ munch
15
+ phonemizer
16
+ kokoro>=0.3.4
17
+ misaki[ja]
18
+ misaki[zh]
requirements_macos.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ffmpeg-python
2
+ omegaconf
3
+ onnx
4
+ onnxruntime
5
+ numpy
6
+ opencv-python
7
+ gradio
8
+ scikit-image
9
+ insightface
10
+ huggingface_hub[cli]
11
+ mediapipe
12
+ torchgeometry
13
+ soundfile
14
+ munch
15
+ phonemizer
16
+ kokoro>=0.3.4
17
+ misaki[ja]
18
+ misaki[zh]
requirements_win.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ffmpeg-python
2
+ omegaconf
3
+ onnx
4
+ numpy
5
+ opencv-python
6
+ gradio
7
+ scikit-image
8
+ insightface
9
+ huggingface_hub[cli]
10
+ mediapipe
11
+ torchgeometry
12
+ soundfile
13
+ munch
14
+ phonemizer
15
+ kokoro>=0.3.4
16
+ misaki[ja]
17
+ misaki[zh]
run.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : wenshao
3
+ # @Email : [email protected]
4
+ # @Project : FasterLivePortrait
5
+ # @FileName: run.py
6
+
7
+ """
8
+ # video
9
+ python run.py \
10
+ --src_image assets/examples/driving/d13.mp4 \
11
+ --dri_video assets/examples/driving/d11.mp4 \
12
+ --cfg configs/trt_infer.yaml \
13
+ --paste_back \
14
+ --animal
15
+ # pkl
16
+ python run.py \
17
+ --src_image assets/examples/source/s12.jpg \
18
+ --dri_video ./results/2024-09-13-081710/d0.mp4.pkl \
19
+ --cfg configs/trt_infer.yaml \
20
+ --paste_back \
21
+ --animal
22
+ """
23
+ import os
24
+ import argparse
25
+ import pdb
26
+ import subprocess
27
+ import ffmpeg
28
+ import cv2
29
+ import time
30
+ import numpy as np
31
+ import os
32
+ import datetime
33
+ import platform
34
+ import pickle
35
+ from omegaconf import OmegaConf
36
+ from tqdm import tqdm
37
+ from colorama import Fore, Back, Style
38
+ from src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
39
+ from src.utils.utils import video_has_audio
40
+
41
+ if platform.system().lower() == 'windows':
42
+ FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe"
43
+ else:
44
+ FFMPEG = "ffmpeg"
45
+
46
+
47
+ def run_with_video(args):
48
+ print(Fore.RED+'Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate'+Style.RESET_ALL)
49
+ infer_cfg = OmegaConf.load(args.cfg)
50
+ infer_cfg.infer_params.flag_pasteback = args.paste_back
51
+
52
+ pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=args.animal)
53
+ ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
54
+ if not ret:
55
+ print(f"no face in {args.src_image}! exit!")
56
+ exit(1)
57
+ if not args.dri_video or not os.path.exists(args.dri_video):
58
+ # read frame from camera if no driving video input
59
+ vcap = cv2.VideoCapture(0)
60
+ if not vcap.isOpened():
61
+ print("no camera found! exit!")
62
+ exit(1)
63
+ else:
64
+ vcap = cv2.VideoCapture(args.dri_video)
65
+ fps = int(vcap.get(cv2.CAP_PROP_FPS))
66
+ h, w = pipe.src_imgs[0].shape[:2]
67
+ save_dir = f"./results/{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
68
+ os.makedirs(save_dir, exist_ok=True)
69
+
70
+ # render output video
71
+ if not args.realtime:
72
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
73
+ vsave_crop_path = os.path.join(save_dir,
74
+ f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4")
75
+ vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512 * 2, 512))
76
+ vsave_org_path = os.path.join(save_dir,
77
+ f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4")
78
+ vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
79
+
80
+ infer_times = []
81
+ motion_lst = []
82
+ c_eyes_lst = []
83
+ c_lip_lst = []
84
+
85
+ frame_ind = 0
86
+ while vcap.isOpened():
87
+ ret, frame = vcap.read()
88
+ if not ret:
89
+ break
90
+ t0 = time.time()
91
+ first_frame = frame_ind == 0
92
+ dri_crop, out_crop, out_org, dri_motion_info = pipe.run(frame, pipe.src_imgs[0], pipe.src_infos[0],
93
+ first_frame=first_frame)
94
+ frame_ind += 1
95
+ if out_crop is None:
96
+ print(f"no face in driving frame:{frame_ind}")
97
+ continue
98
+
99
+ motion_lst.append(dri_motion_info[0])
100
+ c_eyes_lst.append(dri_motion_info[1])
101
+ c_lip_lst.append(dri_motion_info[2])
102
+
103
+ infer_times.append(time.time() - t0)
104
+ # print(time.time() - t0)
105
+ dri_crop = cv2.resize(dri_crop, (512, 512))
106
+ out_crop = np.concatenate([dri_crop, out_crop], axis=1)
107
+ out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
108
+ if not args.realtime:
109
+ vout_crop.write(out_crop)
110
+ out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
111
+ vout_org.write(out_org)
112
+ else:
113
+ if infer_cfg.infer_params.flag_pasteback:
114
+ out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
115
+ cv2.imshow('Render', out_org)
116
+ else:
117
+ # image show in realtime mode
118
+ cv2.imshow('Render', out_crop)
119
+ # 按下'q'键退出循环
120
+ if cv2.waitKey(1) & 0xFF == ord('q'):
121
+ break
122
+ vcap.release()
123
+ if not args.realtime:
124
+ vout_crop.release()
125
+ vout_org.release()
126
+ if video_has_audio(args.dri_video):
127
+ vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4"
128
+ subprocess.call(
129
+ [FFMPEG, "-i", vsave_crop_path, "-i", args.dri_video,
130
+ "-b:v", "10M", "-c:v",
131
+ "libx264", "-map", "0:v", "-map", "1:a",
132
+ "-c:a", "aac",
133
+ "-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
134
+ vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4"
135
+ subprocess.call(
136
+ [FFMPEG, "-i", vsave_org_path, "-i", args.dri_video,
137
+ "-b:v", "10M", "-c:v",
138
+ "libx264", "-map", "0:v", "-map", "1:a",
139
+ "-c:a", "aac",
140
+ "-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
141
+
142
+ print(vsave_crop_path_new)
143
+ print(vsave_org_path_new)
144
+ else:
145
+ print(vsave_crop_path)
146
+ print(vsave_org_path)
147
+ else:
148
+ cv2.destroyAllWindows()
149
+
150
+ print(
151
+ "inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
152
+ np.mean(infer_times) * 1000))
153
+ # save driving motion to pkl
154
+ template_dct = {
155
+ 'n_frames': len(motion_lst),
156
+ 'output_fps': fps,
157
+ 'motion': motion_lst,
158
+ 'c_eyes_lst': c_eyes_lst,
159
+ 'c_lip_lst': c_lip_lst,
160
+ }
161
+ template_pkl_path = os.path.join(save_dir,
162
+ f"{os.path.basename(args.dri_video)}.pkl")
163
+ with open(template_pkl_path, "wb") as fw:
164
+ pickle.dump(template_dct, fw)
165
+ print(f"save driving motion pkl file at : {template_pkl_path}")
166
+
167
+
168
+ def run_with_pkl(args):
169
+ infer_cfg = OmegaConf.load(args.cfg)
170
+ infer_cfg.infer_params.flag_pasteback = args.paste_back
171
+
172
+ pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=args.animal)
173
+ ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
174
+ if not ret:
175
+ print(f"no face in {args.src_image}! exit!")
176
+ return
177
+ with open(args.dri_video, "rb") as fin:
178
+ dri_motion_infos = pickle.load(fin)
179
+
180
+ fps = int(dri_motion_infos["output_fps"])
181
+ h, w = pipe.src_imgs[0].shape[:2]
182
+ save_dir = f"./results/{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
183
+ os.makedirs(save_dir, exist_ok=True)
184
+
185
+ # render output video
186
+ if not args.realtime:
187
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
188
+ vsave_crop_path = os.path.join(save_dir,
189
+ f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4")
190
+ vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512, 512))
191
+ vsave_org_path = os.path.join(save_dir,
192
+ f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4")
193
+ vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h))
194
+
195
+ infer_times = []
196
+ motion_lst = dri_motion_infos["motion"]
197
+ c_eyes_lst = dri_motion_infos["c_eyes_lst"] if "c_eyes_lst" in dri_motion_infos else dri_motion_infos[
198
+ "c_d_eyes_lst"]
199
+ c_lip_lst = dri_motion_infos["c_lip_lst"] if "c_lip_lst" in dri_motion_infos else dri_motion_infos["c_d_lip_lst"]
200
+
201
+ frame_num = len(motion_lst)
202
+ for frame_ind in tqdm(range(frame_num)):
203
+ t0 = time.time()
204
+ first_frame = frame_ind == 0
205
+ dri_motion_info_ = [motion_lst[frame_ind], c_eyes_lst[frame_ind], c_lip_lst[frame_ind]]
206
+ out_crop, out_org = pipe.run_with_pkl(dri_motion_info_, pipe.src_imgs[0], pipe.src_infos[0],
207
+ first_frame=first_frame)
208
+ if out_crop is None:
209
+ print(f"no face in driving frame:{frame_ind}")
210
+ continue
211
+
212
+ infer_times.append(time.time() - t0)
213
+ # print(time.time() - t0)
214
+ out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR)
215
+ if not args.realtime:
216
+ vout_crop.write(out_crop)
217
+ out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
218
+ vout_org.write(out_org)
219
+ else:
220
+ if infer_cfg.infer_params.flag_pasteback:
221
+ out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR)
222
+ cv2.imshow('Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate',out_org)
223
+ else:
224
+ # image show in realtime mode
225
+ cv2.imshow('Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate', out_crop)
226
+ # Press the 'q' key to exit the loop, r to switch realtime src_webcam update, spacebar to switch sourceisWebcam
227
+ k = cv2.waitKey(1) & 0xFF
228
+ if k == ord('q'):
229
+ break
230
+ # Key for Interesting Params
231
+ if k == ord('s'):
232
+ infer_cfg.infer_params.flag_stitching = not infer_cfg.infer_params.flag_stitching
233
+ print('flag_stitching:'+str(infer_cfg.infer_params.flag_stitching))
234
+ if k == ord('z'):
235
+ infer_cfg.infer_params.flag_relative_motion = not infer_cfg.infer_params.flag_relative_motion
236
+ print('flag_relative_motion:'+str(infer_cfg.infer_params.flag_relative_motion))
237
+ if k == ord('x'):
238
+ if infer_cfg.infer_params.animation_region == "all": infer_cfg.infer_params.animation_region = "exp", print('animation_region = "exp"')
239
+ else:infer_cfg.infer_params.animation_region = "all", print('animation_region = "all"')
240
+ if k == ord('c'):
241
+ infer_cfg.infer_params.flag_crop_driving_video = not infer_cfg.infer_params.flag_crop_driving_video
242
+ print('flag_crop_driving_video:'+str(infer_cfg.infer_params.flag_crop_driving_video))
243
+ if k == ord('v'):
244
+ infer_cfg.infer_params.flag_pasteback = not infer_cfg.infer_params.flag_pasteback
245
+ print('flag_pasteback:'+str(infer_cfg.infer_params.flag_pasteback))
246
+
247
+ if k == ord('a'):
248
+ infer_cfg.infer_params.flag_normalize_lip = not infer_cfg.infer_params.flag_normalize_lip
249
+ print('flag_normalize_lip:'+str(infer_cfg.infer_params.flag_normalize_lip))
250
+ if k == ord('d'):
251
+ infer_cfg.infer_params.flag_source_video_eye_retargeting = not infer_cfg.infer_params.flag_source_video_eye_retargeting
252
+ print('flag_source_video_eye_retargeting:'+str(infer_cfg.infer_params.flag_source_video_eye_retargeting))
253
+ if k == ord('f'):
254
+ infer_cfg.infer_params.flag_video_editing_head_rotation = not infer_cfg.infer_params.flag_video_editing_head_rotation
255
+ print('flag_video_editing_head_rotation:'+str(infer_cfg.infer_params.flag_video_editing_head_rotation))
256
+ if k == ord('g'):
257
+ infer_cfg.infer_params.flag_eye_retargeting = not infer_cfg.infer_params.flag_eye_retargeting
258
+ print('flag_eye_retargeting:'+str(infer_cfg.infer_params.flag_eye_retargeting))
259
+
260
+ if k == ord('k'):
261
+ infer_cfg.crop_params.src_scale -= 0.1
262
+ ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
263
+ print('src_scale:'+str(infer_cfg.crop_params.src_scale))
264
+ if k == ord('l'):
265
+ infer_cfg.crop_params.src_scale += 0.1
266
+ ret = pipe.prepare_source(args.src_image, realtime=args.realtime)
267
+ print('src_scale:'+str(infer_cfg.crop_params.src_scale))
268
+ if k == ord('n'):
269
+ infer_cfg.crop_params.dri_scale -= 0.1
270
+ print('dri_scale:'+str(infer_cfg.crop_params.dri_scale))
271
+ if k == ord('m'):
272
+ infer_cfg.crop_params.dri_scale += 0.1
273
+ print('dri_scale:'+str(infer_cfg.crop_params.dri_scale))
274
+
275
+ if not args.realtime:
276
+ vout_crop.release()
277
+ vout_org.release()
278
+ if video_has_audio(args.dri_video):
279
+ vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4"
280
+ subprocess.call(
281
+ [FFMPEG, "-i", vsave_crop_path, "-i", args.dri_video,
282
+ "-b:v", "10M", "-c:v",
283
+ "libx264", "-map", "0:v", "-map", "1:a",
284
+ "-c:a", "aac",
285
+ "-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
286
+ vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4"
287
+ subprocess.call(
288
+ [FFMPEG, "-i", vsave_org_path, "-i", args.dri_video,
289
+ "-b:v", "10M", "-c:v",
290
+ "libx264", "-map", "0:v", "-map", "1:a",
291
+ "-c:a", "aac",
292
+ "-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
293
+
294
+ print(vsave_crop_path_new)
295
+ print(vsave_org_path_new)
296
+ else:
297
+ print(vsave_crop_path)
298
+ print(vsave_org_path)
299
+ else:
300
+ cv2.destroyAllWindows()
301
+
302
+ print(
303
+ "inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000,
304
+ np.mean(infer_times) * 1000))
305
+
306
+
307
+ if __name__ == '__main__':
308
+ parser = argparse.ArgumentParser(description='Faster Live Portrait Pipeline')
309
+ parser.add_argument('--src_image', required=False, type=str, default="assets/examples/source/s12.jpg",
310
+ help='source image')
311
+ parser.add_argument('--dri_video', required=False, type=str, default="assets/examples/driving/d14.mp4",
312
+ help='driving video')
313
+ parser.add_argument('--cfg', required=False, type=str, default="configs/onnx_infer.yaml", help='inference config')
314
+ parser.add_argument('--realtime', action='store_true', help='realtime inference')
315
+ parser.add_argument('--animal', action='store_true', help='use animal model')
316
+ parser.add_argument('--paste_back', action='store_true', default=False, help='paste back to origin image')
317
+ args, unknown = parser.parse_known_args()
318
+
319
+ if args.dri_video.endswith(".pkl"):
320
+ run_with_pkl(args)
321
+ else:
322
+ run_with_video(args)
scripts/all_onnx2trt.bat ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ REM warping+spade model
4
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\warping_spade-fix.onnx
5
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\warping_spade-fix.onnx
6
+
7
+ REM landmark model
8
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\landmark.onnx
9
+
10
+ REM motion_extractor model
11
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\motion_extractor.onnx -p fp32
12
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\motion_extractor.onnx -p fp32
13
+
14
+ REM face_analysis model
15
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\retinaface_det_static.onnx
16
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\face_2dpose_106_static.onnx
17
+
18
+ REM appearance_extractor model
19
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\appearance_feature_extractor.onnx
20
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\appearance_feature_extractor.onnx
21
+
22
+ REM stitching model
23
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching.onnx
24
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_eye.onnx
25
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_onnx\stitching_lip.onnx
26
+
27
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching.onnx
28
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_eye.onnx
29
+ .\venv\python.exe scripts\onnx2trt.py -o .\checkpoints\liveportrait_animal_onnx\stitching_lip.onnx
scripts/all_onnx2trt.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # warping+spade model
4
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/warping_spade-fix.onnx
5
+ # landmark model
6
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/landmark.onnx
7
+ # motion_extractor model
8
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/motion_extractor.onnx -p fp32
9
+ # face_analysis model
10
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/retinaface_det_static.onnx
11
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/face_2dpose_106_static.onnx
12
+ # appearance_extractor model
13
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/appearance_feature_extractor.onnx
14
+ # stitching model
15
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching.onnx
16
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_eye.onnx
17
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_onnx/stitching_lip.onnx
scripts/all_onnx2trt_animal.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # warping+spade model
4
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/warping_spade-fix-v1.1.onnx
5
+ # motion_extractor model
6
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/motion_extractor-v1.1.onnx -p fp32
7
+ # appearance_extractor model
8
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/appearance_feature_extractor-v1.1.onnx
9
+ # stitching model
10
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching-v1.1.onnx
11
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_eye-v1.1.onnx
12
+ python scripts/onnx2trt.py -o ./checkpoints/liveportrait_animal_onnx/stitching_lip-v1.1.onnx
scripts/onnx2trt.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ import os
19
+ import pdb
20
+ import sys
21
+ import logging
22
+ import argparse
23
+ import platform
24
+
25
+ import tensorrt as trt
26
+ import ctypes
27
+ import numpy as np
28
+
29
+ logging.basicConfig(level=logging.INFO)
30
+ logging.getLogger("EngineBuilder").setLevel(logging.INFO)
31
+ log = logging.getLogger("EngineBuilder")
32
+
33
+
34
+ def load_plugins(logger: trt.Logger):
35
+ # 加载插件库
36
+ if platform.system().lower() == 'linux':
37
+ ctypes.CDLL("./checkpoints/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL)
38
+ else:
39
+ ctypes.CDLL("./checkpoints/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL, winmode=0)
40
+ # 初始化TensorRT的插件库
41
+ trt.init_libnvinfer_plugins(logger, "")
42
+
43
+
44
+ class EngineBuilder:
45
+ """
46
+ Parses an ONNX graph and builds a TensorRT engine from it.
47
+ """
48
+
49
+ def __init__(self, verbose=False):
50
+ """
51
+ :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger.
52
+ """
53
+ self.trt_logger = trt.Logger(trt.Logger.INFO)
54
+ if verbose:
55
+ self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE
56
+
57
+ trt.init_libnvinfer_plugins(self.trt_logger, namespace="")
58
+
59
+ self.builder = trt.Builder(self.trt_logger)
60
+ self.config = self.builder.create_builder_config()
61
+ self.config.max_workspace_size = 12 * (2 ** 30) # 12 GB
62
+
63
+ profile = self.builder.create_optimization_profile()
64
+
65
+ # for face_2dpose_106.onnx
66
+ # profile.set_shape("data", (1, 3, 192, 192), (1, 3, 192, 192), (1, 3, 192, 192))
67
+ # for retinaface_det.onnx
68
+ # profile.set_shape("input.1", (1, 3, 512, 512), (1, 3, 512, 512), (1, 3, 512, 512))
69
+
70
+ self.config.add_optimization_profile(profile)
71
+ # 严格类型约束
72
+ self.config.set_flag(trt.BuilderFlag.STRICT_TYPES)
73
+
74
+ self.batch_size = None
75
+ self.network = None
76
+ self.parser = None
77
+
78
+ # 加载自定义插件
79
+ load_plugins(self.trt_logger)
80
+
81
+ def create_network(self, onnx_path):
82
+ """
83
+ Parse the ONNX graph and create the corresponding TensorRT network definition.
84
+ :param onnx_path: The path to the ONNX graph to load.
85
+ """
86
+ network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
87
+
88
+ self.network = self.builder.create_network(network_flags)
89
+ self.parser = trt.OnnxParser(self.network, self.trt_logger)
90
+
91
+ onnx_path = os.path.realpath(onnx_path)
92
+ with open(onnx_path, "rb") as f:
93
+ if not self.parser.parse(f.read()):
94
+ log.error("Failed to load ONNX file: {}".format(onnx_path))
95
+ for error in range(self.parser.num_errors):
96
+ log.error(self.parser.get_error(error))
97
+ sys.exit(1)
98
+
99
+ inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
100
+ outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
101
+
102
+ log.info("Network Description")
103
+ for input in inputs:
104
+ self.batch_size = input.shape[0]
105
+ log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
106
+ for output in outputs:
107
+ log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
108
+ # assert self.batch_size > 0
109
+ self.builder.max_batch_size = 1
110
+
111
+ def create_engine(
112
+ self,
113
+ engine_path,
114
+ precision
115
+ ):
116
+ """
117
+ Build the TensorRT engine and serialize it to disk.
118
+ :param engine_path: The path where to serialize the engine to.
119
+ :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
120
+ """
121
+ engine_path = os.path.realpath(engine_path)
122
+ engine_dir = os.path.dirname(engine_path)
123
+ os.makedirs(engine_dir, exist_ok=True)
124
+ log.info("Building {} Engine in {}".format(precision, engine_path))
125
+
126
+ if precision == "fp16":
127
+ if not self.builder.platform_has_fast_fp16:
128
+ log.warning("FP16 is not supported natively on this platform/device")
129
+ else:
130
+ self.config.set_flag(trt.BuilderFlag.FP16)
131
+
132
+ with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
133
+ log.info("Serializing engine to file: {:}".format(engine_path))
134
+ f.write(engine.serialize())
135
+
136
+
137
+ def main(args):
138
+ builder = EngineBuilder(args.verbose)
139
+ builder.create_network(args.onnx)
140
+ builder.create_engine(
141
+ args.engine,
142
+ args.precision
143
+ )
144
+
145
+
146
+ if __name__ == "__main__":
147
+ parser = argparse.ArgumentParser()
148
+ parser.add_argument("-o", "--onnx", required=True, help="The input ONNX model file to load")
149
+ parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
150
+ parser.add_argument(
151
+ "-p",
152
+ "--precision",
153
+ default="fp16",
154
+ choices=["fp32", "fp16", "int8"],
155
+ help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'",
156
+ )
157
+ parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
158
+ args = parser.parse_args()
159
+ if args.engine is None:
160
+ args.engine = args.onnx.replace(".onnx", ".trt")
161
+ main(args)
scripts/start_api.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ source ~/.bashrc
3
+ python api.py
src/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : wenshao
3
+ # @Email : [email protected]
4
+ # @Project : FasterLivePortrait
5
+ # @FileName: __init__.py.py
src/models/JoyVASA/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/12/15
3
+ # @Author : wenshao
4
+ # @Email : [email protected]
5
+ # @Project : FasterLivePortrait
6
+ # @FileName: __init__.py
src/models/JoyVASA/common.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class PositionalEncoding(nn.Module):
9
+ def __init__(self, d_model, dropout=0.1, max_len=600):
10
+ super().__init__()
11
+ self.dropout = nn.Dropout(p=dropout)
12
+ # vanilla sinusoidal encoding
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0)
19
+ self.register_buffer('pe', pe)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[:, x.shape[1], :]
23
+ return self.dropout(x)
24
+
25
+
26
+ def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'):
27
+ mask = torch.ones(T, S)
28
+ for i in range(T):
29
+ mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
30
+ return (mask == 1).to(device=device)
31
+
32
+
33
+ def pad_audio(audio, audio_unit=320, pad_threshold=80):
34
+ batch_size, audio_len = audio.shape
35
+ n_units = audio_len // audio_unit
36
+ side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
37
+ if side_len >= 0:
38
+ reflect_len = side_len // 2
39
+ replicate_len = side_len % 2
40
+ if reflect_len > 0:
41
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
42
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
43
+ if replicate_len > 0:
44
+ audio = F.pad(audio, (1, 1), mode='replicate')
45
+
46
+ return audio
src/models/JoyVASA/dit_talking_head.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import platform
7
+ from .common import PositionalEncoding, enc_dec_mask, pad_audio
8
+ from tqdm import tqdm
9
+
10
+
11
+ class DiffusionSchedule(nn.Module):
12
+ def __init__(self, num_steps, mode='linear', beta_1=1e-4, beta_T=0.02, s=0.008):
13
+ super().__init__()
14
+
15
+ if mode == 'linear':
16
+ betas = torch.linspace(beta_1, beta_T, num_steps)
17
+ elif mode == 'quadratic':
18
+ betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
19
+ elif mode == 'sigmoid':
20
+ betas = torch.sigmoid(torch.linspace(-5, 5, num_steps)) * (beta_T - beta_1) + beta_1
21
+ elif mode == 'cosine':
22
+ steps = num_steps + 1
23
+ x = torch.linspace(0, num_steps, steps)
24
+ alpha_bars = torch.cos(((x / num_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
25
+ alpha_bars = alpha_bars / alpha_bars[0]
26
+ betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
27
+ betas = torch.clip(betas, 0.0001, 0.999)
28
+ else:
29
+ raise ValueError(f'Unknown diffusion schedule {mode}!')
30
+ betas = torch.cat([torch.zeros(1), betas], dim=0) # Padding beta_0 = 0
31
+
32
+ alphas = 1 - betas
33
+ log_alphas = torch.log(alphas)
34
+ for i in range(1, log_alphas.shape[0]): # 1 to T
35
+ log_alphas[i] += log_alphas[i - 1]
36
+ alpha_bars = log_alphas.exp()
37
+
38
+ sigmas_flex = torch.sqrt(betas)
39
+ sigmas_inflex = torch.zeros_like(sigmas_flex)
40
+ for i in range(1, sigmas_flex.shape[0]):
41
+ sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[i]
42
+ sigmas_inflex = torch.sqrt(sigmas_inflex)
43
+
44
+ self.num_steps = num_steps
45
+ self.register_buffer('betas', betas)
46
+ self.register_buffer('alphas', alphas)
47
+ self.register_buffer('alpha_bars', alpha_bars)
48
+ self.register_buffer('sigmas_flex', sigmas_flex)
49
+ self.register_buffer('sigmas_inflex', sigmas_inflex)
50
+
51
+ def uniform_sample_t(self, batch_size):
52
+ ts = torch.randint(1, self.num_steps + 1, (batch_size,))
53
+ return ts.tolist()
54
+
55
+ def get_sigmas(self, t, flexibility=0):
56
+ assert 0 <= flexibility <= 1
57
+ sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
58
+ return sigmas
59
+
60
+
61
+ class DitTalkingHead(nn.Module):
62
+ def __init__(self, device='cuda', target="sample", architecture="decoder",
63
+ motion_feat_dim=76, fps=25, n_motions=100, n_prev_motions=10,
64
+ audio_model="hubert", feature_dim=512, n_diff_steps=500, diff_schedule="cosine",
65
+ cfg_mode="incremental", guiding_conditions="audio,", audio_encoder_path=''):
66
+ super().__init__()
67
+
68
+ # Model parameters
69
+ self.target = target # 预测原始图像还是预测噪声
70
+ self.architecture = architecture
71
+ self.motion_feat_dim = motion_feat_dim # motion 特征维度
72
+ self.fps = fps
73
+ self.n_motions = n_motions # 当前motion100个, window_length, T_w
74
+ self.n_prev_motions = n_prev_motions # 前续motion 10个, T_p
75
+ self.feature_dim = feature_dim
76
+
77
+ # Audio encoder
78
+ self.audio_model = audio_model
79
+ if self.audio_model == 'wav2vec2':
80
+ print("using wav2vec2 audio encoder ...")
81
+ from .wav2vec2 import Wav2Vec2Model
82
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path)
83
+ # wav2vec 2.0 weights initialization
84
+ self.audio_encoder.feature_extractor._freeze_parameters()
85
+
86
+ frozen_layers = [0, 1]
87
+ for name, param in self.audio_encoder.named_parameters():
88
+ if name.startswith("feature_projection"):
89
+ param.requires_grad = False
90
+ if name.startswith("encoder.layers"):
91
+ layer = int(name.split(".")[2])
92
+ if layer in frozen_layers:
93
+ param.requires_grad = False
94
+ elif self.audio_model == "wav2vec2_ori":
95
+ from .wav2vec2 import Wav2Vec2Model
96
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(audio_encoder_path)
97
+ # wav2vec 2.0 weights initialization
98
+ self.audio_encoder.feature_extractor._freeze_parameters()
99
+ elif self.audio_model == 'hubert': # 根据经验,hubert特征提取器效果更好
100
+ from .hubert import HubertModel
101
+ # from hubert import HubertModel
102
+ self.audio_encoder = HubertModel.from_pretrained(audio_encoder_path)
103
+ self.audio_encoder.feature_extractor._freeze_parameters()
104
+ # print("hubert-en: ", self.audio_encoder)
105
+
106
+ frozen_layers = [0, 1]
107
+ for name, param in self.audio_encoder.named_parameters():
108
+ if name.startswith("feature_projection"):
109
+ param.requires_grad = False
110
+ if name.startswith("encoder.layers"):
111
+ layer = int(name.split(".")[2])
112
+ if layer in frozen_layers:
113
+ param.requires_grad = False
114
+ elif self.audio_model == 'hubert_zh': # 根据经验,hubert特征提取器效果更好
115
+ print("using hubert chinese")
116
+ from .hubert import HubertModel
117
+ # from hubert import HubertModel
118
+ self.audio_encoder = HubertModel.from_pretrained(audio_encoder_path)
119
+ self.audio_encoder.feature_extractor._freeze_parameters()
120
+
121
+ frozen_layers = [0, 1]
122
+ for name, param in self.audio_encoder.named_parameters():
123
+ if name.startswith("feature_projection"):
124
+ param.requires_grad = False
125
+ if name.startswith("encoder.layers"):
126
+ layer = int(name.split(".")[2])
127
+ if layer in frozen_layers:
128
+ param.requires_grad = False
129
+ elif self.audio_model == 'hubert_zh_ori': # 根据经验,hubert特征提取器效果更好
130
+ print("using hubert chinese ori")
131
+ from .hubert import HubertModel
132
+ self.audio_encoder = HubertModel.from_pretrained(audio_encoder_path)
133
+ self.audio_encoder.feature_extractor._freeze_parameters()
134
+ else:
135
+ raise ValueError(f'Unknown audio model {self.audio_model}!')
136
+
137
+ if architecture == 'decoder':
138
+ self.audio_feature_map = nn.Linear(768, feature_dim)
139
+ self.start_audio_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, feature_dim))
140
+ else:
141
+ raise ValueError(f'Unknown architecture {architecture}!')
142
+
143
+ self.start_motion_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, self.motion_feat_dim)) # 1, 10, 76
144
+
145
+ # Diffusion model
146
+ self.denoising_net = DenoisingNetwork(device=device, n_motions=self.n_motions,
147
+ n_prev_motions=self.n_prev_motions,
148
+ motion_feat_dim=self.motion_feat_dim, feature_dim=feature_dim)
149
+ # diffusion schedule
150
+ self.diffusion_sched = DiffusionSchedule(n_diff_steps, diff_schedule)
151
+
152
+ # Classifier-free settings
153
+ self.cfg_mode = cfg_mode
154
+ guiding_conditions = guiding_conditions.split(',') if guiding_conditions else []
155
+ self.guiding_conditions = [cond for cond in guiding_conditions if cond in ['audio']]
156
+ if 'audio' in self.guiding_conditions:
157
+ audio_feat_dim = feature_dim
158
+ self.null_audio_feat = nn.Parameter(torch.randn(1, 1, audio_feat_dim)) # 1, 1, 512
159
+
160
+ self.to(device)
161
+
162
+ @property
163
+ def device(self):
164
+ return next(self.parameters()).device
165
+
166
+ def forward(self, motion_feat, audio_or_feat, prev_motion_feat=None, prev_audio_feat=None, time_step=None,
167
+ indicator=None):
168
+ """
169
+ Args:
170
+ motion_feat: (N, L, d_coef) motion coefficients or features
171
+ audio_or_feat: (N, L_audio) raw audio or audio feature
172
+ prev_motion_feat: (N, n_prev_motions, d_motion) previous motion coefficients or feature
173
+ prev_audio_feat: (N, n_prev_motions, d_audio) previous audio features
174
+ time_step: (N,)
175
+ indicator: (N, L) 0/1 indicator of real (unpadded) motion coefficients
176
+
177
+ Returns:
178
+ motion_feat_noise: (N, L, d_motion)
179
+ """
180
+ batch_size = motion_feat.shape[0]
181
+
182
+ # 加载语音特征
183
+ if audio_or_feat.ndim == 2: # 原始语音
184
+ # Extract audio features
185
+ assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
186
+ f'Incorrect audio length {audio_or_feat.shape[1]}'
187
+ audio_feat_saved = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
188
+ elif audio_or_feat.ndim == 3: # 语音特征
189
+ assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
190
+ audio_feat_saved = audio_or_feat
191
+ else:
192
+ raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
193
+ audio_feat = audio_feat_saved.clone()
194
+
195
+ # 前续motion特征
196
+ if prev_motion_feat is None:
197
+ prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
198
+
199
+ # 前续语音特征
200
+ if prev_audio_feat is None:
201
+ # (N, n_prev_motions, feature_dim)
202
+ prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
203
+
204
+ # Classifier-free guidance
205
+ if len(self.guiding_conditions) > 0:
206
+ assert len(self.guiding_conditions) <= 2, 'Only support 1 or 2 CFG conditions!'
207
+ if len(self.guiding_conditions) == 1 or self.cfg_mode == 'independent':
208
+ null_cond_prob = 0.5 if len(self.guiding_conditions) >= 2 else 0.1
209
+ if 'audio' in self.guiding_conditions:
210
+ mask_audio = torch.rand(batch_size, device=self.device) < null_cond_prob
211
+ audio_feat = torch.where(mask_audio.view(-1, 1, 1),
212
+ self.null_audio_feat.expand(batch_size, self.n_motions, -1),
213
+ audio_feat)
214
+ else:
215
+ # len(self.guiding_conditions) > 1 and self.cfg_mode == 'incremental'
216
+ # full (0.45), w/o style (0.45), w/o style or audio (0.1)
217
+ mask_flag = torch.rand(batch_size, device=self.device)
218
+ if 'audio' in self.guiding_conditions:
219
+ mask_audio = mask_flag > 0.9
220
+ audio_feat = torch.where(mask_audio.view(-1, 1, 1),
221
+ self.null_audio_feat.expand(batch_size, self.n_motions, -1),
222
+ audio_feat)
223
+
224
+ if time_step is None:
225
+ # Sample time step
226
+ time_step = self.diffusion_sched.uniform_sample_t(batch_size) # (N,)
227
+
228
+ # The forward diffusion process
229
+ alpha_bar = self.diffusion_sched.alpha_bars[time_step] # (N,)
230
+ c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) # (N, 1, 1)
231
+ c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # (N, 1, 1)
232
+
233
+ eps = torch.randn_like(motion_feat) # (N, L, d_motion)
234
+ motion_feat_noisy = c0 * motion_feat + c1 * eps
235
+
236
+ # The reverse diffusion process
237
+ motion_feat_target = self.denoising_net(motion_feat_noisy, audio_feat,
238
+ prev_motion_feat, prev_audio_feat, time_step, indicator)
239
+
240
+ return eps, motion_feat_target, motion_feat.detach(), audio_feat_saved.detach()
241
+
242
+ def extract_audio_feature(self, audio, frame_num=None):
243
+ frame_num = frame_num or self.n_motions
244
+
245
+ # # Strategy 1: resample during audio feature extraction
246
+ # hidden_states = self.audio_encoder(pad_audio(audio), self.fps, frame_num=frame_num).last_hidden_state # (N, L, 768)
247
+
248
+ # Strategy 2: resample after audio feature extraction (BackResample)
249
+ hidden_states = self.audio_encoder(pad_audio(audio), self.fps,
250
+ frame_num=frame_num * 2).last_hidden_state # (N, 2L, 768)
251
+ hidden_states = hidden_states.transpose(1, 2) # (N, 768, 2L)
252
+ hidden_states = F.interpolate(hidden_states, size=frame_num, align_corners=False, mode='linear') # (N, 768, L)
253
+ hidden_states = hidden_states.transpose(1, 2) # (N, L, 768)
254
+
255
+ audio_feat = self.audio_feature_map(hidden_states) # (N, L, feature_dim)
256
+ return audio_feat
257
+
258
+ @torch.no_grad()
259
+ def sample(self, audio_or_feat, prev_motion_feat=None, prev_audio_feat=None,
260
+ motion_at_T=None, indicator=None, cfg_mode=None, cfg_cond=None, cfg_scale=1.15, flexibility=0,
261
+ dynamic_threshold=None, ret_traj=False):
262
+ # Check and convert inputs
263
+ batch_size = audio_or_feat.shape[0]
264
+
265
+ # Check CFG conditions
266
+ if cfg_mode is None: # Use default CFG mode
267
+ cfg_mode = self.cfg_mode
268
+ if cfg_cond is None: # Use default CFG conditions
269
+ cfg_cond = self.guiding_conditions
270
+ cfg_cond = [c for c in cfg_cond if c in ['audio', ]]
271
+
272
+ if not isinstance(cfg_scale, list):
273
+ cfg_scale = [cfg_scale] * len(cfg_cond)
274
+
275
+ # sort cfg_cond and cfg_scale
276
+ if len(cfg_cond) > 0:
277
+ cfg_cond, cfg_scale = zip(*sorted(zip(cfg_cond, cfg_scale), key=lambda x: ['audio', ].index(x[0])))
278
+ else:
279
+ cfg_cond, cfg_scale = [], []
280
+
281
+ if audio_or_feat.ndim == 2:
282
+ # Extract audio features
283
+ assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
284
+ f'Incorrect audio length {audio_or_feat.shape[1]}'
285
+ audio_feat = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
286
+ elif audio_or_feat.ndim == 3:
287
+ assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
288
+ audio_feat = audio_or_feat
289
+ else:
290
+ raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
291
+
292
+ if prev_motion_feat is None:
293
+ prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
294
+ if prev_audio_feat is None:
295
+ # (N, n_prev_motions, feature_dim)
296
+ prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
297
+
298
+ if motion_at_T is None:
299
+ motion_at_T = torch.randn((batch_size, self.n_motions, self.motion_feat_dim)).to(self.device)
300
+
301
+ # Prepare input for the reverse diffusion process (including optional classifier-free guidance)
302
+ if 'audio' in cfg_cond:
303
+ audio_feat_null = self.null_audio_feat.expand(batch_size, self.n_motions, -1)
304
+ else:
305
+ audio_feat_null = audio_feat
306
+
307
+ audio_feat_in = [audio_feat_null]
308
+ for cond in cfg_cond:
309
+ if cond == 'audio':
310
+ audio_feat_in.append(audio_feat)
311
+
312
+ n_entries = len(audio_feat_in)
313
+ audio_feat_in = torch.cat(audio_feat_in, dim=0)
314
+ prev_motion_feat_in = torch.cat([prev_motion_feat] * n_entries, dim=0)
315
+ prev_audio_feat_in = torch.cat([prev_audio_feat] * n_entries, dim=0)
316
+ indicator_in = torch.cat([indicator] * n_entries, dim=0) if indicator is not None else None
317
+
318
+ traj = {self.diffusion_sched.num_steps: motion_at_T}
319
+ for t in tqdm(range(self.diffusion_sched.num_steps, 0, -1)):
320
+ if t > 1:
321
+ z = torch.randn_like(motion_at_T)
322
+ else:
323
+ z = torch.zeros_like(motion_at_T)
324
+
325
+ alpha = self.diffusion_sched.alphas[t]
326
+ alpha_bar = self.diffusion_sched.alpha_bars[t]
327
+ alpha_bar_prev = self.diffusion_sched.alpha_bars[t - 1]
328
+ sigma = self.diffusion_sched.get_sigmas(t, flexibility)
329
+
330
+ motion_at_t = traj[t]
331
+ motion_in = torch.cat([motion_at_t] * n_entries, dim=0)
332
+ step_in = torch.tensor([t] * batch_size, device=self.device)
333
+ step_in = torch.cat([step_in] * n_entries, dim=0)
334
+
335
+ results = self.denoising_net(motion_in, audio_feat_in, prev_motion_feat_in,
336
+ prev_audio_feat_in, step_in, indicator_in)
337
+
338
+ # Apply thresholding if specified
339
+ if dynamic_threshold:
340
+ dt_ratio, dt_min, dt_max = dynamic_threshold
341
+ abs_results = results[:, -self.n_motions:].reshape(batch_size * n_entries, -1).abs()
342
+ s = torch.quantile(abs_results, dt_ratio, dim=1)
343
+ s = torch.clamp(s, min=dt_min, max=dt_max)
344
+ s = s[..., None, None]
345
+ results = torch.clamp(results, min=-s, max=s)
346
+
347
+ results = results.chunk(n_entries)
348
+
349
+ # Unconditional target (CFG) or the conditional target (non-CFG)
350
+ target_theta = results[0][:, -self.n_motions:]
351
+ # Classifier-free Guidance (optional)
352
+ for i in range(0, n_entries - 1):
353
+ if cfg_mode == 'independent':
354
+ target_theta += cfg_scale[i] * (
355
+ results[i + 1][:, -self.n_motions:] - results[0][:, -self.n_motions:])
356
+ elif cfg_mode == 'incremental':
357
+ target_theta += cfg_scale[i] * (
358
+ results[i + 1][:, -self.n_motions:] - results[i][:, -self.n_motions:])
359
+ else:
360
+ raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
361
+
362
+ if self.target == 'noise':
363
+ c0 = 1 / torch.sqrt(alpha)
364
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
365
+ motion_next = c0 * (motion_at_t - c1 * target_theta) + sigma * z
366
+ elif self.target == 'sample':
367
+ c0 = (1 - alpha_bar_prev) * torch.sqrt(alpha) / (1 - alpha_bar)
368
+ c1 = (1 - alpha) * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar)
369
+ motion_next = c0 * motion_at_t + c1 * target_theta + sigma * z
370
+ else:
371
+ raise ValueError('Unknown target type: {}'.format(self.target))
372
+
373
+ traj[t - 1] = motion_next.detach() # Stop gradient and save trajectory.
374
+ traj[t] = traj[t].cpu() # Move previous output to CPU memory.
375
+ if not ret_traj:
376
+ del traj[t]
377
+
378
+ if ret_traj:
379
+ return traj, motion_at_T, audio_feat
380
+ else:
381
+ return traj[0], motion_at_T, audio_feat
382
+
383
+
384
+ class DenoisingNetwork(nn.Module):
385
+ def __init__(self, device='cuda', motion_feat_dim=76,
386
+ use_indicator=None, architecture="decoder", feature_dim=512, n_heads=8,
387
+ n_layers=8, mlp_ratio=4, align_mask_width=1, no_use_learnable_pe=True, n_prev_motions=10,
388
+ n_motions=100, n_diff_steps=500, ):
389
+ super().__init__()
390
+
391
+ # Model parameters
392
+ self.motion_feat_dim = motion_feat_dim
393
+ self.use_indicator = use_indicator
394
+
395
+ # Transformer
396
+ self.architecture = architecture
397
+ self.feature_dim = feature_dim
398
+ self.n_heads = n_heads
399
+ self.n_layers = n_layers
400
+ self.mlp_ratio = mlp_ratio
401
+ self.align_mask_width = align_mask_width
402
+ self.use_learnable_pe = not no_use_learnable_pe
403
+
404
+ # sequence length
405
+ self.n_prev_motions = n_prev_motions
406
+ self.n_motions = n_motions
407
+
408
+ # Temporal embedding for the diffusion time step
409
+ self.TE = PositionalEncoding(self.feature_dim, max_len=n_diff_steps + 1)
410
+ self.diff_step_map = nn.Sequential(
411
+ nn.Linear(self.feature_dim, self.feature_dim),
412
+ nn.GELU(),
413
+ nn.Linear(self.feature_dim, self.feature_dim)
414
+ )
415
+
416
+ if self.use_learnable_pe:
417
+ # Learnable positional encoding
418
+ self.PE = nn.Parameter(torch.randn(1, 1 + self.n_prev_motions + self.n_motions, self.feature_dim))
419
+ else:
420
+ self.PE = PositionalEncoding(self.feature_dim)
421
+
422
+ # Transformer decoder
423
+ if self.architecture == 'decoder':
424
+ self.feature_proj = nn.Linear(self.motion_feat_dim + (1 if self.use_indicator else 0),
425
+ self.feature_dim)
426
+ decoder_layer = nn.TransformerDecoderLayer(
427
+ d_model=self.feature_dim, nhead=self.n_heads, dim_feedforward=self.mlp_ratio * self.feature_dim,
428
+ activation='gelu', batch_first=True
429
+ )
430
+ self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=self.n_layers)
431
+ if self.align_mask_width > 0:
432
+ motion_len = self.n_prev_motions + self.n_motions
433
+ alignment_mask = enc_dec_mask(motion_len, motion_len, frame_width=1,
434
+ expansion=self.align_mask_width - 1)
435
+ # print(f"alignment_mask: ", alignment_mask.shape)
436
+ # alignment_mask = F.pad(alignment_mask, (0, 0, 1, 0), value=False)
437
+ self.register_buffer('alignment_mask', alignment_mask)
438
+ else:
439
+ self.alignment_mask = None
440
+ else:
441
+ raise ValueError(f'Unknown architecture: {self.architecture}')
442
+
443
+ # Motion decoder
444
+ self.motion_dec = nn.Sequential(
445
+ nn.Linear(self.feature_dim, self.feature_dim // 2),
446
+ nn.GELU(),
447
+ nn.Linear(self.feature_dim // 2, self.motion_feat_dim),
448
+ # nn.Tanh() # 增加了一个tanh
449
+ # nn.Softmax()
450
+ )
451
+
452
+ self.to(device)
453
+
454
+ @property
455
+ def device(self):
456
+ return next(self.parameters()).device
457
+
458
+ def forward(self, motion_feat, audio_feat, prev_motion_feat, prev_audio_feat, step, indicator=None):
459
+ """
460
+ Args:
461
+ motion_feat: (N, L, d_motion). Noisy motion feature
462
+ audio_feat: (N, L, feature_dim)
463
+ prev_motion_feat: (N, L_p, d_motion). Padded previous motion coefficients or feature
464
+ prev_audio_feat: (N, L_p, d_audio). Padded previous motion coefficients or feature
465
+ step: (N,)
466
+ indicator: (N, L). 0/1 indicator for the real (unpadded) motion feature
467
+
468
+ Returns:
469
+ motion_feat_target: (N, L_p + L, d_motion)
470
+ """
471
+ motion_feat = motion_feat.to(audio_feat.dtype)
472
+ # Diffusion time step embedding
473
+ diff_step_embedding = self.diff_step_map(self.TE.pe[0, step]).unsqueeze(1) # (N, 1, diff_step_dim)
474
+
475
+ if indicator is not None:
476
+ indicator = torch.cat([torch.zeros((indicator.shape[0], self.n_prev_motions), device=indicator.device),
477
+ indicator], dim=1) # (N, L_p + L)
478
+ indicator = indicator.unsqueeze(-1) # (N, L_p + L, 1)
479
+
480
+ # Concat features and embeddings
481
+ if self.architecture == 'decoder':
482
+ # print("prev_motion_feat: ", prev_motion_feat.shape, "motion_feat: ", motion_feat.shape)
483
+ feats_in = torch.cat([prev_motion_feat, motion_feat], dim=1) # (N, L_p + L, d_motion)
484
+ else:
485
+ raise ValueError(f'Unknown architecture: {self.architecture}')
486
+ if self.use_indicator:
487
+ feats_in = torch.cat([feats_in, indicator], dim=-1) # (N, L_p + L, d_motion + d_audio + 1)
488
+ feats_in = self.feature_proj(feats_in) # (N, L_p + L, feature_dim)
489
+ # feats_in = torch.cat([person_feat, feats_in], dim=1) # (N, 1 + L_p + L, feature_dim)
490
+
491
+ if self.use_learnable_pe:
492
+ # feats_in = feats_in + self.PE
493
+ feats_in = feats_in + self.PE + diff_step_embedding
494
+ else:
495
+ # feats_in = self.PE(feats_in)
496
+ feats_in = self.PE(feats_in) + diff_step_embedding
497
+
498
+ # Transformer
499
+ if self.architecture == 'decoder':
500
+ audio_feat_in = torch.cat([prev_audio_feat, audio_feat], dim=1) # (N, L_p + L, d_audio)
501
+ # print(f"feats_in: {feats_in.shape}, audio_feat_in: {audio_feat_in.shape}, memory_mask: {self.alignment_mask.shape}")
502
+ feat_out = self.transformer(feats_in, audio_feat_in, memory_mask=self.alignment_mask)
503
+ else:
504
+ raise ValueError(f'Unknown architecture: {self.architecture}')
505
+
506
+ # Decode predicted motion feature noise / sample
507
+ # motion_feat_target = self.motion_dec(feat_out[:, 1:]) # (N, L_p + L, d_motion)
508
+ motion_feat_target = self.motion_dec(feat_out) # (N, L_p + L, d_motion)
509
+
510
+ return motion_feat_target
511
+
512
+
513
+ if __name__ == "__main__":
514
+ device = "cuda"
515
+ motion_feat_dim = 76
516
+ n_motions = 100 # L
517
+ n_prev_motions = 10 # L_p
518
+
519
+ L_audio = int(16000 * n_motions / 25) # 64000
520
+ d_audio = 768
521
+
522
+ N = 5
523
+ feature_dim = 512
524
+
525
+ motion_feat = torch.ones((N, n_motions, motion_feat_dim)).to(device)
526
+ prev_motion_feat = torch.ones((N, n_prev_motions, motion_feat_dim)).to(device)
527
+
528
+ audio_or_feat = torch.ones((N, L_audio)).to(device)
529
+ prev_audio_feat = torch.ones((N, n_prev_motions, d_audio)).to(device)
530
+
531
+ time_step = torch.ones(N, dtype=torch.long).to(device)
532
+
533
+ model = DitTalkingHead().to(device)
534
+
535
+ z = model(motion_feat, audio_or_feat, prev_motion_feat=None,
536
+ prev_audio_feat=None, time_step=None, indicator=None)
537
+ traj, motion_at_T, audio_feat = z[0], z[1], z[2]
538
+ print(motion_at_T.shape, audio_feat.shape)
src/models/JoyVASA/helper.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/12/15
3
+ # @Author : wenshao
4
+ # @Email : [email protected]
5
+ # @Project : FasterLivePortrait
6
+ # @FileName: helper.py
7
+ import os.path as osp
8
+
9
+
10
+ class NullableArgs:
11
+ def __init__(self, namespace):
12
+ for key, value in namespace.__dict__.items():
13
+ setattr(self, key, value)
14
+
15
+ def __getattr__(self, key):
16
+ # when an attribute lookup has not found the attribute
17
+ if key == 'align_mask_width':
18
+ if 'use_alignment_mask' in self.__dict__:
19
+ return 1 if self.use_alignment_mask else 0
20
+ else:
21
+ return 0
22
+ if key == 'no_head_pose':
23
+ return not self.predict_head_pose
24
+ if key == 'no_use_learnable_pe':
25
+ return not self.use_learnable_pe
26
+
27
+ return None
28
+
29
+
30
+ def make_abs_path(fn):
31
+ # return osp.join(osp.dirname(osp.realpath(__file__)), fn)
32
+ return osp.abspath(osp.join(osp.dirname(osp.realpath(__file__)), fn))
src/models/JoyVASA/hubert.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import HubertModel
2
+ from transformers.modeling_outputs import BaseModelOutput
3
+
4
+ from .wav2vec2 import linear_interpolation
5
+
6
+ _CONFIG_FOR_DOC = 'HubertConfig'
7
+
8
+
9
+ class HubertModel(HubertModel):
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+
13
+ def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
14
+ output_hidden_states=None, return_dict=None, frame_num=None):
15
+ self.config.output_attentions = True
16
+
17
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
18
+ output_hidden_states = (
19
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
20
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
21
+
22
+ extract_features = self.feature_extractor(input_values) # (N, C, L)
23
+ # Resample the audio feature @ 50 fps to `output_fps`.
24
+ if frame_num is not None:
25
+ extract_features_len = round(frame_num * 50 / output_fps)
26
+ extract_features = extract_features[:, :, :extract_features_len]
27
+ extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
28
+ extract_features = extract_features.transpose(1, 2) # (N, L, C)
29
+
30
+ if attention_mask is not None:
31
+ # compute reduced attention_mask corresponding to feature vectors
32
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
33
+
34
+ hidden_states = self.feature_projection(extract_features)
35
+ hidden_states = self._mask_hidden_states(hidden_states)
36
+
37
+ encoder_outputs = self.encoder(
38
+ hidden_states,
39
+ attention_mask=attention_mask,
40
+ output_attentions=output_attentions,
41
+ output_hidden_states=output_hidden_states,
42
+ return_dict=return_dict,
43
+ )
44
+
45
+ hidden_states = encoder_outputs[0]
46
+
47
+ if not return_dict:
48
+ return (hidden_states,) + encoder_outputs[1:]
49
+
50
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
51
+ attentions=encoder_outputs.attentions, )
src/models/JoyVASA/wav2vec2.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from transformers import Wav2Vec2Model
9
+ from transformers.modeling_outputs import BaseModelOutput
10
+
11
+ _CONFIG_FOR_DOC = 'Wav2Vec2Config'
12
+
13
+
14
+ # the implementation of Wav2Vec2Model is borrowed from
15
+ # https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
16
+ # initialize our encoder with the pre-trained wav2vec 2.0 weights.
17
+ def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
18
+ attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
19
+ bsz, all_sz = shape
20
+ mask = np.full((bsz, all_sz), False)
21
+
22
+ all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
23
+ all_num_mask = max(min_masks, all_num_mask)
24
+ mask_idcs = []
25
+ padding_mask = attention_mask.ne(1) if attention_mask is not None else None
26
+ for i in range(bsz):
27
+ if padding_mask is not None:
28
+ sz = all_sz - padding_mask[i].long().sum().item()
29
+ num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
30
+ num_mask = max(min_masks, num_mask)
31
+ else:
32
+ sz = all_sz
33
+ num_mask = all_num_mask
34
+
35
+ lengths = np.full(num_mask, mask_length)
36
+
37
+ if sum(lengths) == 0:
38
+ lengths[0] = min(mask_length, sz - 1)
39
+
40
+ min_len = min(lengths)
41
+ if sz - min_len <= num_mask:
42
+ min_len = sz - num_mask - 1
43
+
44
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
45
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
46
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
47
+
48
+ min_len = min([len(m) for m in mask_idcs])
49
+ for i, mask_idc in enumerate(mask_idcs):
50
+ if len(mask_idc) > min_len:
51
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
52
+ mask[i, mask_idc] = True
53
+ return mask
54
+
55
+
56
+ # linear interpolation layer
57
+ def linear_interpolation(features, input_fps, output_fps, output_len=None):
58
+ # features: (N, C, L)
59
+ seq_len = features.shape[2] / float(input_fps)
60
+ if output_len is None:
61
+ output_len = int(seq_len * output_fps)
62
+ output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
63
+ return output_features
64
+
65
+
66
+ class Wav2Vec2Model(Wav2Vec2Model):
67
+ def __init__(self, config):
68
+ super().__init__(config)
69
+ self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
70
+
71
+ def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
72
+ output_hidden_states=None, return_dict=None, frame_num=None):
73
+ self.config.output_attentions = True
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
77
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78
+
79
+ hidden_states = self.feature_extractor(input_values) # (N, C, L)
80
+ # Resample the audio feature @ 50 fps to `output_fps`.
81
+ if frame_num is not None:
82
+ hidden_states_len = round(frame_num * 50 / output_fps)
83
+ hidden_states = hidden_states[:, :, :hidden_states_len]
84
+ hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
85
+ hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
86
+
87
+ if attention_mask is not None:
88
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
89
+ attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
90
+ device=hidden_states.device)
91
+ attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
92
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
93
+
94
+ if self.is_old_version:
95
+ hidden_states = self.feature_projection(hidden_states)
96
+ else:
97
+ hidden_states = self.feature_projection(hidden_states)[0]
98
+
99
+ if self.config.apply_spec_augment and self.training:
100
+ batch_size, sequence_length, hidden_size = hidden_states.size()
101
+ if self.config.mask_time_prob > 0:
102
+ mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
103
+ self.config.mask_time_length, attention_mask=attention_mask,
104
+ min_masks=2, )
105
+ hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
106
+ if self.config.mask_feature_prob > 0:
107
+ mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
108
+ self.config.mask_feature_length, )
109
+ mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
110
+ hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
111
+ encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
112
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
113
+ return_dict=return_dict, )
114
+ hidden_states = encoder_outputs[0]
115
+ if not return_dict:
116
+ return (hidden_states,) + encoder_outputs[1:]
117
+
118
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
119
+ attentions=encoder_outputs.attentions, )
src/models/XPose/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/5 21:58
3
+ # @Author : shaoguowen
4
+ # @Email : [email protected]
5
+ # @Project : FasterLivePortrait
6
+ # @FileName: __init__.py.py
src/models/XPose/config_model/UniPose_SwinT.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['coco_transformer.py']
2
+
3
+ use_label_enc = True
4
+
5
+ num_classes=2
6
+
7
+ lr = 0.0001
8
+ param_dict_type = 'default'
9
+ lr_backbone = 1e-05
10
+ lr_backbone_names = ['backbone.0']
11
+ lr_linear_proj_names = ['reference_points', 'sampling_offsets']
12
+ lr_linear_proj_mult = 0.1
13
+ ddetr_lr_param = False
14
+ batch_size = 2
15
+ weight_decay = 0.0001
16
+ epochs = 12
17
+ lr_drop = 11
18
+ save_checkpoint_interval = 100
19
+ clip_max_norm = 0.1
20
+ onecyclelr = False
21
+ multi_step_lr = False
22
+ lr_drop_list = [33, 45]
23
+
24
+
25
+ modelname = 'UniPose'
26
+ frozen_weights = None
27
+ backbone = 'swin_T_224_1k'
28
+
29
+
30
+ dilation = False
31
+ position_embedding = 'sine'
32
+ pe_temperatureH = 20
33
+ pe_temperatureW = 20
34
+ return_interm_indices = [1, 2, 3]
35
+ backbone_freeze_keywords = None
36
+ enc_layers = 6
37
+ dec_layers = 6
38
+ unic_layers = 0
39
+ pre_norm = False
40
+ dim_feedforward = 2048
41
+ hidden_dim = 256
42
+ dropout = 0.0
43
+ nheads = 8
44
+ num_queries = 900
45
+ query_dim = 4
46
+ num_patterns = 0
47
+ pdetr3_bbox_embed_diff_each_layer = False
48
+ pdetr3_refHW = -1
49
+ random_refpoints_xy = False
50
+ fix_refpoints_hw = -1
51
+ dabdetr_yolo_like_anchor_update = False
52
+ dabdetr_deformable_encoder = False
53
+ dabdetr_deformable_decoder = False
54
+ use_deformable_box_attn = False
55
+ box_attn_type = 'roi_align'
56
+ dec_layer_number = None
57
+ num_feature_levels = 4
58
+ enc_n_points = 4
59
+ dec_n_points = 4
60
+ decoder_layer_noise = False
61
+ dln_xy_noise = 0.2
62
+ dln_hw_noise = 0.2
63
+ add_channel_attention = False
64
+ add_pos_value = False
65
+ two_stage_type = 'standard'
66
+ two_stage_pat_embed = 0
67
+ two_stage_add_query_num = 0
68
+ two_stage_bbox_embed_share = False
69
+ two_stage_class_embed_share = False
70
+ two_stage_learn_wh = False
71
+ two_stage_default_hw = 0.05
72
+ two_stage_keep_all_tokens = False
73
+ num_select = 50
74
+ transformer_activation = 'relu'
75
+ batch_norm_type = 'FrozenBatchNorm2d'
76
+ masks = False
77
+
78
+ decoder_sa_type = 'sa' # ['sa', 'ca_label', 'ca_content']
79
+ matcher_type = 'HungarianMatcher' # or SimpleMinsumMatcher
80
+ decoder_module_seq = ['sa', 'ca', 'ffn']
81
+ nms_iou_threshold = -1
82
+
83
+ dec_pred_bbox_embed_share = True
84
+ dec_pred_class_embed_share = True
85
+
86
+
87
+ use_dn = True
88
+ dn_number = 100
89
+ dn_box_noise_scale = 1.0
90
+ dn_label_noise_ratio = 0.5
91
+ dn_label_coef=1.0
92
+ dn_bbox_coef=1.0
93
+ embed_init_tgt = True
94
+ dn_labelbook_size = 2000
95
+
96
+ match_unstable_error = True
97
+
98
+ # for ema
99
+ use_ema = True
100
+ ema_decay = 0.9997
101
+ ema_epoch = 0
102
+
103
+ use_detached_boxes_dec_out = False
104
+
105
+ max_text_len = 256
106
+ shuffle_type = None
107
+
108
+ use_text_enhancer = True
109
+ use_fusion_layer = True
110
+
111
+ use_checkpoint = False # True
112
+ use_transformer_ckpt = True
113
+ text_encoder_type = 'bert-base-uncased'
114
+
115
+ use_text_cross_attention = True
116
+ text_dropout = 0.0
117
+ fusion_dropout = 0.0
118
+ fusion_droppath = 0.1
119
+
120
+ num_body_points=68
121
+ binary_query_selection = False
122
+ use_cdn = True
123
+ ffn_extra_layernorm = False
124
+
125
+ fix_size=False
src/models/XPose/config_model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/5 21:58
3
+ # @Author : shaoguowen
4
+ # @Email : [email protected]
5
+ # @Project : FasterLivePortrait
6
+ # @FileName: __init__.py.py
src/models/XPose/config_model/coco_transformer.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ data_aug_scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
2
+ data_aug_max_size = 1333
3
+ data_aug_scales2_resize = [400, 500, 600]
4
+ data_aug_scales2_crop = [384, 600]
5
+
6
+
7
+ data_aug_scale_overlap = None
8
+
src/models/XPose/models/UniPose/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Conditional DETR
3
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Copied from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
8
+ # ------------------------------------------------------------------------
9
+
10
+ from .unipose import build_unipose
src/models/XPose/models/UniPose/attention.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # UniPose
3
+ # url: https://github.com/IDEA-Research/UniPose
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # ED-Pose
8
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Conditional DETR
12
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Modified from codes in torch.nn
16
+ # ------------------------------------------------------------------------
17
+
18
+ """
19
+ MultiheadAttention that support query, key, and value to have different dimensions.
20
+ Query, key, and value projections are removed.
21
+
22
+ Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
23
+ and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
24
+ """
25
+
26
+ import warnings
27
+ import torch
28
+ from torch.nn.modules.linear import Linear
29
+ from torch.nn.init import constant_
30
+ from torch.nn.modules.module import Module
31
+ from torch._jit_internal import Optional, Tuple
32
+ try:
33
+ from torch.overrides import has_torch_function, handle_torch_function
34
+ except:
35
+ from torch._overrides import has_torch_function, handle_torch_function
36
+ from torch.nn.functional import linear, pad, softmax, dropout
37
+ Tensor = torch.Tensor
38
+
39
+ class MultiheadAttention(Module):
40
+ r"""Allows the model to jointly attend to information
41
+ from different representation subspaces.
42
+ See reference: Attention Is All You Need
43
+ .. math::
44
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
45
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
46
+ Args:
47
+ embed_dim: total dimension of the model.
48
+ num_heads: parallel attention heads.
49
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
50
+ bias: add bias as module parameter. Default: True.
51
+ add_bias_kv: add bias to the key and value sequences at dim=0.
52
+ add_zero_attn: add a new batch of zeros to the key and
53
+ value sequences at dim=1.
54
+ kdim: total number of features in key. Default: None.
55
+ vdim: total number of features in value. Default: None.
56
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
57
+ query, key, and value have the same number of features.
58
+ Examples::
59
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
60
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
61
+ """
62
+ bias_k: Optional[torch.Tensor]
63
+ bias_v: Optional[torch.Tensor]
64
+
65
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
66
+ super(MultiheadAttention, self).__init__()
67
+ self.embed_dim = embed_dim
68
+ self.kdim = kdim if kdim is not None else embed_dim
69
+ self.vdim = vdim if vdim is not None else embed_dim
70
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
71
+
72
+ self.num_heads = num_heads
73
+ self.dropout = dropout
74
+ self.head_dim = embed_dim // num_heads
75
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
76
+
77
+ vdim = vdim if vdim is not None else embed_dim
78
+ self.out_proj = Linear(vdim , vdim)
79
+
80
+ self.in_proj_bias = None
81
+ self.in_proj_weight = None
82
+ self.bias_k = self.bias_v = None
83
+ self.q_proj_weight = None
84
+ self.k_proj_weight = None
85
+ self.v_proj_weight = None
86
+
87
+ self.add_zero_attn = add_zero_attn
88
+
89
+ self._reset_parameters()
90
+
91
+ def _reset_parameters(self):
92
+ constant_(self.out_proj.bias, 0.)
93
+
94
+ def __setstate__(self, state):
95
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
96
+ if '_qkv_same_embed_dim' not in state:
97
+ state['_qkv_same_embed_dim'] = True
98
+
99
+ super(MultiheadAttention, self).__setstate__(state)
100
+
101
+ def forward(self, query, key, value, key_padding_mask=None,
102
+ need_weights=True, attn_mask=None):
103
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
104
+ r"""
105
+ Args:
106
+ query, key, value: map a query and a set of key-value pairs to an output.
107
+ See "Attention Is All You Need" for more details.
108
+ key_padding_mask: if provided, specified padding elements in the key will
109
+ be ignored by the attention. When given a binary mask and a value is True,
110
+ the corresponding value on the attention layer will be ignored. When given
111
+ a byte mask and a value is non-zero, the corresponding value on the attention
112
+ layer will be ignored
113
+ need_weights: output attn_output_weights.
114
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
115
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
116
+ Shape:
117
+ - Inputs:
118
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
119
+ the embedding dimension.
120
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
121
+ the embedding dimension.
122
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
123
+ the embedding dimension.
124
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
125
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
126
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
127
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
128
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
129
+ 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
130
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
131
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
132
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
133
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
134
+ is provided, it will be added to the attention weight.
135
+ - Outputs:
136
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
137
+ E is the embedding dimension.
138
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
139
+ L is the target sequence length, S is the source sequence length.
140
+ """
141
+ if not self._qkv_same_embed_dim:
142
+ return multi_head_attention_forward(
143
+ query, key, value, self.embed_dim, self.num_heads,
144
+ self.in_proj_weight, self.in_proj_bias,
145
+ self.bias_k, self.bias_v, self.add_zero_attn,
146
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
147
+ training=self.training,
148
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
149
+ attn_mask=attn_mask, use_separate_proj_weight=True,
150
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
151
+ v_proj_weight=self.v_proj_weight, out_dim=self.vdim)
152
+ else:
153
+ return multi_head_attention_forward(
154
+ query, key, value, self.embed_dim, self.num_heads,
155
+ self.in_proj_weight, self.in_proj_bias,
156
+ self.bias_k, self.bias_v, self.add_zero_attn,
157
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
158
+ training=self.training,
159
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
160
+ attn_mask=attn_mask, out_dim=self.vdim)
161
+
162
+
163
+ def multi_head_attention_forward(query: Tensor,
164
+ key: Tensor,
165
+ value: Tensor,
166
+ embed_dim_to_check: int,
167
+ num_heads: int,
168
+ in_proj_weight: Tensor,
169
+ in_proj_bias: Tensor,
170
+ bias_k: Optional[Tensor],
171
+ bias_v: Optional[Tensor],
172
+ add_zero_attn: bool,
173
+ dropout_p: float,
174
+ out_proj_weight: Tensor,
175
+ out_proj_bias: Tensor,
176
+ training: bool = True,
177
+ key_padding_mask: Optional[Tensor] = None,
178
+ need_weights: bool = True,
179
+ attn_mask: Optional[Tensor] = None,
180
+ use_separate_proj_weight: bool = False,
181
+ q_proj_weight: Optional[Tensor] = None,
182
+ k_proj_weight: Optional[Tensor] = None,
183
+ v_proj_weight: Optional[Tensor] = None,
184
+ static_k: Optional[Tensor] = None,
185
+ static_v: Optional[Tensor] = None,
186
+ out_dim: Optional[Tensor] = None
187
+ ) -> Tuple[Tensor, Optional[Tensor]]:
188
+ r"""
189
+ Args:
190
+ query, key, value: map a query and a set of key-value pairs to an output.
191
+ See "Attention Is All You Need" for more details.
192
+ embed_dim_to_check: total dimension of the model.
193
+ num_heads: parallel attention heads.
194
+ in_proj_weight, in_proj_bias: input projection weight and bias.
195
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
196
+ add_zero_attn: add a new batch of zeros to the key and
197
+ value sequences at dim=1.
198
+ dropout_p: probability of an element to be zeroed.
199
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
200
+ training: apply dropout if is ``True``.
201
+ key_padding_mask: if provided, specified padding elements in the key will
202
+ be ignored by the attention. This is an binary mask. When the value is True,
203
+ the corresponding value on the attention layer will be filled with -inf.
204
+ need_weights: output attn_output_weights.
205
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
206
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
207
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
208
+ and value in different forms. If false, in_proj_weight will be used, which is
209
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
210
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
211
+ static_k, static_v: static key and value used for attention operators.
212
+ Shape:
213
+ Inputs:
214
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
215
+ the embedding dimension.
216
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
217
+ the embedding dimension.
218
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
219
+ the embedding dimension.
220
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
221
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
222
+ will be unchanged. If a BoolTensor is provided, the positions with the
223
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
224
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
225
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
226
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
227
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
228
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
229
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
230
+ is provided, it will be added to the attention weight.
231
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
232
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
233
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
234
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
235
+ Outputs:
236
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
237
+ E is the embedding dimension.
238
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
239
+ L is the target sequence length, S is the source sequence length.
240
+ """
241
+ if not torch.jit.is_scripting():
242
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
243
+ out_proj_weight, out_proj_bias)
244
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
245
+ return handle_torch_function(
246
+ multi_head_attention_forward, tens_ops, query, key, value,
247
+ embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
248
+ bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
249
+ out_proj_bias, training=training, key_padding_mask=key_padding_mask,
250
+ need_weights=need_weights, attn_mask=attn_mask,
251
+ use_separate_proj_weight=use_separate_proj_weight,
252
+ q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
253
+ v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
254
+ tgt_len, bsz, embed_dim = query.size()
255
+ assert embed_dim == embed_dim_to_check
256
+ # allow MHA to have different sizes for the feature dimension
257
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
258
+
259
+ head_dim = embed_dim // num_heads
260
+ v_head_dim = out_dim // num_heads
261
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
262
+ scaling = float(head_dim) ** -0.5
263
+
264
+ q = query * scaling
265
+ k = key
266
+ v = value
267
+
268
+ if attn_mask is not None:
269
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
270
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
271
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
272
+ if attn_mask.dtype == torch.uint8:
273
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
274
+ attn_mask = attn_mask.to(torch.bool)
275
+
276
+ if attn_mask.dim() == 2:
277
+ attn_mask = attn_mask.unsqueeze(0)
278
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
279
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
280
+ elif attn_mask.dim() == 3:
281
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
282
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
283
+ else:
284
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
285
+ # attn_mask's dim is 3 now.
286
+
287
+ # convert ByteTensor key_padding_mask to bool
288
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
289
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
290
+ key_padding_mask = key_padding_mask.to(torch.bool)
291
+
292
+ if bias_k is not None and bias_v is not None:
293
+ if static_k is None and static_v is None:
294
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
295
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
296
+ if attn_mask is not None:
297
+ attn_mask = pad(attn_mask, (0, 1))
298
+ if key_padding_mask is not None:
299
+ key_padding_mask = pad(key_padding_mask, (0, 1))
300
+ else:
301
+ assert static_k is None, "bias cannot be added to static key."
302
+ assert static_v is None, "bias cannot be added to static value."
303
+ else:
304
+ assert bias_k is None
305
+ assert bias_v is None
306
+
307
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
308
+ if k is not None:
309
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
310
+ if v is not None:
311
+ v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
312
+
313
+ if static_k is not None:
314
+ assert static_k.size(0) == bsz * num_heads
315
+ assert static_k.size(2) == head_dim
316
+ k = static_k
317
+
318
+ if static_v is not None:
319
+ assert static_v.size(0) == bsz * num_heads
320
+ assert static_v.size(2) == v_head_dim
321
+ v = static_v
322
+
323
+ src_len = k.size(1)
324
+
325
+ if key_padding_mask is not None:
326
+ assert key_padding_mask.size(0) == bsz
327
+ assert key_padding_mask.size(1) == src_len
328
+
329
+ if add_zero_attn:
330
+ src_len += 1
331
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
332
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
333
+ if attn_mask is not None:
334
+ attn_mask = pad(attn_mask, (0, 1))
335
+ if key_padding_mask is not None:
336
+ key_padding_mask = pad(key_padding_mask, (0, 1))
337
+
338
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
339
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
340
+
341
+ if attn_mask is not None:
342
+ if attn_mask.dtype == torch.bool:
343
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
344
+ else:
345
+ attn_output_weights += attn_mask
346
+
347
+
348
+ if key_padding_mask is not None:
349
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
350
+ attn_output_weights = attn_output_weights.masked_fill(
351
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
352
+ float('-inf'),
353
+ )
354
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
355
+
356
+ # attn_output_weights = softmax(
357
+ # attn_output_weights, dim=-1)
358
+ attn_output_weights = softmax(
359
+ attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1)
360
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
361
+
362
+ attn_output = torch.bmm(attn_output_weights, v)
363
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
364
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
365
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
366
+
367
+ if need_weights:
368
+ # average attention weights over heads
369
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
370
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
371
+ else:
372
+ return attn_output, None
373
+
src/models/XPose/models/UniPose/backbone.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # UniPose
3
+ # url: https://github.com/IDEA-Research/UniPose
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # Conditional DETR
8
+ # Copyright (c) 2021 Microsoft. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # Copied from DETR (https://github.com/facebookresearch/detr)
12
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
13
+ # ------------------------------------------------------------------------
14
+
15
+ """
16
+ Backbone modules.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import torchvision
22
+ from torch import nn
23
+ from torchvision.models._utils import IntermediateLayerGetter
24
+ from typing import Dict, List
25
+
26
+ from ...util.misc import NestedTensor, is_main_process
27
+
28
+ from .position_encoding import build_position_encoding
29
+ from .swin_transformer import build_swin_transformer
30
+
31
+ class FrozenBatchNorm2d(torch.nn.Module):
32
+ """
33
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
34
+
35
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
36
+ without which any other models than torchvision.models.resnet[18,34,50,101]
37
+ produce nans.
38
+ """
39
+
40
+ def __init__(self, n):
41
+ super(FrozenBatchNorm2d, self).__init__()
42
+ self.register_buffer("weight", torch.ones(n))
43
+ self.register_buffer("bias", torch.zeros(n))
44
+ self.register_buffer("running_mean", torch.zeros(n))
45
+ self.register_buffer("running_var", torch.ones(n))
46
+
47
+ def _load_from_state_dict(
48
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
49
+ ):
50
+ num_batches_tracked_key = prefix + "num_batches_tracked"
51
+ if num_batches_tracked_key in state_dict:
52
+ del state_dict[num_batches_tracked_key]
53
+
54
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
55
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
56
+ )
57
+
58
+ def forward(self, x):
59
+ # move reshapes to the beginning
60
+ # to make it fuser-friendly
61
+ w = self.weight.reshape(1, -1, 1, 1)
62
+ b = self.bias.reshape(1, -1, 1, 1)
63
+ rv = self.running_var.reshape(1, -1, 1, 1)
64
+ rm = self.running_mean.reshape(1, -1, 1, 1)
65
+ eps = 1e-5
66
+ scale = w * (rv + eps).rsqrt()
67
+ bias = b - rm * scale
68
+ return x * scale + bias
69
+
70
+
71
+ class BackboneBase(nn.Module):
72
+ def __init__(
73
+ self,
74
+ backbone: nn.Module,
75
+ train_backbone: bool,
76
+ num_channels: int,
77
+ return_interm_indices: list,
78
+ ):
79
+ super().__init__()
80
+ for name, parameter in backbone.named_parameters():
81
+ if (
82
+ not train_backbone
83
+ or "layer2" not in name
84
+ and "layer3" not in name
85
+ and "layer4" not in name
86
+ ):
87
+ parameter.requires_grad_(False)
88
+
89
+ return_layers = {}
90
+ for idx, layer_index in enumerate(return_interm_indices):
91
+ return_layers.update(
92
+ {"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)}
93
+ )
94
+
95
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
96
+ self.num_channels = num_channels
97
+
98
+ def forward(self, tensor_list: NestedTensor):
99
+ xs = self.body(tensor_list.tensors)
100
+ out: Dict[str, NestedTensor] = {}
101
+ for name, x in xs.items():
102
+ m = tensor_list.mask
103
+ assert m is not None
104
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
105
+ out[name] = NestedTensor(x, mask)
106
+ # import ipdb; ipdb.set_trace()
107
+ return out
108
+
109
+
110
+ class Backbone(BackboneBase):
111
+ """ResNet backbone with frozen BatchNorm."""
112
+
113
+ def __init__(
114
+ self,
115
+ name: str,
116
+ train_backbone: bool,
117
+ dilation: bool,
118
+ return_interm_indices: list,
119
+ batch_norm=FrozenBatchNorm2d,
120
+ ):
121
+ if name in ["resnet18", "resnet34", "resnet50", "resnet101"]:
122
+ backbone = getattr(torchvision.models, name)(
123
+ replace_stride_with_dilation=[False, False, dilation],
124
+ pretrained=is_main_process(),
125
+ norm_layer=batch_norm,
126
+ )
127
+ else:
128
+ raise NotImplementedError("Why you can get here with name {}".format(name))
129
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
130
+ assert name not in ("resnet18", "resnet34"), "Only resnet50 and resnet101 are available."
131
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
132
+ num_channels_all = [256, 512, 1024, 2048]
133
+ num_channels = num_channels_all[4 - len(return_interm_indices) :]
134
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
135
+
136
+
137
+ class Joiner(nn.Sequential):
138
+ def __init__(self, backbone, position_embedding):
139
+ super().__init__(backbone, position_embedding)
140
+
141
+ def forward(self, tensor_list: NestedTensor):
142
+ xs = self[0](tensor_list)
143
+ out: List[NestedTensor] = []
144
+ pos = []
145
+ for name, x in xs.items():
146
+ out.append(x)
147
+ # position encoding
148
+ pos.append(self[1](x).to(x.tensors.dtype))
149
+
150
+ return out, pos
151
+
152
+
153
+ def build_backbone(args):
154
+ """
155
+ Useful args:
156
+ - backbone: backbone name
157
+ - lr_backbone:
158
+ - dilation
159
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
160
+ - backbone_freeze_keywords:
161
+ - use_checkpoint: for swin only for now
162
+
163
+ """
164
+ position_embedding = build_position_encoding(args)
165
+ train_backbone = True
166
+ if not train_backbone:
167
+ raise ValueError("Please set lr_backbone > 0")
168
+ return_interm_indices = args.return_interm_indices
169
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
170
+ args.backbone_freeze_keywords
171
+ use_checkpoint = getattr(args, "use_checkpoint", False)
172
+
173
+ if args.backbone in ["resnet50", "resnet101"]:
174
+ backbone = Backbone(
175
+ args.backbone,
176
+ train_backbone,
177
+ args.dilation,
178
+ return_interm_indices,
179
+ batch_norm=FrozenBatchNorm2d,
180
+ )
181
+ bb_num_channels = backbone.num_channels
182
+ elif args.backbone in [
183
+ "swin_T_224_1k",
184
+ "swin_B_224_22k",
185
+ "swin_B_384_22k",
186
+ "swin_L_224_22k",
187
+ "swin_L_384_22k",
188
+ ]:
189
+ pretrain_img_size = int(args.backbone.split("_")[-2])
190
+ backbone = build_swin_transformer(
191
+ args.backbone,
192
+ pretrain_img_size=pretrain_img_size,
193
+ out_indices=tuple(return_interm_indices),
194
+ dilation=False,
195
+ use_checkpoint=use_checkpoint,
196
+ )
197
+
198
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
199
+ else:
200
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
201
+
202
+ assert len(bb_num_channels) == len(
203
+ return_interm_indices
204
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
205
+
206
+ model = Joiner(backbone, position_embedding)
207
+ model.num_channels = bb_num_channels
208
+ assert isinstance(
209
+ bb_num_channels, List
210
+ ), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
211
+ return model
src/models/XPose/models/UniPose/deformable_transformer.py ADDED
@@ -0,0 +1,1230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # UniPose
3
+ # url: https://github.com/IDEA-Research/UniPose
4
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
5
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ # ------------------------------------------------------------------------
7
+ # ED-Pose
8
+ # Copyright (c) 2023 IDEA. All Rights Reserved.
9
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
10
+ # ------------------------------------------------------------------------
11
+ # DINO
12
+ # Copyright (c) 2022 IDEA. All Rights Reserved.
13
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
14
+ # ------------------------------------------------------------------------
15
+ # Modified from DETR (https://github.com/facebookresearch/detr)
16
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
17
+ # ------------------------------------------------------------------------
18
+
19
+ import math
20
+ import copy
21
+ import torch
22
+ import torch.utils.checkpoint as checkpoint
23
+ from torch import nn, Tensor
24
+ from typing import Optional
25
+ from ...util.misc import inverse_sigmoid
26
+
27
+ from .transformer_vanilla import TransformerEncoderLayer
28
+ from .fuse_modules import BiAttentionBlock
29
+ from .utils import gen_encoder_output_proposals, MLP, _get_activation_fn, gen_sineembed_for_position, get_sine_pos_embed
30
+ from .ops.modules import MSDeformAttn
31
+
32
+
33
+ class DeformableTransformer(nn.Module):
34
+
35
+ def __init__(self, d_model=256, nhead=8,
36
+ num_queries=300,
37
+ num_encoder_layers=6,
38
+ num_unicoder_layers=0,
39
+ num_decoder_layers=6,
40
+ dim_feedforward=2048, dropout=0.0,
41
+ activation="relu", normalize_before=False,
42
+ return_intermediate_dec=False, query_dim=4,
43
+ num_patterns=0,
44
+ modulate_hw_attn=False,
45
+ # for deformable encoder
46
+ deformable_encoder=False,
47
+ deformable_decoder=False,
48
+ num_feature_levels=1,
49
+ enc_n_points=4,
50
+ dec_n_points=4,
51
+ use_deformable_box_attn=False,
52
+ box_attn_type='roi_align',
53
+ # init query
54
+ learnable_tgt_init=False,
55
+ decoder_query_perturber=None,
56
+ add_channel_attention=False,
57
+ add_pos_value=False,
58
+ random_refpoints_xy=False,
59
+ # two stage
60
+ two_stage_type='no',
61
+ two_stage_pat_embed=0,
62
+ two_stage_add_query_num=0,
63
+ two_stage_learn_wh=False,
64
+ two_stage_keep_all_tokens=False,
65
+ # evo of #anchors
66
+ dec_layer_number=None,
67
+ rm_enc_query_scale=True,
68
+ rm_dec_query_scale=True,
69
+ rm_self_attn_layers=None,
70
+ key_aware_type=None,
71
+ # layer share
72
+ layer_share_type=None,
73
+ # for detach
74
+ rm_detach=None,
75
+ decoder_sa_type='ca',
76
+ module_seq=['sa', 'ca', 'ffn'],
77
+ # for dn
78
+ embed_init_tgt=False,
79
+
80
+ use_detached_boxes_dec_out=False,
81
+ use_text_enhancer=False,
82
+ use_fusion_layer=False,
83
+ use_checkpoint=False,
84
+ use_transformer_ckpt=False,
85
+ use_text_cross_attention=False,
86
+ text_dropout=0.1,
87
+ fusion_dropout=0.1,
88
+ fusion_droppath=0.0,
89
+
90
+ binary_query_selection=False,
91
+ ffn_extra_layernorm=False,
92
+ ):
93
+ super().__init__()
94
+ self.num_feature_levels = num_feature_levels
95
+ self.num_encoder_layers = num_encoder_layers
96
+ self.num_unicoder_layers = num_unicoder_layers
97
+ self.num_decoder_layers = num_decoder_layers
98
+ self.deformable_encoder = deformable_encoder
99
+ self.deformable_decoder = deformable_decoder
100
+ self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
101
+ self.num_queries = num_queries
102
+ self.random_refpoints_xy = random_refpoints_xy
103
+ self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
104
+ self.ffn_extra_layernorm = ffn_extra_layernorm
105
+ assert query_dim == 4
106
+
107
+ self.binary_query_selection = binary_query_selection
108
+ if self.binary_query_selection:
109
+ self.binary_query_selection_layer = nn.Linear(d_model, 1)
110
+ # assert not binary_query_selection, 'binary_query_selection not implemented yet'
111
+
112
+ if num_feature_levels > 1:
113
+ assert deformable_encoder, "only support deformable_encoder for num_feature_levels > 1"
114
+ if use_deformable_box_attn:
115
+ assert deformable_encoder or deformable_encoder
116
+
117
+ assert layer_share_type in [None, 'encoder', 'decoder', 'both']
118
+ if layer_share_type in ['encoder', 'both']:
119
+ enc_layer_share = True
120
+ else:
121
+ enc_layer_share = False
122
+ if layer_share_type in ['decoder', 'both']:
123
+ dec_layer_share = True
124
+ else:
125
+ dec_layer_share = False
126
+ assert layer_share_type is None
127
+
128
+ self.decoder_sa_type = decoder_sa_type
129
+ assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
130
+
131
+ # choose encoder layer type
132
+ if deformable_encoder:
133
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
134
+ dropout, activation,
135
+ num_feature_levels, nhead, enc_n_points,
136
+ add_channel_attention=add_channel_attention,
137
+ use_deformable_box_attn=use_deformable_box_attn,
138
+ box_attn_type=box_attn_type)
139
+ else:
140
+ raise NotImplementedError
141
+
142
+ if use_text_enhancer:
143
+ text_enhance_layer = TransformerEncoderLayer(
144
+ d_model=d_model,
145
+ nhead=nhead // 2,
146
+ dim_feedforward=dim_feedforward // 2,
147
+ dropout=text_dropout
148
+ )
149
+ else:
150
+ text_enhance_layer = None
151
+
152
+ if use_fusion_layer:
153
+ feature_fusion_layer = BiAttentionBlock(
154
+ v_dim=d_model,
155
+ l_dim=d_model,
156
+ embed_dim=dim_feedforward // 2,
157
+ num_heads=nhead // 2,
158
+ dropout=fusion_dropout,
159
+ drop_path=fusion_droppath
160
+ )
161
+ else:
162
+ feature_fusion_layer = None
163
+
164
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
165
+ assert encoder_norm is None
166
+ self.encoder = TransformerEncoder(
167
+ encoder_layer, num_encoder_layers, d_model=d_model,
168
+ num_queries=num_queries,
169
+ enc_layer_share=enc_layer_share,
170
+ text_enhance_layer=text_enhance_layer,
171
+ feature_fusion_layer=feature_fusion_layer,
172
+ use_checkpoint=use_checkpoint,
173
+ use_transformer_ckpt=use_transformer_ckpt,
174
+ )
175
+
176
+ # choose decoder layer type
177
+ if deformable_decoder:
178
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
179
+ dropout, activation,
180
+ num_feature_levels, nhead, dec_n_points,
181
+ use_text_cross_attention=use_text_cross_attention,
182
+ ffn_extra_layernorm=ffn_extra_layernorm, )
183
+
184
+ else:
185
+ raise NotImplementedError
186
+
187
+ decoder_norm = nn.LayerNorm(d_model)
188
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
189
+ return_intermediate=return_intermediate_dec,
190
+ d_model=d_model, query_dim=query_dim,
191
+ modulate_hw_attn=modulate_hw_attn,
192
+ num_feature_levels=num_feature_levels,
193
+ deformable_decoder=deformable_decoder,
194
+ decoder_query_perturber=decoder_query_perturber,
195
+ dec_layer_number=dec_layer_number, rm_dec_query_scale=rm_dec_query_scale,
196
+ dec_layer_share=dec_layer_share,
197
+ use_detached_boxes_dec_out=use_detached_boxes_dec_out
198
+ )
199
+
200
+ self.d_model = d_model
201
+ self.nhead = nhead
202
+ self.dec_layers = num_decoder_layers
203
+ self.num_queries = num_queries # useful for single stage model only
204
+ self.num_patterns = num_patterns
205
+ if not isinstance(num_patterns, int):
206
+ Warning("num_patterns should be int but {}".format(type(num_patterns)))
207
+ self.num_patterns = 0
208
+
209
+ if num_feature_levels > 1:
210
+ if self.num_encoder_layers > 0:
211
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
212
+ else:
213
+ self.level_embed = None
214
+
215
+ self.learnable_tgt_init = learnable_tgt_init
216
+ assert learnable_tgt_init, "why not learnable_tgt_init"
217
+ self.embed_init_tgt = embed_init_tgt
218
+ if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type == 'no'):
219
+ self.tgt_embed = nn.Embedding(self.num_queries, d_model)
220
+ nn.init.normal_(self.tgt_embed.weight.data)
221
+ else:
222
+ self.tgt_embed = None
223
+
224
+ # for two stage
225
+ self.two_stage_type = two_stage_type
226
+ self.two_stage_pat_embed = two_stage_pat_embed
227
+ self.two_stage_add_query_num = two_stage_add_query_num
228
+ self.two_stage_learn_wh = two_stage_learn_wh
229
+ assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type)
230
+ if two_stage_type == 'standard':
231
+ # anchor selection at the output of encoder
232
+ self.enc_output = nn.Linear(d_model, d_model)
233
+ self.enc_output_norm = nn.LayerNorm(d_model)
234
+
235
+ if two_stage_pat_embed > 0:
236
+ self.pat_embed_for_2stage = nn.Parameter(torch.Tensor(two_stage_pat_embed, d_model))
237
+ nn.init.normal_(self.pat_embed_for_2stage)
238
+
239
+ if two_stage_add_query_num > 0:
240
+ self.tgt_embed = nn.Embedding(self.two_stage_add_query_num, d_model)
241
+
242
+ if two_stage_learn_wh:
243
+ # import ipdb; ipdb.set_trace()
244
+ self.two_stage_wh_embedding = nn.Embedding(1, 2)
245
+ else:
246
+ self.two_stage_wh_embedding = None
247
+
248
+ if two_stage_type == 'no':
249
+ self.init_ref_points(num_queries) # init self.refpoint_embed
250
+
251
+ self.enc_out_class_embed = None
252
+ self.enc_out_bbox_embed = None
253
+
254
+ # evolution of anchors
255
+ self.dec_layer_number = dec_layer_number
256
+ if dec_layer_number is not None:
257
+ if self.two_stage_type != 'no' or num_patterns == 0:
258
+ assert dec_layer_number[
259
+ 0] == num_queries, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})"
260
+ else:
261
+ assert dec_layer_number[
262
+ 0] == num_queries * num_patterns, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})"
263
+
264
+ self._reset_parameters()
265
+
266
+ self.rm_self_attn_layers = rm_self_attn_layers
267
+ if rm_self_attn_layers is not None:
268
+ # assert len(rm_self_attn_layers) == num_decoder_layers
269
+ print("Removing the self-attn in {} decoder layers".format(rm_self_attn_layers))
270
+ for lid, dec_layer in enumerate(self.decoder.layers):
271
+ if lid in rm_self_attn_layers:
272
+ dec_layer.rm_self_attn_modules()
273
+
274
+ self.rm_detach = rm_detach
275
+ if self.rm_detach:
276
+ assert isinstance(rm_detach, list)
277
+ assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
278
+ self.decoder.rm_detach = rm_detach
279
+
280
+ def _reset_parameters(self):
281
+ for p in self.parameters():
282
+ if p.dim() > 1:
283
+ nn.init.xavier_uniform_(p)
284
+ for m in self.modules():
285
+ if isinstance(m, MSDeformAttn):
286
+ m._reset_parameters()
287
+ if self.num_feature_levels > 1 and self.level_embed is not None:
288
+ nn.init.normal_(self.level_embed)
289
+
290
+ if self.two_stage_learn_wh:
291
+ nn.init.constant_(self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05)))
292
+
293
+ def get_valid_ratio(self, mask):
294
+ _, H, W = mask.shape
295
+ valid_H = torch.sum(~mask[:, :, 0], 1)
296
+ valid_W = torch.sum(~mask[:, 0, :], 1)
297
+ valid_ratio_h = valid_H.float() / H
298
+ valid_ratio_w = valid_W.float() / W
299
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
300
+ return valid_ratio
301
+
302
+ def init_ref_points(self, use_num_queries):
303
+ self.refpoint_embed = nn.Embedding(use_num_queries, 4)
304
+
305
+ if self.random_refpoints_xy:
306
+ # import ipdb; ipdb.set_trace()
307
+ self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
308
+ self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
309
+ self.refpoint_embed.weight.data[:, :2].requires_grad = False
310
+
311
+ def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, attn_mask2=None, text_dict=None,
312
+ dn_meta=None,targets=None,kpt_embed=None):
313
+ """
314
+ Input:
315
+ - srcs: List of multi features [bs, ci, hi, wi]
316
+ - masks: List of multi masks [bs, hi, wi]
317
+ - refpoint_embed: [bs, num_dn, 4]. None in infer
318
+ - pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
319
+ - tgt: [bs, num_dn, d_model]. None in infer
320
+
321
+ """
322
+ # if self.two_stage_type != 'no' and self.two_stage_add_query_num == 0:
323
+ # assert refpoint_embed is None
324
+
325
+ # prepare input for encoder
326
+ src_flatten = []
327
+ mask_flatten = []
328
+ lvl_pos_embed_flatten = []
329
+ spatial_shapes = []
330
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
331
+ bs, c, h, w = src.shape
332
+ spatial_shape = (h, w)
333
+ spatial_shapes.append(spatial_shape)
334
+
335
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
336
+ mask = mask.flatten(1) # bs, hw
337
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
338
+ if self.num_feature_levels > 1 and self.level_embed is not None:
339
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
340
+ else:
341
+ lvl_pos_embed = pos_embed
342
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
343
+ src_flatten.append(src)
344
+ mask_flatten.append(mask)
345
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
346
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
347
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
348
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
349
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
350
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
351
+
352
+ # two stage
353
+ enc_topk_proposals = enc_refpoint_embed = None
354
+
355
+ #########################################################
356
+ # Begin Encoder
357
+ #########################################################
358
+ memory, memory_text = self.encoder(
359
+ src_flatten,
360
+ pos=lvl_pos_embed_flatten,
361
+ level_start_index=level_start_index,
362
+ spatial_shapes=spatial_shapes,
363
+ valid_ratios=valid_ratios,
364
+ key_padding_mask=mask_flatten,
365
+ memory_text=text_dict['encoded_text'],
366
+ text_attention_mask=~text_dict['text_token_mask'],
367
+ # we ~ the mask . False means use the token; True means pad the token
368
+ position_ids=text_dict['position_ids'],
369
+ text_self_attention_masks=text_dict['text_self_attention_masks'],
370
+ )
371
+ #########################################################
372
+ # End Encoder
373
+ # - memory: bs, \sum{hw}, c
374
+ # - mask_flatten: bs, \sum{hw}
375
+ # - lvl_pos_embed_flatten: bs, \sum{hw}, c
376
+ # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
377
+ # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
378
+ #########################################################
379
+ text_dict['encoded_text'] = memory_text
380
+
381
+ if self.two_stage_type == 'standard':
382
+ if self.two_stage_learn_wh:
383
+ input_hw = self.two_stage_wh_embedding.weight[0]
384
+ else:
385
+ input_hw = None
386
+ output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes,
387
+ input_hw)
388
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
389
+
390
+ if self.two_stage_pat_embed > 0:
391
+ bs, nhw, _ = output_memory.shape
392
+ # output_memory: bs, n, 256; self.pat_embed_for_2stage: k, 256
393
+ output_memory = output_memory.repeat(1, self.two_stage_pat_embed, 1)
394
+ _pats = self.pat_embed_for_2stage.repeat_interleave(nhw, 0)
395
+ output_memory = output_memory + _pats
396
+ output_proposals = output_proposals.repeat(1, self.two_stage_pat_embed, 1)
397
+
398
+ if self.two_stage_add_query_num > 0:
399
+ assert refpoint_embed is not None
400
+ output_memory = torch.cat((output_memory, tgt), dim=1)
401
+ output_proposals = torch.cat((output_proposals, refpoint_embed), dim=1)
402
+
403
+ if self.binary_query_selection:
404
+ topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1)
405
+ else:
406
+ if text_dict is not None:
407
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory, text_dict)
408
+ else:
409
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
410
+
411
+ topk_logits = enc_outputs_class_unselected.max(-1)[0]
412
+ enc_outputs_coord_unselected = self.enc_out_bbox_embed(
413
+ output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
414
+ topk = self.num_queries
415
+
416
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
417
+
418
+ # gather boxes
419
+ refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1,
420
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
421
+ refpoint_embed_ = refpoint_embed_undetach.detach()
422
+ init_box_proposal = torch.gather(output_proposals, 1,
423
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid
424
+
425
+ # gather tgt
426
+ tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
427
+ if self.embed_init_tgt:
428
+ tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
429
+ else:
430
+ tgt_ = tgt_undetach.detach()
431
+
432
+ if refpoint_embed is not None:
433
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
434
+ tgt = torch.cat([tgt, tgt_], dim=1)
435
+ else:
436
+ refpoint_embed, tgt = refpoint_embed_, tgt_
437
+
438
+ elif self.two_stage_type == 'no':
439
+ tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
440
+ refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, 4
441
+
442
+ if refpoint_embed is not None:
443
+ refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1)
444
+ tgt = torch.cat([tgt, tgt_], dim=1)
445
+ else:
446
+ refpoint_embed, tgt = refpoint_embed_, tgt_
447
+
448
+ if self.num_patterns > 0:
449
+ tgt_embed = tgt.repeat(1, self.num_patterns, 1)
450
+ refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
451
+ tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(self.num_queries,
452
+ 1) # 1, n_q*n_pat, d_model
453
+ tgt = tgt_embed + tgt_pat
454
+
455
+ init_box_proposal = refpoint_embed_.sigmoid()
456
+
457
+ else:
458
+ raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
459
+ #########################################################
460
+ # End preparing tgt
461
+ # - tgt: bs, NQ, d_model
462
+ # - refpoint_embed(unsigmoid): bs, NQ, d_model
463
+ #########################################################
464
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
465
+ # if refpoint_embed.isnan().any() | refpoint_embed.isinf().any():
466
+ # import ipdb; ipdb.set_trace()
467
+ # if tgt.isnan().any() | tgt.isinf().any():
468
+ # import ipdb; ipdb.set_trace()
469
+
470
+ #########################################################
471
+ # Begin Decoder
472
+ #########################################################
473
+ hs, references = self.decoder(
474
+ tgt=tgt.transpose(0, 1),
475
+ memory=memory.transpose(0, 1),
476
+ memory_key_padding_mask=mask_flatten,
477
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
478
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
479
+ level_start_index=level_start_index,
480
+ spatial_shapes=spatial_shapes,
481
+ valid_ratios=valid_ratios, tgt_mask=attn_mask,
482
+ tgt_mask2=attn_mask2,
483
+ memory_text=text_dict['encoded_text'],
484
+ text_attention_mask=~text_dict['text_token_mask'],
485
+ text_dict=text_dict,
486
+ dn_meta=dn_meta,
487
+ targets=targets,
488
+ kpt_embed=kpt_embed
489
+ # we ~ the mask . False means use the token; True means pad the token
490
+ )
491
+ #########################################################
492
+ # End Decoder
493
+ # hs: n_dec, bs, nq, d_model
494
+ # references: n_dec+1, bs, nq, query_dim
495
+ #########################################################
496
+
497
+ #########################################################
498
+ # Begin postprocess
499
+ #########################################################
500
+ if self.two_stage_type == 'standard':
501
+ if self.two_stage_keep_all_tokens:
502
+ hs_enc = output_memory.unsqueeze(0)
503
+ ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
504
+ init_box_proposal = output_proposals
505
+ # import ipdb; ipdb.set_trace()
506
+ else:
507
+ hs_enc = tgt_undetach.unsqueeze(0)
508
+ ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
509
+ else:
510
+ hs_enc = ref_enc = None
511
+ #########################################################
512
+ # End postprocess
513
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
514
+ # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
515
+ #########################################################
516
+
517
+ return hs, references, hs_enc, ref_enc, init_box_proposal
518
+ # hs: (n_dec, bs, nq, d_model)
519
+ # references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
520
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
521
+ # ref_enc: sigmoid coordinates. \
522
+ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
523
+
524
+
525
+ class TransformerEncoder(nn.Module):
526
+
527
+ def __init__(self,
528
+ encoder_layer, num_layers, d_model=256,
529
+ num_queries=300,
530
+ enc_layer_share=False,
531
+ text_enhance_layer=None,
532
+ feature_fusion_layer=None,
533
+ use_checkpoint=False,
534
+ use_transformer_ckpt=False,
535
+ ):
536
+ """_summary_
537
+
538
+ Args:
539
+ encoder_layer (_type_): _description_
540
+ num_layers (_type_): _description_
541
+ norm (_type_, optional): _description_. Defaults to None.
542
+ d_model (int, optional): _description_. Defaults to 256.
543
+ num_queries (int, optional): _description_. Defaults to 300.
544
+ enc_layer_share (bool, optional): _description_. Defaults to False.
545
+
546
+ """
547
+ super().__init__()
548
+ # prepare layers
549
+ self.layers = []
550
+ self.text_layers = []
551
+ self.fusion_layers = []
552
+ if num_layers > 0:
553
+ self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
554
+
555
+ if text_enhance_layer is not None:
556
+ self.text_layers = _get_clones(text_enhance_layer, num_layers, layer_share=enc_layer_share)
557
+ if feature_fusion_layer is not None:
558
+ self.fusion_layers = _get_clones(feature_fusion_layer, num_layers, layer_share=enc_layer_share)
559
+ else:
560
+ self.layers = []
561
+ del encoder_layer
562
+
563
+ if text_enhance_layer is not None:
564
+ self.text_layers = []
565
+ del text_enhance_layer
566
+ if feature_fusion_layer is not None:
567
+ self.fusion_layers = []
568
+ del feature_fusion_layer
569
+
570
+ self.query_scale = None
571
+ self.num_queries = num_queries
572
+ self.num_layers = num_layers
573
+ self.d_model = d_model
574
+
575
+ self.use_checkpoint = use_checkpoint
576
+ self.use_transformer_ckpt = use_transformer_ckpt
577
+
578
+ @staticmethod
579
+ def get_reference_points(spatial_shapes, valid_ratios, device):
580
+ reference_points_list = []
581
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
582
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
583
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),)
584
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
585
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
586
+ ref = torch.stack((ref_x, ref_y), -1)
587
+ reference_points_list.append(ref)
588
+ reference_points = torch.cat(reference_points_list, 1)
589
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
590
+ return reference_points
591
+
592
+ def forward(self,
593
+ # for images
594
+ src: Tensor,
595
+ pos: Tensor,
596
+ spatial_shapes: Tensor,
597
+ level_start_index: Tensor,
598
+ valid_ratios: Tensor,
599
+ key_padding_mask: Tensor,
600
+ # for texts
601
+ memory_text: Tensor = None,
602
+ text_attention_mask: Tensor = None,
603
+ pos_text: Tensor = None,
604
+ text_self_attention_masks: Tensor = None,
605
+ position_ids: Tensor = None,
606
+ ):
607
+ """
608
+ Input:
609
+ - src: [bs, sum(hi*wi), 256]
610
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
611
+ - spatial_shapes: h,w of each level [num_level, 2]
612
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
613
+ - valid_ratios: [bs, num_level, 2]
614
+ - key_padding_mask: [bs, sum(hi*wi)]
615
+
616
+ - memory_text: bs, n_text, 256
617
+ - text_attention_mask: bs, n_text
618
+ False for no padding; True for padding
619
+ - pos_text: bs, n_text, 256
620
+
621
+ - position_ids: bs, n_text
622
+ Intermedia:
623
+ - reference_points: [bs, sum(hi*wi), num_level, 2]
624
+ Outpus:
625
+ - output: [bs, sum(hi*wi), 256]
626
+ """
627
+
628
+ output = src
629
+
630
+ # preparation and reshape
631
+ if self.num_layers > 0:
632
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
633
+
634
+ if self.text_layers:
635
+ # generate pos_text
636
+ bs, n_text, text_dim = memory_text.shape
637
+ if pos_text is None and position_ids is None:
638
+ pos_text = torch.arange(n_text, device=memory_text.device).float().unsqueeze(0).unsqueeze(-1).repeat(bs,
639
+ 1,
640
+ 1)
641
+ pos_text = get_sine_pos_embed(pos_text, num_pos_feats=256, exchange_xy=False)
642
+ if position_ids is not None:
643
+ pos_text = get_sine_pos_embed(position_ids[..., None], num_pos_feats=256, exchange_xy=False)
644
+
645
+ # main process
646
+ for layer_id, layer in enumerate(self.layers):
647
+ # if output.isnan().any() or memory_text.isnan().any():
648
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
649
+ # import ipdb; ipdb.set_trace()
650
+ if self.fusion_layers:
651
+ if self.use_checkpoint:
652
+ output, memory_text = checkpoint.checkpoint(
653
+ self.fusion_layers[layer_id],
654
+ output,
655
+ memory_text,
656
+ key_padding_mask,
657
+ text_attention_mask
658
+ )
659
+ else:
660
+ output, memory_text = self.fusion_layers[layer_id](v=output, l=memory_text,
661
+ attention_mask_v=key_padding_mask,
662
+ attention_mask_l=text_attention_mask)
663
+
664
+ if self.text_layers:
665
+ memory_text = self.text_layers[layer_id](
666
+ src=memory_text.transpose(0, 1),
667
+ src_mask=~text_self_attention_masks, # note we use ~ for mask here
668
+ src_key_padding_mask=text_attention_mask,
669
+ pos=(pos_text.transpose(0, 1) if pos_text is not None else None)
670
+ ).transpose(0, 1)
671
+
672
+ # main process
673
+ if self.use_transformer_ckpt:
674
+ output = checkpoint.checkpoint(
675
+ layer,
676
+ output,
677
+ pos,
678
+ reference_points,
679
+ spatial_shapes,
680
+ level_start_index,
681
+ key_padding_mask
682
+ )
683
+ else:
684
+ output = layer(src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes,
685
+ level_start_index=level_start_index, key_padding_mask=key_padding_mask)
686
+
687
+ return output, memory_text
688
+
689
+
690
+ class TransformerDecoder(nn.Module):
691
+
692
+ def __init__(self, decoder_layer, num_layers, norm=None,
693
+ return_intermediate=False,
694
+ d_model=256, query_dim=4,
695
+ modulate_hw_attn=False,
696
+ num_feature_levels=1,
697
+ deformable_decoder=False,
698
+ decoder_query_perturber=None,
699
+ dec_layer_number=None, # number of queries each layer in decoder
700
+ rm_dec_query_scale=False,
701
+ dec_layer_share=False,
702
+ dec_layer_dropout_prob=None,
703
+ use_detached_boxes_dec_out=False,
704
+ num_box_decoder_layers=2,
705
+ num_body_points=68,
706
+ ):
707
+ super().__init__()
708
+ if num_layers > 0:
709
+ self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share)
710
+ else:
711
+ self.layers = []
712
+ self.num_layers = num_layers
713
+ self.norm = norm
714
+ self.return_intermediate = return_intermediate
715
+ assert return_intermediate, "support return_intermediate only"
716
+ self.query_dim = query_dim
717
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
718
+ self.num_feature_levels = num_feature_levels
719
+ self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
720
+
721
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
722
+ if not deformable_decoder:
723
+ self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
724
+ else:
725
+ self.query_pos_sine_scale = None
726
+
727
+ if rm_dec_query_scale:
728
+ self.query_scale = None
729
+ else:
730
+ raise NotImplementedError
731
+ self.query_scale = MLP(d_model, d_model, d_model, 2)
732
+ self.bbox_embed = None
733
+ self.class_embed = None
734
+ self.pose_embed = None
735
+ self.pose_hw_embed = None
736
+ self.d_model = d_model
737
+ self.modulate_hw_attn = modulate_hw_attn
738
+ self.deformable_decoder = deformable_decoder
739
+
740
+ if not deformable_decoder and modulate_hw_attn:
741
+ self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
742
+ else:
743
+ self.ref_anchor_head = None
744
+
745
+ self.decoder_query_perturber = decoder_query_perturber
746
+ self.box_pred_damping = None
747
+
748
+ self.dec_layer_number = dec_layer_number
749
+ if dec_layer_number is not None:
750
+ assert isinstance(dec_layer_number, list)
751
+ assert len(dec_layer_number) == num_layers
752
+ # assert dec_layer_number[0] ==
753
+
754
+ self.dec_layer_dropout_prob = dec_layer_dropout_prob
755
+ if dec_layer_dropout_prob is not None:
756
+ assert isinstance(dec_layer_dropout_prob, list)
757
+ assert len(dec_layer_dropout_prob) == num_layers
758
+ for i in dec_layer_dropout_prob:
759
+ assert 0.0 <= i <= 1.0
760
+
761
+ self.rm_detach = None
762
+ self.num_body_points = num_body_points
763
+
764
+ self.hw = nn.Embedding(17, 2)
765
+ self.num_box_decoder_layers = num_box_decoder_layers
766
+ self.kpt_index = [x for x in range(50 * (self.num_body_points + 1)) if x % (self.num_body_points + 1) != 0]
767
+ self.hw_append = nn.Embedding(self.num_body_points-17, 2)
768
+
769
+ def forward(self, tgt, memory,
770
+ tgt_mask: Optional[Tensor] = None,
771
+ tgt_mask2: Optional[Tensor] = None,
772
+ memory_mask: Optional[Tensor] = None,
773
+ tgt_key_padding_mask: Optional[Tensor] = None,
774
+ memory_key_padding_mask: Optional[Tensor] = None,
775
+ pos: Optional[Tensor] = None,
776
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
777
+ # for memory
778
+ level_start_index: Optional[Tensor] = None, # num_levels
779
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
780
+ valid_ratios: Optional[Tensor] = None,
781
+ # for text
782
+ memory_text: Optional[Tensor] = None,
783
+ text_attention_mask: Optional[Tensor] = None,
784
+ text_dict: Optional[Tensor] = None,
785
+ dn_meta: Optional[Tensor] = None,
786
+ targets: Optional[Tensor] = None,
787
+ kpt_embed: Optional[Tensor] = None
788
+ ):
789
+ """
790
+ Input:
791
+ - tgt: nq, bs, d_model
792
+ - memory: hw, bs, d_model
793
+ - pos: hw, bs, d_model
794
+ - refpoints_unsigmoid: nq, bs, 2/4
795
+ - valid_ratios/spatial_shapes: bs, nlevel, 2
796
+ """
797
+
798
+ output = tgt
799
+ output += self.hw.weight[0, 0] * 0.0
800
+
801
+
802
+ intermediate = []
803
+ reference_points = refpoints_unsigmoid.sigmoid()
804
+ ref_points = [reference_points]
805
+ effect_num_dn = dn_meta['pad_size'] if self.training else 0
806
+ inter_select_number = 50
807
+ for layer_id, layer in enumerate(self.layers):
808
+
809
+ if reference_points.shape[-1] == 4:
810
+ reference_points_input = reference_points[:, :, None] \
811
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4
812
+ else:
813
+ assert reference_points.shape[-1] == 2
814
+ reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
815
+ query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # nq, bs, 256*2
816
+
817
+ # conditional query
818
+ raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
819
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
820
+ query_pos = pos_scale * raw_query_pos
821
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
822
+ # if query_pos.isnan().any() | query_pos.isinf().any():
823
+ # import ipdb; ipdb.set_trace()
824
+
825
+ # main process
826
+ output = layer(
827
+ tgt=output,
828
+ tgt_query_pos=query_pos,
829
+ tgt_query_sine_embed=query_sine_embed,
830
+ tgt_key_padding_mask=tgt_key_padding_mask,
831
+ tgt_reference_points=reference_points_input,
832
+
833
+ memory_text=memory_text,
834
+ text_attention_mask=text_attention_mask,
835
+
836
+ memory=memory,
837
+ memory_key_padding_mask=memory_key_padding_mask,
838
+ memory_level_start_index=level_start_index,
839
+ memory_spatial_shapes=spatial_shapes,
840
+ memory_pos=pos,
841
+
842
+ self_attn_mask=tgt_mask,
843
+ cross_attn_mask=memory_mask
844
+ )
845
+ if output.isnan().any() | output.isinf().any():
846
+ print(f"output layer_id {layer_id} is nan")
847
+ try:
848
+ num_nan = output.isnan().sum().item()
849
+ num_inf = output.isinf().sum().item()
850
+ print(f"num_nan {num_nan}, num_inf {num_inf}")
851
+ except Exception as e:
852
+ print(e)
853
+
854
+
855
+
856
+
857
+ intermediate.append(self.norm(output))
858
+ # iter update
859
+ if layer_id < self.num_box_decoder_layers:
860
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
861
+ delta_unsig = self.bbox_embed[layer_id](output)
862
+ outputs_unsig = delta_unsig + reference_before_sigmoid
863
+ new_reference_points = outputs_unsig.sigmoid()
864
+
865
+ # select # ref points as anchors
866
+ if layer_id == self.num_box_decoder_layers - 1:
867
+ dn_output = output[:effect_num_dn]
868
+ dn_new_reference_points = new_reference_points[:effect_num_dn]
869
+ class_unselected = self.class_embed[layer_id](output.transpose(0, 1), text_dict)[:,
870
+ effect_num_dn:].transpose(0, 1)
871
+ topk_proposals = torch.topk(class_unselected.max(-1)[0], inter_select_number, dim=0)[1]
872
+ new_reference_points_for_box = torch.gather(new_reference_points[effect_num_dn:], 0,
873
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
874
+ new_output_for_box = torch.gather(output[effect_num_dn:], 0,
875
+ topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
876
+ keypoint_embed=kpt_embed.transpose(0, 1)
877
+
878
+ new_output_for_keypoint = keypoint_embed[None, :, :, :].repeat(new_output_for_box.shape[0],1,1,1)
879
+ delta_xy = self.pose_embed[-1](new_output_for_keypoint)[..., :2]
880
+ keypoint_xy = (inverse_sigmoid(new_reference_points_for_box[..., :2][:, None]) + delta_xy).sigmoid()
881
+ num_queries, _, bs, _ = keypoint_xy.shape
882
+ aa = torch.cat((self.hw.weight,self.hw_append.weight),dim=0)
883
+ keypoint_wh_weight = aa.unsqueeze(0).unsqueeze(-2).repeat(num_queries, 1, bs, 1).sigmoid()
884
+ keypoint_wh = keypoint_wh_weight * new_reference_points_for_box[..., 2:][:, None]
885
+ new_reference_points_for_keypoint = torch.cat((keypoint_xy, keypoint_wh), dim=-1)
886
+ new_reference_points = torch.cat(
887
+ (new_reference_points_for_box.unsqueeze(1), new_reference_points_for_keypoint), dim=1).flatten(0, 1)
888
+ output = torch.cat((new_output_for_box.unsqueeze(1), new_output_for_keypoint), dim=1).flatten(0, 1)
889
+ new_reference_points = torch.cat((dn_new_reference_points, new_reference_points), dim=0)
890
+ output = torch.cat((dn_output, output), dim=0)
891
+ tgt_mask = tgt_mask2
892
+
893
+ if layer_id >= self.num_box_decoder_layers:
894
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
895
+ output_bbox_dn = output[:effect_num_dn]
896
+ output_bbox_norm = output[effect_num_dn:][0::(self.num_body_points + 1)]
897
+ reference_before_sigmoid_bbox_dn = reference_before_sigmoid[:effect_num_dn]
898
+ reference_before_sigmoid_bbox_norm = reference_before_sigmoid[effect_num_dn:][
899
+ 0::(self.num_body_points + 1)]
900
+ delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn)
901
+ delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm)
902
+ outputs_unsig_dn = delta_unsig_dn + reference_before_sigmoid_bbox_dn
903
+ outputs_unsig_norm = delta_unsig_norm + reference_before_sigmoid_bbox_norm
904
+ new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid()
905
+ new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid()
906
+ output_kpt = output[effect_num_dn:].index_select(0, torch.tensor(self.kpt_index, device=output.device))
907
+ delta_xy_unsig = self.pose_embed[layer_id - self.num_box_decoder_layers](output_kpt)
908
+ outputs_unsig = reference_before_sigmoid[effect_num_dn:].index_select(0, torch.tensor(self.kpt_index,
909
+ device=output.device)).clone() ##
910
+ delta_hw_unsig = self.pose_hw_embed[layer_id - self.num_box_decoder_layers](output_kpt)
911
+ outputs_unsig[..., :2] += delta_xy_unsig[..., :2]
912
+ outputs_unsig[..., 2:] += delta_hw_unsig
913
+ new_reference_points_for_keypoint = outputs_unsig.sigmoid()
914
+ bs = new_reference_points_for_box_norm.shape[1]
915
+ new_reference_points_norm = torch.cat((new_reference_points_for_box_norm.unsqueeze(1),
916
+ new_reference_points_for_keypoint.view(-1, self.num_body_points,
917
+ bs, 4)), dim=1).flatten(0,
918
+ 1)
919
+ new_reference_points = torch.cat((new_reference_points_for_box_dn, new_reference_points_norm), dim=0)
920
+
921
+ if self.rm_detach and 'dec' in self.rm_detach:
922
+ reference_points = new_reference_points
923
+ else:
924
+ reference_points = new_reference_points.detach()
925
+
926
+ # if layer_id != self.num_layers - 1:
927
+ if self.use_detached_boxes_dec_out:
928
+ ref_points.append(reference_points)
929
+ else:
930
+ ref_points.append(new_reference_points)
931
+
932
+ return [
933
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
934
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]
935
+ ]
936
+
937
+
938
+ class DeformableTransformerEncoderLayer(nn.Module):
939
+ def __init__(self,
940
+ d_model=256, d_ffn=1024,
941
+ dropout=0.1, activation="relu",
942
+ n_levels=4, n_heads=8, n_points=4,
943
+ add_channel_attention=False,
944
+ use_deformable_box_attn=False,
945
+ box_attn_type='roi_align',
946
+ ):
947
+ super().__init__()
948
+
949
+ # self attention
950
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
951
+ self.dropout1 = nn.Dropout(dropout)
952
+ self.norm1 = nn.LayerNorm(d_model)
953
+
954
+ # ffn
955
+ self.linear1 = nn.Linear(d_model, d_ffn)
956
+ self.activation = _get_activation_fn(activation, d_model=d_ffn)
957
+ self.dropout2 = nn.Dropout(dropout)
958
+ self.linear2 = nn.Linear(d_ffn, d_model)
959
+ self.dropout3 = nn.Dropout(dropout)
960
+ self.norm2 = nn.LayerNorm(d_model)
961
+
962
+ # channel attention
963
+ self.add_channel_attention = add_channel_attention
964
+ if add_channel_attention:
965
+ self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model)
966
+ self.norm_channel = nn.LayerNorm(d_model)
967
+
968
+ @staticmethod
969
+ def with_pos_embed(tensor, pos):
970
+ return tensor if pos is None else tensor + pos
971
+
972
+ def forward_ffn(self, src):
973
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
974
+ src = src + self.dropout3(src2)
975
+ src = self.norm2(src)
976
+ return src
977
+
978
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
979
+ # self attention
980
+ # import ipdb; ipdb.set_trace()
981
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index,
982
+ key_padding_mask)
983
+ src = src + self.dropout1(src2)
984
+ src = self.norm1(src)
985
+
986
+ # ffn
987
+ src = self.forward_ffn(src)
988
+
989
+ # channel attn
990
+ if self.add_channel_attention:
991
+ src = self.norm_channel(src + self.activ_channel(src))
992
+
993
+ return src
994
+
995
+
996
+ class DeformableTransformerDecoderLayer(nn.Module):
997
+ def __init__(self, d_model=256, d_ffn=1024,
998
+ dropout=0.1, activation="relu",
999
+ n_levels=4, n_heads=8, n_points=4,
1000
+ use_text_feat_guide=False,
1001
+ use_text_cross_attention=False,
1002
+ ffn_extra_layernorm=False
1003
+ ):
1004
+ super().__init__()
1005
+
1006
+ # cross attention
1007
+ # self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
1008
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
1009
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1010
+ self.norm1 = nn.LayerNorm(d_model)
1011
+
1012
+ # cross attention text
1013
+ if use_text_cross_attention:
1014
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
1015
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1016
+ self.catext_norm = nn.LayerNorm(d_model)
1017
+
1018
+ # self attention
1019
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
1020
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1021
+ self.norm2 = nn.LayerNorm(d_model)
1022
+
1023
+ # ffn
1024
+ self.linear1 = nn.Linear(d_model, d_ffn)
1025
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
1026
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1027
+ self.linear2 = nn.Linear(d_ffn, d_model)
1028
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
1029
+ self.norm3 = nn.LayerNorm(d_model)
1030
+ if ffn_extra_layernorm:
1031
+ raise NotImplementedError('ffn_extra_layernorm not implemented')
1032
+ self.norm_ext = nn.LayerNorm(d_ffn)
1033
+ else:
1034
+ self.norm_ext = None
1035
+
1036
+ self.key_aware_proj = None
1037
+ self.use_text_feat_guide = use_text_feat_guide
1038
+ assert not use_text_feat_guide
1039
+ self.use_text_cross_attention = use_text_cross_attention
1040
+
1041
+ def rm_self_attn_modules(self):
1042
+ self.self_attn = None
1043
+ self.dropout2 = None
1044
+ self.norm2 = None
1045
+
1046
+ @staticmethod
1047
+ def with_pos_embed(tensor, pos):
1048
+ return tensor if pos is None else tensor + pos
1049
+
1050
+ def forward_ffn(self, tgt, ipdb_flag=False):
1051
+
1052
+ with torch.cuda.amp.autocast(enabled=False):
1053
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
1054
+
1055
+ tgt = tgt + self.dropout4(tgt2)
1056
+ tgt = self.norm3(tgt)
1057
+ return tgt
1058
+
1059
+ def forward(self,
1060
+ # for tgt
1061
+ tgt: Optional[Tensor], # nq, bs, d_model
1062
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
1063
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
1064
+ tgt_key_padding_mask: Optional[Tensor] = None,
1065
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
1066
+
1067
+ memory_text: Optional[Tensor] = None, # bs, num_token, d_model
1068
+ text_attention_mask: Optional[Tensor] = None, # bs, num_token
1069
+
1070
+ # for memory
1071
+ memory: Optional[Tensor] = None, # hw, bs, d_model
1072
+ memory_key_padding_mask: Optional[Tensor] = None,
1073
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
1074
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
1075
+ memory_pos: Optional[Tensor] = None, # pos for memory
1076
+
1077
+ # sa
1078
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
1079
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
1080
+ ):
1081
+ """
1082
+ Input:
1083
+ - tgt/tgt_query_pos: nq, bs, d_model
1084
+ -
1085
+ """
1086
+ assert cross_attn_mask is None
1087
+
1088
+ # self attention
1089
+ if self.self_attn is not None:
1090
+ # import ipdb; ipdb.set_trace()
1091
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
1092
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
1093
+ tgt = tgt + self.dropout2(tgt2)
1094
+ tgt = self.norm2(tgt)
1095
+
1096
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
1097
+ # if tgt.isnan().any() | tgt.isinf().any() :
1098
+ # import ipdb; ipdb.set_trace()
1099
+
1100
+ if self.use_text_cross_attention:
1101
+ tgt2 = self.ca_text(self.with_pos_embed(tgt, tgt_query_pos), memory_text.transpose(0, 1),
1102
+ memory_text.transpose(0, 1), key_padding_mask=text_attention_mask)[0]
1103
+ tgt = tgt + self.catext_dropout(tgt2)
1104
+ tgt = self.catext_norm(tgt)
1105
+
1106
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
1107
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
1108
+ # import ipdb; ipdb.set_trace()
1109
+
1110
+ # if tgt.isnan().any() | tgt.isinf().any() :
1111
+ # import ipdb; ipdb.set_trace()
1112
+
1113
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
1114
+ tgt_reference_points.transpose(0, 1).contiguous(),
1115
+ memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index,
1116
+ memory_key_padding_mask).transpose(0, 1)
1117
+ tgt = tgt + self.dropout1(tgt2)
1118
+ tgt = self.norm1(tgt)
1119
+
1120
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
1121
+ # tgtk = tgt.clone()
1122
+ # if tgt.isnan().any() | tgt.isinf().any() :
1123
+ # import ipdb; ipdb.set_trace()
1124
+
1125
+ # ffn
1126
+ tgt = self.forward_ffn(tgt)
1127
+ # if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1':
1128
+ # if tgt.isnan().any() | tgt.isinf().any() :
1129
+ # tgtk = self.forward_ffn(tgtk, ipdb_flag=True)
1130
+ # import ipdb; ipdb.set_trace()
1131
+
1132
+ return tgt
1133
+
1134
+
1135
+ def _get_clones(module, N, layer_share=False):
1136
+ # import ipdb; ipdb.set_trace()
1137
+ if layer_share:
1138
+ return nn.ModuleList([module for i in range(N)])
1139
+ else:
1140
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
1141
+
1142
+
1143
+ def build_deformable_transformer(args):
1144
+ decoder_query_perturber = None
1145
+ if args.decoder_layer_noise:
1146
+ from .utils import RandomBoxPerturber
1147
+ decoder_query_perturber = RandomBoxPerturber(
1148
+ x_noise_scale=args.dln_xy_noise, y_noise_scale=args.dln_xy_noise,
1149
+ w_noise_scale=args.dln_hw_noise, h_noise_scale=args.dln_hw_noise)
1150
+
1151
+ use_detached_boxes_dec_out = False
1152
+ try:
1153
+ use_detached_boxes_dec_out = args.use_detached_boxes_dec_out
1154
+ except:
1155
+ use_detached_boxes_dec_out = False
1156
+
1157
+ binary_query_selection = False
1158
+ try:
1159
+ binary_query_selection = args.binary_query_selection
1160
+ except:
1161
+ binary_query_selection = False
1162
+
1163
+ ffn_extra_layernorm = False
1164
+ try:
1165
+ ffn_extra_layernorm = args.ffn_extra_layernorm
1166
+ except:
1167
+ print('ffn_extra_layernorm not found, set to False')
1168
+ ffn_extra_layernorm = False
1169
+
1170
+ return DeformableTransformer(
1171
+ d_model=args.hidden_dim,
1172
+ dropout=args.dropout,
1173
+ nhead=args.nheads,
1174
+ num_queries=args.num_queries,
1175
+ dim_feedforward=args.dim_feedforward,
1176
+ num_encoder_layers=args.enc_layers,
1177
+ num_unicoder_layers=args.unic_layers,
1178
+ num_decoder_layers=args.dec_layers,
1179
+ normalize_before=args.pre_norm,
1180
+ return_intermediate_dec=True,
1181
+ query_dim=args.query_dim,
1182
+ activation=args.transformer_activation,
1183
+ num_patterns=args.num_patterns,
1184
+ modulate_hw_attn=True,
1185
+
1186
+ deformable_encoder=True,
1187
+ deformable_decoder=True,
1188
+ num_feature_levels=args.num_feature_levels,
1189
+ enc_n_points=args.enc_n_points,
1190
+ dec_n_points=args.dec_n_points,
1191
+ use_deformable_box_attn=args.use_deformable_box_attn,
1192
+ box_attn_type=args.box_attn_type,
1193
+
1194
+ learnable_tgt_init=True,
1195
+ decoder_query_perturber=decoder_query_perturber,
1196
+
1197
+ add_channel_attention=args.add_channel_attention,
1198
+ add_pos_value=args.add_pos_value,
1199
+ random_refpoints_xy=args.random_refpoints_xy,
1200
+
1201
+ # two stage
1202
+ two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
1203
+ two_stage_pat_embed=args.two_stage_pat_embed,
1204
+ two_stage_add_query_num=args.two_stage_add_query_num,
1205
+ two_stage_learn_wh=args.two_stage_learn_wh,
1206
+ two_stage_keep_all_tokens=args.two_stage_keep_all_tokens,
1207
+ dec_layer_number=args.dec_layer_number,
1208
+ rm_self_attn_layers=None,
1209
+ key_aware_type=None,
1210
+ layer_share_type=None,
1211
+
1212
+ rm_detach=None,
1213
+ decoder_sa_type=args.decoder_sa_type,
1214
+ module_seq=args.decoder_module_seq,
1215
+
1216
+ embed_init_tgt=args.embed_init_tgt,
1217
+ use_detached_boxes_dec_out=use_detached_boxes_dec_out,
1218
+ use_text_enhancer=args.use_text_enhancer,
1219
+ use_fusion_layer=args.use_fusion_layer,
1220
+ use_checkpoint=args.use_checkpoint,
1221
+ use_transformer_ckpt=args.use_transformer_ckpt,
1222
+ use_text_cross_attention=args.use_text_cross_attention,
1223
+
1224
+ text_dropout=args.text_dropout,
1225
+ fusion_dropout=args.fusion_dropout,
1226
+ fusion_droppath=args.fusion_droppath,
1227
+
1228
+ binary_query_selection=binary_query_selection,
1229
+ ffn_extra_layernorm=ffn_extra_layernorm,
1230
+ )
src/models/XPose/models/UniPose/fuse_modules.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # from timm.models.layers import DropPath
6
+ from src.models.util import DropPath
7
+
8
+
9
+ class FeatureResizer(nn.Module):
10
+ """
11
+ This class takes as input a set of embeddings of dimension C1 and outputs a set of
12
+ embedding of dimension C2, after a linear transformation, dropout and normalization (LN).
13
+ """
14
+
15
+ def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True):
16
+ super().__init__()
17
+ self.do_ln = do_ln
18
+ # Object feature encoding
19
+ self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True)
20
+ self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12)
21
+ self.dropout = nn.Dropout(dropout)
22
+
23
+ def forward(self, encoder_features):
24
+ x = self.fc(encoder_features)
25
+ if self.do_ln:
26
+ x = self.layer_norm(x)
27
+ output = self.dropout(x)
28
+ return output
29
+
30
+
31
+ def l1norm(X, dim, eps=1e-8):
32
+ """L1-normalize columns of X
33
+ """
34
+ norm = torch.abs(X).sum(dim=dim, keepdim=True) + eps
35
+ X = torch.div(X, norm)
36
+ return X
37
+
38
+
39
+ def l2norm(X, dim, eps=1e-8):
40
+ """L2-normalize columns of X
41
+ """
42
+ norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
43
+ X = torch.div(X, norm)
44
+ return X
45
+
46
+
47
+ def func_attention(query, context, smooth=1, raw_feature_norm="softmax", eps=1e-8):
48
+ """
49
+ query: (n_context, queryL, d)
50
+ context: (n_context, sourceL, d)
51
+ """
52
+ batch_size_q, queryL = query.size(0), query.size(1)
53
+ batch_size, sourceL = context.size(0), context.size(1)
54
+
55
+ # Get attention
56
+ # --> (batch, d, queryL)
57
+ queryT = torch.transpose(query, 1, 2)
58
+
59
+ # (batch, sourceL, d)(batch, d, queryL)
60
+ # --> (batch, sourceL, queryL)
61
+ attn = torch.bmm(context, queryT)
62
+ if raw_feature_norm == "softmax":
63
+ # --> (batch*sourceL, queryL)
64
+ attn = attn.view(batch_size * sourceL, queryL)
65
+ attn = nn.Softmax()(attn)
66
+ # --> (batch, sourceL, queryL)
67
+ attn = attn.view(batch_size, sourceL, queryL)
68
+ elif raw_feature_norm == "l2norm":
69
+ attn = l2norm(attn, 2)
70
+ elif raw_feature_norm == "clipped_l2norm":
71
+ attn = nn.LeakyReLU(0.1)(attn)
72
+ attn = l2norm(attn, 2)
73
+ else:
74
+ raise ValueError("unknown first norm type:", raw_feature_norm)
75
+ # --> (batch, queryL, sourceL)
76
+ attn = torch.transpose(attn, 1, 2).contiguous()
77
+ # --> (batch*queryL, sourceL)
78
+ attn = attn.view(batch_size * queryL, sourceL)
79
+ attn = nn.Softmax()(attn * smooth)
80
+ # --> (batch, queryL, sourceL)
81
+ attn = attn.view(batch_size, queryL, sourceL)
82
+ # --> (batch, sourceL, queryL)
83
+ attnT = torch.transpose(attn, 1, 2).contiguous()
84
+
85
+ # --> (batch, d, sourceL)
86
+ contextT = torch.transpose(context, 1, 2)
87
+ # (batch x d x sourceL)(batch x sourceL x queryL)
88
+ # --> (batch, d, queryL)
89
+ weightedContext = torch.bmm(contextT, attnT)
90
+ # --> (batch, queryL, d)
91
+ weightedContext = torch.transpose(weightedContext, 1, 2)
92
+
93
+ return weightedContext, attnT
94
+
95
+
96
+ class BiMultiHeadAttention(nn.Module):
97
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None):
98
+ super(BiMultiHeadAttention, self).__init__()
99
+
100
+ self.embed_dim = embed_dim
101
+ self.num_heads = num_heads
102
+ self.head_dim = embed_dim // num_heads
103
+ self.v_dim = v_dim
104
+ self.l_dim = l_dim
105
+
106
+ assert (
107
+ self.head_dim * self.num_heads == self.embed_dim
108
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
109
+ self.scale = self.head_dim ** (-0.5)
110
+ self.dropout = dropout
111
+
112
+ self.v_proj = nn.Linear(self.v_dim, self.embed_dim)
113
+ self.l_proj = nn.Linear(self.l_dim, self.embed_dim)
114
+ self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim)
115
+ self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim)
116
+
117
+ self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim)
118
+ self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim)
119
+
120
+ self.stable_softmax_2d = True
121
+ self.clamp_min_for_underflow = True
122
+ self.clamp_max_for_overflow = True
123
+
124
+ self._reset_parameters()
125
+
126
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
127
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
128
+
129
+ def _reset_parameters(self):
130
+ nn.init.xavier_uniform_(self.v_proj.weight)
131
+ self.v_proj.bias.data.fill_(0)
132
+ nn.init.xavier_uniform_(self.l_proj.weight)
133
+ self.l_proj.bias.data.fill_(0)
134
+ nn.init.xavier_uniform_(self.values_v_proj.weight)
135
+ self.values_v_proj.bias.data.fill_(0)
136
+ nn.init.xavier_uniform_(self.values_l_proj.weight)
137
+ self.values_l_proj.bias.data.fill_(0)
138
+ nn.init.xavier_uniform_(self.out_v_proj.weight)
139
+ self.out_v_proj.bias.data.fill_(0)
140
+ nn.init.xavier_uniform_(self.out_l_proj.weight)
141
+ self.out_l_proj.bias.data.fill_(0)
142
+
143
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
144
+ """_summary_
145
+
146
+ Args:
147
+ v (_type_): bs, n_img, dim
148
+ l (_type_): bs, n_text, dim
149
+ attention_mask_v (_type_, optional): _description_. bs, n_img
150
+ attention_mask_l (_type_, optional): _description_. bs, n_text
151
+
152
+ Returns:
153
+ _type_: _description_
154
+ """
155
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
156
+ # import ipdb; ipdb.set_trace()
157
+ bsz, tgt_len, _ = v.size()
158
+
159
+ query_states = self.v_proj(v) * self.scale
160
+ key_states = self._shape(self.l_proj(l), -1, bsz)
161
+ value_v_states = self._shape(self.values_v_proj(v), -1, bsz)
162
+ value_l_states = self._shape(self.values_l_proj(l), -1, bsz)
163
+
164
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
165
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
166
+ key_states = key_states.view(*proj_shape)
167
+ value_v_states = value_v_states.view(*proj_shape)
168
+ value_l_states = value_l_states.view(*proj_shape)
169
+
170
+ src_len = key_states.size(1)
171
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) # bs*nhead, nimg, ntxt
172
+
173
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
174
+ raise ValueError(
175
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
176
+ )
177
+
178
+ if self.stable_softmax_2d:
179
+ attn_weights = attn_weights - attn_weights.max()
180
+
181
+ if self.clamp_min_for_underflow:
182
+ attn_weights = torch.clamp(attn_weights,
183
+ min=-50000) # Do not increase -50000, data type half has quite limited range
184
+ if self.clamp_max_for_overflow:
185
+ attn_weights = torch.clamp(attn_weights,
186
+ max=50000) # Do not increase 50000, data type half has quite limited range
187
+
188
+ attn_weights_T = attn_weights.transpose(1, 2)
189
+ attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[
190
+ 0])
191
+ if self.clamp_min_for_underflow:
192
+ attn_weights_l = torch.clamp(attn_weights_l,
193
+ min=-50000) # Do not increase -50000, data type half has quite limited range
194
+ if self.clamp_max_for_overflow:
195
+ attn_weights_l = torch.clamp(attn_weights_l,
196
+ max=50000) # Do not increase 50000, data type half has quite limited range
197
+
198
+ # mask vison for language
199
+ if attention_mask_v is not None:
200
+ attention_mask_v = attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
201
+ attn_weights_l.masked_fill_(attention_mask_v, float('-inf'))
202
+
203
+ attn_weights_l = attn_weights_l.softmax(dim=-1)
204
+
205
+ # mask language for vision
206
+ if attention_mask_l is not None:
207
+ attention_mask_l = attention_mask_l[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
208
+ attn_weights.masked_fill_(attention_mask_l, float('-inf'))
209
+ attn_weights_v = attn_weights.softmax(dim=-1)
210
+
211
+ attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training)
212
+ attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training)
213
+
214
+ attn_output_v = torch.bmm(attn_probs_v, value_l_states)
215
+ attn_output_l = torch.bmm(attn_probs_l, value_v_states)
216
+
217
+ if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
218
+ raise ValueError(
219
+ f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}"
220
+ )
221
+
222
+ if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim):
223
+ raise ValueError(
224
+ f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}"
225
+ )
226
+
227
+ attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim)
228
+ attn_output_v = attn_output_v.transpose(1, 2)
229
+ attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim)
230
+
231
+ attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim)
232
+ attn_output_l = attn_output_l.transpose(1, 2)
233
+ attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim)
234
+
235
+ attn_output_v = self.out_v_proj(attn_output_v)
236
+ attn_output_l = self.out_l_proj(attn_output_l)
237
+
238
+ return attn_output_v, attn_output_l
239
+
240
+
241
+ # Bi-Direction MHA (text->image, image->text)
242
+ class BiAttentionBlock(nn.Module):
243
+ def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1,
244
+ drop_path=.0, init_values=1e-4, cfg=None):
245
+ """
246
+ Inputs:
247
+ embed_dim - Dimensionality of input and attention feature vectors
248
+ hidden_dim - Dimensionality of hidden layer in feed-forward network
249
+ (usually 2-4x larger than embed_dim)
250
+ num_heads - Number of heads to use in the Multi-Head Attention block
251
+ dropout - Amount of dropout to apply in the feed-forward network
252
+ """
253
+ super(BiAttentionBlock, self).__init__()
254
+
255
+ # pre layer norm
256
+ self.layer_norm_v = nn.LayerNorm(v_dim)
257
+ self.layer_norm_l = nn.LayerNorm(l_dim)
258
+ self.attn = BiMultiHeadAttention(v_dim=v_dim,
259
+ l_dim=l_dim,
260
+ embed_dim=embed_dim,
261
+ num_heads=num_heads,
262
+ dropout=dropout)
263
+
264
+ # add layer scale for training stability
265
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
266
+ self.gamma_v = nn.Parameter(init_values * torch.ones((v_dim)), requires_grad=False)
267
+ self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=False)
268
+
269
+ def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
270
+ v = self.layer_norm_v(v)
271
+ l = self.layer_norm_l(l)
272
+ delta_v, delta_l = self.attn(v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l)
273
+ # v, l = v + delta_v, l + delta_l
274
+ v = v + self.drop_path(self.gamma_v * delta_v)
275
+ l = l + self.drop_path(self.gamma_l * delta_l)
276
+ return v, l
src/models/XPose/models/UniPose/mask_generate.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def prepare_for_mask(kpt_mask):
5
+
6
+
7
+ tgt_size2 = 50 * 69
8
+ attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0
9
+ group_bbox_kpt = 69
10
+ num_group=50
11
+ for matchj in range(num_group * group_bbox_kpt):
12
+ sj = (matchj // group_bbox_kpt) * group_bbox_kpt
13
+ ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
14
+ if sj > 0:
15
+ attn_mask2[:,:,matchj, :sj] = True
16
+ if ej < num_group * group_bbox_kpt:
17
+ attn_mask2[:,:,matchj, ej:] = True
18
+
19
+
20
+ bs, length = kpt_mask.shape
21
+ equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :]
22
+ equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1)
23
+ for idx in range(num_group):
24
+ start_idx = idx * length
25
+ end_idx = (idx + 1) * length
26
+ attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False
27
+ attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True
28
+
29
+
30
+
31
+
32
+ input_query_label = None
33
+ input_query_bbox = None
34
+ attn_mask = None
35
+ dn_meta = None
36
+
37
+ return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta
38
+
39
+
40
+ def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
41
+
42
+ if dn_meta and dn_meta['pad_size'] > 0:
43
+
44
+ output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class]
45
+ output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord]
46
+
47
+ outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class]
48
+ outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord]
49
+
50
+ out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
51
+ if aux_loss:
52
+ out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
53
+ dn_meta['output_known_lbs_bboxes'] = out
54
+ return outputs_class, outputs_coord
55
+
56
+
src/models/XPose/models/UniPose/ops/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/5 21:58
3
+ # @Author : shaoguowen
4
+ # @Email : [email protected]
5
+ # @Project : FasterLivePortrait
6
+ # @FileName: __init__.py.py
src/models/XPose/models/UniPose/ops/functions/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn_func import MSDeformAttnFunction
10
+
src/models/XPose/models/UniPose/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.autograd import Function
16
+ from torch.autograd.function import once_differentiable
17
+
18
+ import MultiScaleDeformableAttention as MSDA
19
+
20
+
21
+ class MSDeformAttnFunction(Function):
22
+ @staticmethod
23
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24
+ ctx.im2col_step = im2col_step
25
+ output = MSDA.ms_deform_attn_forward(
26
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28
+ return output
29
+
30
+ @staticmethod
31
+ @once_differentiable
32
+ def backward(ctx, grad_output):
33
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34
+ grad_value, grad_sampling_loc, grad_attn_weight = \
35
+ MSDA.ms_deform_attn_backward(
36
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37
+
38
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39
+
40
+
41
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42
+ # for debug and test only,
43
+ # need to use cuda version instead
44
+ N_, S_, M_, D_ = value.shape
45
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47
+ sampling_grids = 2 * sampling_locations - 1
48
+ sampling_value_list = []
49
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54
+ # N_*M_, D_, Lq_, P_
55
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56
+ mode='bilinear', padding_mode='zeros', align_corners=False)
57
+ sampling_value_list.append(sampling_value_l_)
58
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61
+ return output.transpose(1, 2).contiguous()