Spaces:
Running
Running
MeYourHint
commited on
Commit
•
c0eac48
1
Parent(s):
08572f0
first demo version
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- LICENSE +21 -0
- README.md +221 -13
- app.py +203 -0
- assets/mapping.json +1 -0
- assets/mapping6.json +1 -0
- assets/text_prompt.txt +12 -0
- common/__init__.py +0 -0
- common/quaternion.py +423 -0
- common/skeleton.py +199 -0
- data/__init__.py +0 -0
- data/t2m_dataset.py +348 -0
- dataset/__init__.py +0 -0
- edit_t2m.py +195 -0
- environment.yml +204 -0
- eval_t2m_trans_res.py +199 -0
- eval_t2m_vq.py +123 -0
- example_data/000612.mp4 +0 -0
- example_data/000612.npy +3 -0
- gen_t2m.py +261 -0
- models/.DS_Store +0 -0
- models/__init__.py +0 -0
- models/mask_transformer/__init__.py +0 -0
- models/mask_transformer/tools.py +165 -0
- models/mask_transformer/transformer.py +1039 -0
- models/mask_transformer/transformer_trainer.py +359 -0
- models/t2m_eval_modules.py +182 -0
- models/t2m_eval_wrapper.py +191 -0
- models/vq/__init__.py +0 -0
- models/vq/encdec.py +68 -0
- models/vq/model.py +124 -0
- models/vq/quantizer.py +180 -0
- models/vq/residual_vq.py +194 -0
- models/vq/resnet.py +84 -0
- models/vq/vq_trainer.py +359 -0
- motion_loaders/__init__.py +0 -0
- motion_loaders/dataset_motion_loader.py +27 -0
- options/__init__.py +0 -0
- options/base_option.py +61 -0
- options/eval_option.py +38 -0
- options/train_option.py +64 -0
- options/vq_option.py +89 -0
- prepare/.DS_Store +0 -0
- prepare/download_evaluator.sh +24 -0
- prepare/download_glove.sh +9 -0
- prepare/download_models.sh +31 -0
- prepare/download_models_demo.sh +10 -0
- requirements.txt +140 -0
- train_res_transformer.py +171 -0
- train_t2m_transformer.py +153 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Chuan Guo
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,221 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MoMask: Generative Masked Modeling of 3D Human Motions
|
2 |
+
## [[Project Page]](https://ericguo5513.github.io/momask) [[Paper]](https://arxiv.org/abs/2312.00063)
|
3 |
+
![teaser_image](https://ericguo5513.github.io/momask/static/images/teaser.png)
|
4 |
+
|
5 |
+
If you find our code or paper helpful, please consider citing:
|
6 |
+
```
|
7 |
+
@article{guo2023momask,
|
8 |
+
title={MoMask: Generative Masked Modeling of 3D Human Motions},
|
9 |
+
author={Chuan Guo and Yuxuan Mu and Muhammad Gohar Javed and Sen Wang and Li Cheng},
|
10 |
+
year={2023},
|
11 |
+
eprint={2312.00063},
|
12 |
+
archivePrefix={arXiv},
|
13 |
+
primaryClass={cs.CV}
|
14 |
+
}
|
15 |
+
```
|
16 |
+
|
17 |
+
## :postbox: News
|
18 |
+
📢 **2023-12-19** --- Release scripts for temporal inpainting.
|
19 |
+
|
20 |
+
📢 **2023-12-15** --- Release codes and models for momask. Including training/eval/generation scripts.
|
21 |
+
|
22 |
+
📢 **2023-11-29** --- Initialized the webpage and git project.
|
23 |
+
|
24 |
+
|
25 |
+
## :round_pushpin: Get You Ready
|
26 |
+
|
27 |
+
<details>
|
28 |
+
|
29 |
+
### 1. Conda Environment
|
30 |
+
```
|
31 |
+
conda env create -f environment.yml
|
32 |
+
conda activate momask
|
33 |
+
pip install git+https://github.com/openai/CLIP.git
|
34 |
+
```
|
35 |
+
We test our code on Python 3.7.13 and PyTorch 1.7.1
|
36 |
+
|
37 |
+
|
38 |
+
### 2. Models and Dependencies
|
39 |
+
|
40 |
+
#### Download Pre-trained Models
|
41 |
+
```
|
42 |
+
bash prepare/download_models.sh
|
43 |
+
```
|
44 |
+
|
45 |
+
#### Download Evaluation Models and Gloves
|
46 |
+
For evaluation only.
|
47 |
+
```
|
48 |
+
bash prepare/download_evaluator.sh
|
49 |
+
bash prepare/download_glove.sh
|
50 |
+
```
|
51 |
+
|
52 |
+
#### Troubleshooting
|
53 |
+
To address the download error related to gdown: "Cannot retrieve the public link of the file. You may need to change the permission to 'Anyone with the link', or have had many accesses". A potential solution is to run `pip install --upgrade --no-cache-dir gdown`, as suggested on https://github.com/wkentaro/gdown/issues/43. This should help resolve the issue.
|
54 |
+
|
55 |
+
#### (Optional) Download Mannually
|
56 |
+
Visit [[Google Drive]](https://drive.google.com/drive/folders/1b3GnAbERH8jAoO5mdWgZhyxHB73n23sK?usp=drive_link) to download the models and evaluators mannually.
|
57 |
+
|
58 |
+
### 3. Get Data
|
59 |
+
|
60 |
+
You have two options here:
|
61 |
+
* **Skip getting data**, if you just want to generate motions using *own* descriptions.
|
62 |
+
* **Get full data**, if you want to *re-train* and *evaluate* the model.
|
63 |
+
|
64 |
+
**(a). Full data (text + motion)**
|
65 |
+
|
66 |
+
**HumanML3D** - Follow the instruction in [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git), then copy the result dataset to our repository:
|
67 |
+
```
|
68 |
+
cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
|
69 |
+
```
|
70 |
+
**KIT**-Download from [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git), then place result in `./dataset/KIT-ML`
|
71 |
+
|
72 |
+
####
|
73 |
+
|
74 |
+
</details>
|
75 |
+
|
76 |
+
## :rocket: Demo
|
77 |
+
<details>
|
78 |
+
|
79 |
+
### (a) Generate from a single prompt
|
80 |
+
```
|
81 |
+
python gen_t2m.py --gpu_id 1 --ext exp1 --text_prompt "A person is running on a treadmill."
|
82 |
+
```
|
83 |
+
### (b) Generate from a prompt file
|
84 |
+
An example of prompt file is given in `./assets/text_prompt.txt`. Please follow the format of `<text description>#<motion length>` at each line. Motion length indicates the number of poses, which must be integeter and will be rounded by 4. In our work, motion is in 20 fps.
|
85 |
+
|
86 |
+
If you write `<text description>#NA`, our model will determine a length. Note once there is **one** NA, all the others will be **NA** automatically.
|
87 |
+
|
88 |
+
```
|
89 |
+
python gen_t2m.py --gpu_id 1 --ext exp2 --text_path ./assets/text_prompt.txt
|
90 |
+
```
|
91 |
+
|
92 |
+
|
93 |
+
A few more parameters you may be interested:
|
94 |
+
* `--repeat_times`: number of replications for generation, default `1`.
|
95 |
+
* `--motion_length`: specify the number of poses for generation, only applicable in (a).
|
96 |
+
|
97 |
+
The output files are stored under folder `./generation/<ext>/`. They are
|
98 |
+
* `numpy files`: generated motions with shape of (nframe, 22, 3), under subfolder `./joints`.
|
99 |
+
* `video files`: stick figure animation in mp4 format, under subfolder `./animation`.
|
100 |
+
* `bvh files`: bvh files of the generated motion, under subfolder `./animation`.
|
101 |
+
|
102 |
+
We also apply naive foot ik to the generated motions, see files with suffix `_ik`. It sometimes works well, but sometimes will fail.
|
103 |
+
|
104 |
+
</details>
|
105 |
+
|
106 |
+
## :dancers: Visualization
|
107 |
+
<details>
|
108 |
+
|
109 |
+
All the animations are manually rendered in blender. We use the characters from [mixamo](https://www.mixamo.com/#/). You need to download the characters in T-Pose with skeleton.
|
110 |
+
|
111 |
+
### Retargeting
|
112 |
+
For retargeting, we found rokoko usually leads to large error on foot. On the other hand, [keemap.rig.transfer](https://github.com/nkeeline/Keemap-Blender-Rig-ReTargeting-Addon/releases) shows more precise retargetting. You could watch the [tutorial](https://www.youtube.com/watch?v=EG-VCMkVpxg) here.
|
113 |
+
|
114 |
+
Following these steps:
|
115 |
+
* Download keemap.rig.transfer from the github, and install it in blender.
|
116 |
+
* Import both the motion files (.bvh) and character files (.fbx) in blender.
|
117 |
+
* `Shift + Select` the both source and target skeleton. (Do not need to be Rest Position)
|
118 |
+
* Switch to `Pose Mode`, then unfold the `KeeMapRig` tool at the top-right corner of the view window.
|
119 |
+
* Load and read the bone mapping file `./assets/mapping.json`(or `mapping6.json` if it doesn't work). This file is manually made by us. It works for most characters in mixamo. You could make your own.
|
120 |
+
* Adjust the `Number of Samples`, `Source Rig`, `Destination Rig Name`.
|
121 |
+
* Clik `Transfer Animation from Source Destination`, wait a few seconds.
|
122 |
+
|
123 |
+
We didn't tried other retargetting tools. Welcome to comment if you find others are more useful.
|
124 |
+
|
125 |
+
### Scene
|
126 |
+
|
127 |
+
We use this [scene](https://drive.google.com/file/d/1lg62nugD7RTAIz0Q_YP2iZsxpUzzOkT1/view?usp=sharing) for animation.
|
128 |
+
|
129 |
+
|
130 |
+
</details>
|
131 |
+
|
132 |
+
## :clapper: Temporal Inpainting
|
133 |
+
<details>
|
134 |
+
We conduct mask-based editing in the m-transformer stage, followed by the regeneration of residual tokens for the entire sequence. To load your own motion, provide the path through `--source_motion`. Utilize `-msec` to specify the mask section, supporting either ratio or frame index. For instance, `-msec 0.3,0.6` with `max_motion_length=196` is equivalent to `-msec 59,118`, indicating the editing of the frame section [59, 118].
|
135 |
+
|
136 |
+
```
|
137 |
+
python edit_t2m.py --gpu_id 1 --ext exp3 --use_res_model -msec 0.4,0.7 --text_prompt "A man picks something from the ground using his right hand."
|
138 |
+
```
|
139 |
+
|
140 |
+
Note: Presently, the source motion must adhere to the format of a HumanML3D dim-263 feature vector. An example motion vector data from the HumanML3D test set is available in `example_data/000612.npy`. To process your own motion data, you can utilize the `process_file` function from `utils/motion_process.py`.
|
141 |
+
|
142 |
+
</details>
|
143 |
+
|
144 |
+
## :space_invader: Train Your Own Models
|
145 |
+
<details>
|
146 |
+
|
147 |
+
|
148 |
+
**Note**: You have to train RVQ **BEFORE** training masked/residual transformers. The latter two can be trained simultaneously.
|
149 |
+
|
150 |
+
### Train RVQ
|
151 |
+
```
|
152 |
+
python train_vq.py --name rvq_name --gpu_id 1 --dataset_name t2m --batch_size 512 --num_quantizers 6 --max_epoch 500 --quantize_drop_prob 0.2
|
153 |
+
```
|
154 |
+
|
155 |
+
### Train Masked Transformer
|
156 |
+
```
|
157 |
+
python train_t2m_transformer.py --name mtrans_name --gpu_id 2 --dataset_name t2m --batch_size 64 --vq_name rvq_name
|
158 |
+
```
|
159 |
+
|
160 |
+
### Train Residual Transformer
|
161 |
+
```
|
162 |
+
python train_res_transformer.py --name rtrans_name --gpu_id 2 --dataset_name t2m --batch_size 64 --vq_name rvq_name --cond_drop_prob 0.2 --share_weight
|
163 |
+
```
|
164 |
+
|
165 |
+
* `--dataset_name`: motion dataset, `t2m` for HumanML3D and `kit` for KIT-ML.
|
166 |
+
* `--name`: name your model. This will create to model space as `./checkpoints/<dataset_name>/<name>`
|
167 |
+
* `--gpu_id`: GPU id.
|
168 |
+
* `--batch_size`: we use `512` for rvq training. For masked/residual transformer, we use `64` on HumanML3D and `16` for KIT-ML.
|
169 |
+
* `--num_quantizers`: number of quantization layers, `6` is used in our case.
|
170 |
+
* `--quantize_drop_prob`: quantization dropout ratio, `0.2` is used.
|
171 |
+
* `--vq_name`: when training masked/residual transformer, you need to specify the name of rvq model for tokenization.
|
172 |
+
* `--cond_drop_prob`: condition drop ratio, for classifier-free guidance. `0.2` is used.
|
173 |
+
* `--share_weight`: whether to share the projection/embedding weights in residual transformer.
|
174 |
+
|
175 |
+
All the pre-trained models and intermediate results will be saved in space `./checkpoints/<dataset_name>/<name>`.
|
176 |
+
</details>
|
177 |
+
|
178 |
+
## :book: Evaluation
|
179 |
+
<details>
|
180 |
+
|
181 |
+
### Evaluate RVQ Reconstruction:
|
182 |
+
HumanML3D:
|
183 |
+
```
|
184 |
+
python eval_t2m_vq.py --gpu_id 0 --name rvq_nq6_dc512_nc512_noshare_qdp0.2 --dataset_name t2m --ext rvq_nq6
|
185 |
+
|
186 |
+
```
|
187 |
+
KIT-ML:
|
188 |
+
```
|
189 |
+
python eval_t2m_vq.py --gpu_id 0 --name rvq_nq6_dc512_nc512_noshare_qdp0.2_k --dataset_name kit --ext rvq_nq6
|
190 |
+
```
|
191 |
+
|
192 |
+
### Evaluate Text2motion Generation:
|
193 |
+
HumanML3D:
|
194 |
+
```
|
195 |
+
python eval_t2m_trans_res.py --res_name tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw --dataset_name t2m --name t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns --gpu_id 1 --cond_scale 4 --time_steps 10 --ext evaluation
|
196 |
+
```
|
197 |
+
KIT-ML:
|
198 |
+
```
|
199 |
+
python eval_t2m_trans_res.py --res_name tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw_k --dataset_name kit --name t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns_k --gpu_id 0 --cond_scale 2 --time_steps 10 --ext evaluation
|
200 |
+
```
|
201 |
+
|
202 |
+
* `--res_name`: model name of `residual transformer`.
|
203 |
+
* `--name`: model name of `masked transformer`.
|
204 |
+
* `--cond_scale`: scale of classifer-free guidance.
|
205 |
+
* `--time_steps`: number of iterations for inference.
|
206 |
+
* `--ext`: filename for saving evaluation results.
|
207 |
+
|
208 |
+
The final evaluation results will be saved in `./checkpoints/<dataset_name>/<name>/eval/<ext>.log`
|
209 |
+
|
210 |
+
</details>
|
211 |
+
|
212 |
+
## Acknowlegements
|
213 |
+
|
214 |
+
We sincerely thank the open-sourcing of these works where our code is based on:
|
215 |
+
|
216 |
+
[deep-motion-editing](https://github.com/DeepMotionEditing/deep-motion-editing), [Muse](https://github.com/lucidrains/muse-maskgit-pytorch), [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch), [T2M-GPT](https://github.com/Mael-zys/T2M-GPT), [MDM](https://github.com/GuyTevet/motion-diffusion-model/tree/main) and [MLD](https://github.com/ChenFengYe/motion-latent-diffusion/tree/main)
|
217 |
+
|
218 |
+
## License
|
219 |
+
This code is distributed under an [MIT LICENSE](https://github.com/EricGuo5513/momask-codes/tree/main?tab=MIT-1-ov-file#readme).
|
220 |
+
|
221 |
+
Note that our code depends on other libraries, including SMPL, SMPL-X, PyTorch3D, and uses datasets which each have their own respective licenses that must also be followed.
|
app.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import gradio as gr
|
7 |
+
import gdown
|
8 |
+
|
9 |
+
|
10 |
+
WEBSITE = """
|
11 |
+
<div class="embed_hidden">
|
12 |
+
<h1 style='text-align: center'> MoMask: Generative Masked Modeling of 3D Human Motions </h1>
|
13 |
+
<h2 style='text-align: center'>
|
14 |
+
<a href="https://ericguo5513.github.io" target="_blank"><nobr>Chuan Guo*</nobr></a>  
|
15 |
+
<a href="https://yxmu.foo/" target="_blank"><nobr>Yuxuan Mu*</nobr></a>  
|
16 |
+
<a href="https://scholar.google.com/citations?user=w4e-j9sAAAAJ&hl=en" target="_blank"><nobr>Muhammad Gohar Javed*</nobr></a>  
|
17 |
+
<a href="https://sites.google.com/site/senwang1312home/" target="_blank"><nobr>Sen Wang</nobr></a>  
|
18 |
+
<a href="https://www.ece.ualberta.ca/~lcheng5/" target="_blank"><nobr>Li Cheng</nobr></a>
|
19 |
+
</h2>
|
20 |
+
<h2 style='text-align: center'>
|
21 |
+
<nobr>arXiv 2023</nobr>
|
22 |
+
</h2>
|
23 |
+
<h3 style="text-align:center;">
|
24 |
+
<a target="_blank" href="https://arxiv.org/abs/2312.00063"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a>  
|
25 |
+
<a target="_blank" href="https://github.com/EricGuo5513/momask-codes"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a>  
|
26 |
+
<a target="_blank" href="https://ericguo5513.github.io/momask/"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a>  
|
27 |
+
<a target="_blank" href="https://ericguo5513.github.io/source_files/momask_2023_bib.txt"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
|
28 |
+
</h3>
|
29 |
+
<h3> Description </h3>
|
30 |
+
<p>
|
31 |
+
This space illustrates <a href='https://ericguo5513.github.io/momask/' target='_blank'><b>MoMask</b></a>, a method for text-to-motion generation.
|
32 |
+
</p>
|
33 |
+
</div>
|
34 |
+
"""
|
35 |
+
|
36 |
+
EXAMPLES = [
|
37 |
+
"A person is walking slowly",
|
38 |
+
"A person is walking in a circle",
|
39 |
+
"A person is jumping rope",
|
40 |
+
"Someone is doing a backflip",
|
41 |
+
"A person is doing a moonwalk",
|
42 |
+
"A person walks forward and then turns back",
|
43 |
+
"Picking up an object",
|
44 |
+
"A person is swimming in the sea",
|
45 |
+
"A human is squatting",
|
46 |
+
"Someone is jumping with one foot",
|
47 |
+
"A person is chopping vegetables",
|
48 |
+
"Someone walks backward",
|
49 |
+
"Somebody is ascending a staircase",
|
50 |
+
"A person is sitting down",
|
51 |
+
"A person is taking the stairs",
|
52 |
+
"Someone is doing jumping jacks",
|
53 |
+
"The person walked forward and is picking up his toolbox",
|
54 |
+
"The person angrily punching the air",
|
55 |
+
]
|
56 |
+
|
57 |
+
# Show closest text in the training
|
58 |
+
|
59 |
+
|
60 |
+
# css to make videos look nice
|
61 |
+
# var(--block-border-color); TODO
|
62 |
+
CSS = """
|
63 |
+
.retrieved_video {
|
64 |
+
position: relative;
|
65 |
+
margin: 0;
|
66 |
+
box-shadow: var(--block-shadow);
|
67 |
+
border-width: var(--block-border-width);
|
68 |
+
border-color: #000000;
|
69 |
+
border-radius: var(--block-radius);
|
70 |
+
background: var(--block-background-fill);
|
71 |
+
width: 100%;
|
72 |
+
line-height: var(--line-sm);
|
73 |
+
}
|
74 |
+
}
|
75 |
+
"""
|
76 |
+
|
77 |
+
|
78 |
+
DEFAULT_TEXT = "A person is "
|
79 |
+
|
80 |
+
def generate(
|
81 |
+
text, uid, motion_length=0, seed=351540, repeat_times=4,
|
82 |
+
):
|
83 |
+
os.system(f'python gen_t2m.py --gpu_id 0 --seed {seed} --ext {uid} --repeat_times {repeat_times} --motion_length {motion_length} --text_prompt {text}')
|
84 |
+
datas = []
|
85 |
+
for n in repeat_times:
|
86 |
+
data_unit = {
|
87 |
+
"url": f"./generation/{uid}/animations/0/sample0_repeat{n}_len196_ik.mp4"
|
88 |
+
}
|
89 |
+
datas.append(data_unit)
|
90 |
+
return datas
|
91 |
+
|
92 |
+
|
93 |
+
# HTML component
|
94 |
+
def get_video_html(data, video_id, width=700, height=700):
|
95 |
+
url = data["url"]
|
96 |
+
# class="wrap default svelte-gjihhp hide"
|
97 |
+
# <div class="contour_video" style="position: absolute; padding: 10px;">
|
98 |
+
# width="{width}" height="{height}"
|
99 |
+
video_html = f"""
|
100 |
+
<video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
101 |
+
autoplay loop disablepictureinpicture id="{video_id}">
|
102 |
+
<source src="{url}" type="video/mp4">
|
103 |
+
Your browser does not support the video tag.
|
104 |
+
</video>
|
105 |
+
"""
|
106 |
+
return video_html
|
107 |
+
|
108 |
+
|
109 |
+
def generate_component(generate_function, text):
|
110 |
+
if text == DEFAULT_TEXT or text == "" or text is None:
|
111 |
+
return [None for _ in range(4)]
|
112 |
+
|
113 |
+
datas = generate_function(text, )
|
114 |
+
htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
|
115 |
+
return htmls
|
116 |
+
|
117 |
+
|
118 |
+
if not os.path.exists("checkpoints/t2m"):
|
119 |
+
os.system("bash prepare/download_models.sh")
|
120 |
+
|
121 |
+
|
122 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
123 |
+
|
124 |
+
# LOADING
|
125 |
+
|
126 |
+
# DEMO
|
127 |
+
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
128 |
+
generate_and_show = partial(generate_component, generate)
|
129 |
+
|
130 |
+
with gr.Blocks(css=CSS, theme=theme) as demo:
|
131 |
+
gr.Markdown(WEBSITE)
|
132 |
+
videos = []
|
133 |
+
|
134 |
+
with gr.Row():
|
135 |
+
with gr.Column(scale=3):
|
136 |
+
with gr.Column(scale=2):
|
137 |
+
text = gr.Textbox(
|
138 |
+
show_label=True,
|
139 |
+
label="Text prompt",
|
140 |
+
value=DEFAULT_TEXT,
|
141 |
+
)
|
142 |
+
with gr.Column(scale=1):
|
143 |
+
gen_btn = gr.Button("Generate", variant="primary")
|
144 |
+
clear = gr.Button("Clear", variant="secondary")
|
145 |
+
|
146 |
+
with gr.Column(scale=2):
|
147 |
+
|
148 |
+
def generate_example(text):
|
149 |
+
return generate_and_show(text)
|
150 |
+
|
151 |
+
examples = gr.Examples(
|
152 |
+
examples=[[x, None, None] for x in EXAMPLES],
|
153 |
+
inputs=[text],
|
154 |
+
examples_per_page=20,
|
155 |
+
run_on_click=False,
|
156 |
+
cache_examples=False,
|
157 |
+
fn=generate_example,
|
158 |
+
outputs=[],
|
159 |
+
)
|
160 |
+
|
161 |
+
i = -1
|
162 |
+
# should indent
|
163 |
+
for _ in range(1):
|
164 |
+
with gr.Row():
|
165 |
+
for _ in range(4):
|
166 |
+
i += 1
|
167 |
+
video = gr.HTML()
|
168 |
+
videos.append(video)
|
169 |
+
|
170 |
+
# connect the examples to the output
|
171 |
+
# a bit hacky
|
172 |
+
examples.outputs = videos
|
173 |
+
|
174 |
+
def load_example(example_id):
|
175 |
+
processed_example = examples.non_none_processed_examples[example_id]
|
176 |
+
return gr.utils.resolve_singleton(processed_example)
|
177 |
+
|
178 |
+
examples.dataset.click(
|
179 |
+
load_example,
|
180 |
+
inputs=[examples.dataset],
|
181 |
+
outputs=examples.inputs_with_examples, # type: ignore
|
182 |
+
show_progress=False,
|
183 |
+
postprocess=False,
|
184 |
+
queue=False,
|
185 |
+
).then(fn=generate_example, inputs=examples.inputs, outputs=videos)
|
186 |
+
|
187 |
+
gen_btn.click(
|
188 |
+
fn=generate_and_show,
|
189 |
+
inputs=[text],
|
190 |
+
outputs=videos,
|
191 |
+
)
|
192 |
+
text.submit(
|
193 |
+
fn=generate_and_show,
|
194 |
+
inputs=[text],
|
195 |
+
outputs=videos,
|
196 |
+
)
|
197 |
+
|
198 |
+
def clear_videos():
|
199 |
+
return [None for x in range(4)] + [DEFAULT_TEXT]
|
200 |
+
|
201 |
+
clear.click(fn=clear_videos, outputs=videos + [text])
|
202 |
+
|
203 |
+
demo.launch()
|
assets/mapping.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bones": [{"name": "Hips", "label": "", "description": "", "SourceBoneName": "Hips", "DestinationBoneName": "mixamorig:Hips", "keyframe_this_bone": true, "CorrectionFactorX": 2.6179938316345215, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": true, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.2588190734386444, "QuatCorrectionFactorx": 0.965925931930542, "QuatCorrectionFactory": 2.7939677238464355e-09, "QuatCorrectionFactorz": -2.7939677238464355e-09, "scale_secondary_bone_name": ""}, {"name": "RightUpLeg", "label": "", "description": "", "SourceBoneName": "RightUpLeg", "DestinationBoneName": "mixamorig:RightUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftUpLeg", "label": "", "description": "", "SourceBoneName": "LeftUpLeg", "DestinationBoneName": "mixamorig:LeftUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightLeg", "label": "", "description": "", "SourceBoneName": "RightLeg", "DestinationBoneName": "mixamorig:RightLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftLeg", "label": "", "description": "", "SourceBoneName": "LeftLeg", "DestinationBoneName": "mixamorig:LeftLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 3.665191411972046, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightShoulder", "label": "", "description": "", "SourceBoneName": "RightShoulder", "DestinationBoneName": "mixamorig:RightShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftShoulder", "label": "", "description": "", "SourceBoneName": "LeftShoulder", "DestinationBoneName": "mixamorig:LeftShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightArm", "label": "", "description": "", "SourceBoneName": "RightArm", "DestinationBoneName": "mixamorig:RightArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -1.0471975803375244, "CorrectionFactorZ": -0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftArm", "label": "", "description": "", "SourceBoneName": "LeftArm", "DestinationBoneName": "mixamorig:LeftArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.0471975803375244, "CorrectionFactorZ": 0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightForeArm", "label": "", "description": "", "SourceBoneName": "RightForeArm", "DestinationBoneName": "mixamorig:RightForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftForeArm", "label": "", "description": "", "SourceBoneName": "LeftForeArm", "DestinationBoneName": "mixamorig:LeftForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.5707963705062866, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine", "label": "", "description": "", "SourceBoneName": "Spine", "DestinationBoneName": "mixamorig:Spine", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine1", "label": "", "description": "", "SourceBoneName": "Spine1", "DestinationBoneName": "mixamorig:Spine1", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine2", "label": "", "description": "", "SourceBoneName": "Spine2", "DestinationBoneName": "mixamorig:Spine2", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Neck", "label": "", "description": "", "SourceBoneName": "Neck", "DestinationBoneName": "mixamorig:Neck", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Head", "label": "", "description": "", "SourceBoneName": "Head", "DestinationBoneName": "mixamorig:Head", "keyframe_this_bone": true, "CorrectionFactorX": 0.3490658402442932, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightFoot", "label": "", "description": "", "SourceBoneName": "RightFoot", "DestinationBoneName": "mixamorig:RightFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.19192171096801758, "CorrectionFactorY": 2.979980945587158, "CorrectionFactorZ": -0.05134282633662224, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": -0.082771435379982, "QuatCorrectionFactorx": -0.0177358016371727, "QuatCorrectionFactory": -0.9920229315757751, "QuatCorrectionFactorz": -0.09340716898441315, "scale_secondary_bone_name": ""}, {"name": "LeftFoot", "label": "", "description": "", "SourceBoneName": "LeftFoot", "DestinationBoneName": "mixamorig:LeftFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.25592508912086487, "CorrectionFactorY": -2.936899423599243, "CorrectionFactorZ": 0.2450830191373825, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.11609010398387909, "QuatCorrectionFactorx": 0.10766097158193588, "QuatCorrectionFactory": -0.9808290004730225, "QuatCorrectionFactorz": -0.11360746622085571, "scale_secondary_bone_name": ""}], "start_frame_to_apply": 0, "number_of_frames_to_apply": 196, "keyframe_every_n_frames": 1, "source_rig_name": "bvh_batch1_sample30_repeat1_len48", "destination_rig_name": "Armature", "bone_rotation_mode": "EULER", "bone_mapping_file": "C:\\Users\\cguo2\\Documents\\CVPR2024_MoMask\\mapping.json"}
|
assets/mapping6.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bones": [{"name": "Hips", "label": "", "description": "", "SourceBoneName": "Hips", "DestinationBoneName": "mixamorig6:Hips", "keyframe_this_bone": true, "CorrectionFactorX": 2.6179938316345215, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": true, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.2588190734386444, "QuatCorrectionFactorx": 0.965925931930542, "QuatCorrectionFactory": 2.7939677238464355e-09, "QuatCorrectionFactorz": -2.7939677238464355e-09, "scale_secondary_bone_name": ""}, {"name": "RightUpLeg", "label": "", "description": "", "SourceBoneName": "RightUpLeg", "DestinationBoneName": "mixamorig6:RightUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftUpLeg", "label": "", "description": "", "SourceBoneName": "LeftUpLeg", "DestinationBoneName": "mixamorig6:LeftUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightLeg", "label": "", "description": "", "SourceBoneName": "RightLeg", "DestinationBoneName": "mixamorig6:RightLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftLeg", "label": "", "description": "", "SourceBoneName": "LeftLeg", "DestinationBoneName": "mixamorig6:LeftLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 3.665191411972046, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightShoulder", "label": "", "description": "", "SourceBoneName": "RightShoulder", "DestinationBoneName": "mixamorig6:RightShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftShoulder", "label": "", "description": "", "SourceBoneName": "LeftShoulder", "DestinationBoneName": "mixamorig6:LeftShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightArm", "label": "", "description": "", "SourceBoneName": "RightArm", "DestinationBoneName": "mixamorig6:RightArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -1.0471975803375244, "CorrectionFactorZ": -0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftArm", "label": "", "description": "", "SourceBoneName": "LeftArm", "DestinationBoneName": "mixamorig6:LeftArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.0471975803375244, "CorrectionFactorZ": 0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightForeArm", "label": "", "description": "", "SourceBoneName": "RightForeArm", "DestinationBoneName": "mixamorig6:RightForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftForeArm", "label": "", "description": "", "SourceBoneName": "LeftForeArm", "DestinationBoneName": "mixamorig6:LeftForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.5707963705062866, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine", "label": "", "description": "", "SourceBoneName": "Spine", "DestinationBoneName": "mixamorig6:Spine", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine1", "label": "", "description": "", "SourceBoneName": "Spine1", "DestinationBoneName": "mixamorig6:Spine1", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine2", "label": "", "description": "", "SourceBoneName": "Spine2", "DestinationBoneName": "mixamorig6:Spine2", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Neck", "label": "", "description": "", "SourceBoneName": "Neck", "DestinationBoneName": "mixamorig6:Neck", "keyframe_this_bone": true, "CorrectionFactorX": -0.994345486164093, "CorrectionFactorY": -0.006703000050038099, "CorrectionFactorZ": 0.04061730206012726, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.8787809014320374, "QuatCorrectionFactorx": -0.4767816960811615, "QuatCorrectionFactory": -0.01263047568500042, "QuatCorrectionFactorz": 0.016250507906079292, "scale_secondary_bone_name": ""}, {"name": "Head", "label": "", "description": "", "SourceBoneName": "Head", "DestinationBoneName": "mixamorig6:Head", "keyframe_this_bone": true, "CorrectionFactorX": -0.07639937847852707, "CorrectionFactorY": 0.011205507442355156, "CorrectionFactorZ": 0.011367863975465298, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.9992374181747437, "QuatCorrectionFactorx": -0.038221005350351334, "QuatCorrectionFactory": 0.0053814793936908245, "QuatCorrectionFactorz": 0.005893632769584656, "scale_secondary_bone_name": ""}, {"name": "RightFoot", "label": "", "description": "", "SourceBoneName": "RightFoot", "DestinationBoneName": "mixamorig6:RightFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.17194896936416626, "CorrectionFactorY": 2.7372374534606934, "CorrectionFactorZ": -0.029542576521635056, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": -0.20128199458122253, "QuatCorrectionFactorx": 0.002824343740940094, "QuatCorrectionFactory": -0.9761614799499512, "QuatCorrectionFactorz": -0.08115538209676743, "scale_secondary_bone_name": ""}, {"name": "LeftFoot", "label": "", "description": "", "SourceBoneName": "LeftFoot", "DestinationBoneName": "mixamorig6:LeftFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.09363158047199249, "CorrectionFactorY": -2.9336421489715576, "CorrectionFactorZ": -0.17343592643737793, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": -0.09925344586372375, "QuatCorrectionFactorx": 0.09088610112667084, "QuatCorrectionFactory": 0.9893556833267212, "QuatCorrectionFactorz": 0.05535021424293518, "scale_secondary_bone_name": ""}], "start_frame_to_apply": 0, "number_of_frames_to_apply": 196, "keyframe_every_n_frames": 1, "source_rig_name": "MoMask__02_ik", "destination_rig_name": "Armature", "bone_rotation_mode": "EULER", "bone_mapping_file": "C:\\Users\\cguo2\\Documents\\CVPR2024_MoMask\\mapping6.json"}
|
assets/text_prompt.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
the person holds his left foot with his left hand, puts his right foot up and left hand up too.#132
|
2 |
+
a man bends down and picks something up with his left hand.#84
|
3 |
+
A man stands for few seconds and picks up his arms and shakes them.#176
|
4 |
+
A person walks with a limp, their left leg get injured.#192
|
5 |
+
a person jumps up and then lands.#52
|
6 |
+
a person performs a standing back kick.#52
|
7 |
+
A person pokes their right hand along the ground, like they might be planting seeds.#60
|
8 |
+
the person steps forward and uses the left leg to kick something forward.#92
|
9 |
+
the man walked forward, spun right on one foot and walked back to his original position.#92
|
10 |
+
the person was pushed but did not fall.#124
|
11 |
+
this person stumbles left and right while moving forward.#132
|
12 |
+
a person reaching down and picking something up.#148
|
common/__init__.py
ADDED
File without changes
|
common/quaternion.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
12 |
+
|
13 |
+
_FLOAT_EPS = np.finfo(np.float).eps
|
14 |
+
|
15 |
+
# PyTorch-backed implementations
|
16 |
+
def qinv(q):
|
17 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
18 |
+
mask = torch.ones_like(q)
|
19 |
+
mask[..., 1:] = -mask[..., 1:]
|
20 |
+
return q * mask
|
21 |
+
|
22 |
+
|
23 |
+
def qinv_np(q):
|
24 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
25 |
+
return qinv(torch.from_numpy(q).float()).numpy()
|
26 |
+
|
27 |
+
|
28 |
+
def qnormalize(q):
|
29 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
30 |
+
return q / torch.norm(q, dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
|
33 |
+
def qmul(q, r):
|
34 |
+
"""
|
35 |
+
Multiply quaternion(s) q with quaternion(s) r.
|
36 |
+
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
37 |
+
Returns q*r as a tensor of shape (*, 4).
|
38 |
+
"""
|
39 |
+
assert q.shape[-1] == 4
|
40 |
+
assert r.shape[-1] == 4
|
41 |
+
|
42 |
+
original_shape = q.shape
|
43 |
+
|
44 |
+
# Compute outer product
|
45 |
+
terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
|
46 |
+
|
47 |
+
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
48 |
+
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
49 |
+
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
50 |
+
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
51 |
+
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
52 |
+
|
53 |
+
|
54 |
+
def qrot(q, v):
|
55 |
+
"""
|
56 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
57 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
58 |
+
where * denotes any number of dimensions.
|
59 |
+
Returns a tensor of shape (*, 3).
|
60 |
+
"""
|
61 |
+
assert q.shape[-1] == 4
|
62 |
+
assert v.shape[-1] == 3
|
63 |
+
assert q.shape[:-1] == v.shape[:-1]
|
64 |
+
|
65 |
+
original_shape = list(v.shape)
|
66 |
+
# print(q.shape)
|
67 |
+
q = q.contiguous().view(-1, 4)
|
68 |
+
v = v.contiguous().view(-1, 3)
|
69 |
+
|
70 |
+
qvec = q[:, 1:]
|
71 |
+
uv = torch.cross(qvec, v, dim=1)
|
72 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
73 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
74 |
+
|
75 |
+
|
76 |
+
def qeuler(q, order, epsilon=0, deg=True):
|
77 |
+
"""
|
78 |
+
Convert quaternion(s) q to Euler angles.
|
79 |
+
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
|
80 |
+
Returns a tensor of shape (*, 3).
|
81 |
+
"""
|
82 |
+
assert q.shape[-1] == 4
|
83 |
+
|
84 |
+
original_shape = list(q.shape)
|
85 |
+
original_shape[-1] = 3
|
86 |
+
q = q.view(-1, 4)
|
87 |
+
|
88 |
+
q0 = q[:, 0]
|
89 |
+
q1 = q[:, 1]
|
90 |
+
q2 = q[:, 2]
|
91 |
+
q3 = q[:, 3]
|
92 |
+
|
93 |
+
if order == 'xyz':
|
94 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
95 |
+
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
|
96 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
97 |
+
elif order == 'yzx':
|
98 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
99 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
100 |
+
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
|
101 |
+
elif order == 'zxy':
|
102 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
|
103 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
104 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
|
105 |
+
elif order == 'xzy':
|
106 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
107 |
+
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
108 |
+
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
|
109 |
+
elif order == 'yxz':
|
110 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
|
111 |
+
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
|
112 |
+
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
113 |
+
elif order == 'zyx':
|
114 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
115 |
+
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
|
116 |
+
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
117 |
+
else:
|
118 |
+
raise
|
119 |
+
|
120 |
+
if deg:
|
121 |
+
return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
|
122 |
+
else:
|
123 |
+
return torch.stack((x, y, z), dim=1).view(original_shape)
|
124 |
+
|
125 |
+
|
126 |
+
# Numpy-backed implementations
|
127 |
+
|
128 |
+
def qmul_np(q, r):
|
129 |
+
q = torch.from_numpy(q).contiguous().float()
|
130 |
+
r = torch.from_numpy(r).contiguous().float()
|
131 |
+
return qmul(q, r).numpy()
|
132 |
+
|
133 |
+
|
134 |
+
def qrot_np(q, v):
|
135 |
+
q = torch.from_numpy(q).contiguous().float()
|
136 |
+
v = torch.from_numpy(v).contiguous().float()
|
137 |
+
return qrot(q, v).numpy()
|
138 |
+
|
139 |
+
|
140 |
+
def qeuler_np(q, order, epsilon=0, use_gpu=False):
|
141 |
+
if use_gpu:
|
142 |
+
q = torch.from_numpy(q).cuda().float()
|
143 |
+
return qeuler(q, order, epsilon).cpu().numpy()
|
144 |
+
else:
|
145 |
+
q = torch.from_numpy(q).contiguous().float()
|
146 |
+
return qeuler(q, order, epsilon).numpy()
|
147 |
+
|
148 |
+
|
149 |
+
def qfix(q):
|
150 |
+
"""
|
151 |
+
Enforce quaternion continuity across the time dimension by selecting
|
152 |
+
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
153 |
+
between two consecutive frames.
|
154 |
+
|
155 |
+
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
|
156 |
+
Returns a tensor of the same shape.
|
157 |
+
"""
|
158 |
+
assert len(q.shape) == 3
|
159 |
+
assert q.shape[-1] == 4
|
160 |
+
|
161 |
+
result = q.copy()
|
162 |
+
dot_products = np.sum(q[1:] * q[:-1], axis=2)
|
163 |
+
mask = dot_products < 0
|
164 |
+
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
|
165 |
+
result[1:][mask] *= -1
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
def euler2quat(e, order, deg=True):
|
170 |
+
"""
|
171 |
+
Convert Euler angles to quaternions.
|
172 |
+
"""
|
173 |
+
assert e.shape[-1] == 3
|
174 |
+
|
175 |
+
original_shape = list(e.shape)
|
176 |
+
original_shape[-1] = 4
|
177 |
+
|
178 |
+
e = e.view(-1, 3)
|
179 |
+
|
180 |
+
## if euler angles in degrees
|
181 |
+
if deg:
|
182 |
+
e = e * np.pi / 180.
|
183 |
+
|
184 |
+
x = e[:, 0]
|
185 |
+
y = e[:, 1]
|
186 |
+
z = e[:, 2]
|
187 |
+
|
188 |
+
rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
|
189 |
+
ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
|
190 |
+
rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
|
191 |
+
|
192 |
+
result = None
|
193 |
+
for coord in order:
|
194 |
+
if coord == 'x':
|
195 |
+
r = rx
|
196 |
+
elif coord == 'y':
|
197 |
+
r = ry
|
198 |
+
elif coord == 'z':
|
199 |
+
r = rz
|
200 |
+
else:
|
201 |
+
raise
|
202 |
+
if result is None:
|
203 |
+
result = r
|
204 |
+
else:
|
205 |
+
result = qmul(result, r)
|
206 |
+
|
207 |
+
# Reverse antipodal representation to have a non-negative "w"
|
208 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
209 |
+
result *= -1
|
210 |
+
|
211 |
+
return result.view(original_shape)
|
212 |
+
|
213 |
+
|
214 |
+
def expmap_to_quaternion(e):
|
215 |
+
"""
|
216 |
+
Convert axis-angle rotations (aka exponential maps) to quaternions.
|
217 |
+
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
|
218 |
+
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
|
219 |
+
Returns a tensor of shape (*, 4).
|
220 |
+
"""
|
221 |
+
assert e.shape[-1] == 3
|
222 |
+
|
223 |
+
original_shape = list(e.shape)
|
224 |
+
original_shape[-1] = 4
|
225 |
+
e = e.reshape(-1, 3)
|
226 |
+
|
227 |
+
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
|
228 |
+
w = np.cos(0.5 * theta).reshape(-1, 1)
|
229 |
+
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
|
230 |
+
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
|
231 |
+
|
232 |
+
|
233 |
+
def euler_to_quaternion(e, order):
|
234 |
+
"""
|
235 |
+
Convert Euler angles to quaternions.
|
236 |
+
"""
|
237 |
+
assert e.shape[-1] == 3
|
238 |
+
|
239 |
+
original_shape = list(e.shape)
|
240 |
+
original_shape[-1] = 4
|
241 |
+
|
242 |
+
e = e.reshape(-1, 3)
|
243 |
+
|
244 |
+
x = e[:, 0]
|
245 |
+
y = e[:, 1]
|
246 |
+
z = e[:, 2]
|
247 |
+
|
248 |
+
rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
|
249 |
+
ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
|
250 |
+
rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
|
251 |
+
|
252 |
+
result = None
|
253 |
+
for coord in order:
|
254 |
+
if coord == 'x':
|
255 |
+
r = rx
|
256 |
+
elif coord == 'y':
|
257 |
+
r = ry
|
258 |
+
elif coord == 'z':
|
259 |
+
r = rz
|
260 |
+
else:
|
261 |
+
raise
|
262 |
+
if result is None:
|
263 |
+
result = r
|
264 |
+
else:
|
265 |
+
result = qmul_np(result, r)
|
266 |
+
|
267 |
+
# Reverse antipodal representation to have a non-negative "w"
|
268 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
269 |
+
result *= -1
|
270 |
+
|
271 |
+
return result.reshape(original_shape)
|
272 |
+
|
273 |
+
|
274 |
+
def quaternion_to_matrix(quaternions):
|
275 |
+
"""
|
276 |
+
Convert rotations given as quaternions to rotation matrices.
|
277 |
+
Args:
|
278 |
+
quaternions: quaternions with real part first,
|
279 |
+
as tensor of shape (..., 4).
|
280 |
+
Returns:
|
281 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
282 |
+
"""
|
283 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
284 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
285 |
+
|
286 |
+
o = torch.stack(
|
287 |
+
(
|
288 |
+
1 - two_s * (j * j + k * k),
|
289 |
+
two_s * (i * j - k * r),
|
290 |
+
two_s * (i * k + j * r),
|
291 |
+
two_s * (i * j + k * r),
|
292 |
+
1 - two_s * (i * i + k * k),
|
293 |
+
two_s * (j * k - i * r),
|
294 |
+
two_s * (i * k - j * r),
|
295 |
+
two_s * (j * k + i * r),
|
296 |
+
1 - two_s * (i * i + j * j),
|
297 |
+
),
|
298 |
+
-1,
|
299 |
+
)
|
300 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
301 |
+
|
302 |
+
|
303 |
+
def quaternion_to_matrix_np(quaternions):
|
304 |
+
q = torch.from_numpy(quaternions).contiguous().float()
|
305 |
+
return quaternion_to_matrix(q).numpy()
|
306 |
+
|
307 |
+
|
308 |
+
def quaternion_to_cont6d_np(quaternions):
|
309 |
+
rotation_mat = quaternion_to_matrix_np(quaternions)
|
310 |
+
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
|
311 |
+
return cont_6d
|
312 |
+
|
313 |
+
|
314 |
+
def quaternion_to_cont6d(quaternions):
|
315 |
+
rotation_mat = quaternion_to_matrix(quaternions)
|
316 |
+
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
|
317 |
+
return cont_6d
|
318 |
+
|
319 |
+
|
320 |
+
def cont6d_to_matrix(cont6d):
|
321 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
322 |
+
x_raw = cont6d[..., 0:3]
|
323 |
+
y_raw = cont6d[..., 3:6]
|
324 |
+
|
325 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
326 |
+
z = torch.cross(x, y_raw, dim=-1)
|
327 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
328 |
+
|
329 |
+
y = torch.cross(z, x, dim=-1)
|
330 |
+
|
331 |
+
x = x[..., None]
|
332 |
+
y = y[..., None]
|
333 |
+
z = z[..., None]
|
334 |
+
|
335 |
+
mat = torch.cat([x, y, z], dim=-1)
|
336 |
+
return mat
|
337 |
+
|
338 |
+
|
339 |
+
def cont6d_to_matrix_np(cont6d):
|
340 |
+
q = torch.from_numpy(cont6d).contiguous().float()
|
341 |
+
return cont6d_to_matrix(q).numpy()
|
342 |
+
|
343 |
+
|
344 |
+
def qpow(q0, t, dtype=torch.float):
|
345 |
+
''' q0 : tensor of quaternions
|
346 |
+
t: tensor of powers
|
347 |
+
'''
|
348 |
+
q0 = qnormalize(q0)
|
349 |
+
theta0 = torch.acos(q0[..., 0])
|
350 |
+
|
351 |
+
## if theta0 is close to zero, add epsilon to avoid NaNs
|
352 |
+
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
|
353 |
+
theta0 = (1 - mask) * theta0 + mask * 10e-10
|
354 |
+
v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
|
355 |
+
|
356 |
+
if isinstance(t, torch.Tensor):
|
357 |
+
q = torch.zeros(t.shape + q0.shape)
|
358 |
+
theta = t.view(-1, 1) * theta0.view(1, -1)
|
359 |
+
else: ## if t is a number
|
360 |
+
q = torch.zeros(q0.shape)
|
361 |
+
theta = t * theta0
|
362 |
+
|
363 |
+
q[..., 0] = torch.cos(theta)
|
364 |
+
q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
|
365 |
+
|
366 |
+
return q.to(dtype)
|
367 |
+
|
368 |
+
|
369 |
+
def qslerp(q0, q1, t):
|
370 |
+
'''
|
371 |
+
q0: starting quaternion
|
372 |
+
q1: ending quaternion
|
373 |
+
t: array of points along the way
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Tensor of Slerps: t.shape + q0.shape
|
377 |
+
'''
|
378 |
+
|
379 |
+
q0 = qnormalize(q0)
|
380 |
+
q1 = qnormalize(q1)
|
381 |
+
q_ = qpow(qmul(q1, qinv(q0)), t)
|
382 |
+
|
383 |
+
return qmul(q_,
|
384 |
+
q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
|
385 |
+
|
386 |
+
|
387 |
+
def qbetween(v0, v1):
|
388 |
+
'''
|
389 |
+
find the quaternion used to rotate v0 to v1
|
390 |
+
'''
|
391 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
392 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
393 |
+
|
394 |
+
v = torch.cross(v0, v1)
|
395 |
+
w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
|
396 |
+
keepdim=True)
|
397 |
+
return qnormalize(torch.cat([w, v], dim=-1))
|
398 |
+
|
399 |
+
|
400 |
+
def qbetween_np(v0, v1):
|
401 |
+
'''
|
402 |
+
find the quaternion used to rotate v0 to v1
|
403 |
+
'''
|
404 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
405 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
406 |
+
|
407 |
+
v0 = torch.from_numpy(v0).float()
|
408 |
+
v1 = torch.from_numpy(v1).float()
|
409 |
+
return qbetween(v0, v1).numpy()
|
410 |
+
|
411 |
+
|
412 |
+
def lerp(p0, p1, t):
|
413 |
+
if not isinstance(t, torch.Tensor):
|
414 |
+
t = torch.Tensor([t])
|
415 |
+
|
416 |
+
new_shape = t.shape + p0.shape
|
417 |
+
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
|
418 |
+
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
|
419 |
+
p0 = p0.view(new_view_p).expand(new_shape)
|
420 |
+
p1 = p1.view(new_view_p).expand(new_shape)
|
421 |
+
t = t.view(new_view_t).expand(new_shape)
|
422 |
+
|
423 |
+
return p0 + t * (p1 - p0)
|
common/skeleton.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from common.quaternion import *
|
2 |
+
import scipy.ndimage.filters as filters
|
3 |
+
|
4 |
+
class Skeleton(object):
|
5 |
+
def __init__(self, offset, kinematic_tree, device):
|
6 |
+
self.device = device
|
7 |
+
self._raw_offset_np = offset.numpy()
|
8 |
+
self._raw_offset = offset.clone().detach().to(device).float()
|
9 |
+
self._kinematic_tree = kinematic_tree
|
10 |
+
self._offset = None
|
11 |
+
self._parents = [0] * len(self._raw_offset)
|
12 |
+
self._parents[0] = -1
|
13 |
+
for chain in self._kinematic_tree:
|
14 |
+
for j in range(1, len(chain)):
|
15 |
+
self._parents[chain[j]] = chain[j-1]
|
16 |
+
|
17 |
+
def njoints(self):
|
18 |
+
return len(self._raw_offset)
|
19 |
+
|
20 |
+
def offset(self):
|
21 |
+
return self._offset
|
22 |
+
|
23 |
+
def set_offset(self, offsets):
|
24 |
+
self._offset = offsets.clone().detach().to(self.device).float()
|
25 |
+
|
26 |
+
def kinematic_tree(self):
|
27 |
+
return self._kinematic_tree
|
28 |
+
|
29 |
+
def parents(self):
|
30 |
+
return self._parents
|
31 |
+
|
32 |
+
# joints (batch_size, joints_num, 3)
|
33 |
+
def get_offsets_joints_batch(self, joints):
|
34 |
+
assert len(joints.shape) == 3
|
35 |
+
_offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
|
36 |
+
for i in range(1, self._raw_offset.shape[0]):
|
37 |
+
_offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
|
38 |
+
|
39 |
+
self._offset = _offsets.detach()
|
40 |
+
return _offsets
|
41 |
+
|
42 |
+
# joints (joints_num, 3)
|
43 |
+
def get_offsets_joints(self, joints):
|
44 |
+
assert len(joints.shape) == 2
|
45 |
+
_offsets = self._raw_offset.clone()
|
46 |
+
for i in range(1, self._raw_offset.shape[0]):
|
47 |
+
# print(joints.shape)
|
48 |
+
_offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
|
49 |
+
|
50 |
+
self._offset = _offsets.detach()
|
51 |
+
return _offsets
|
52 |
+
|
53 |
+
# face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
|
54 |
+
# joints (batch_size, joints_num, 3)
|
55 |
+
def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
|
56 |
+
assert len(face_joint_idx) == 4
|
57 |
+
'''Get Forward Direction'''
|
58 |
+
l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
|
59 |
+
across1 = joints[:, r_hip] - joints[:, l_hip]
|
60 |
+
across2 = joints[:, sdr_r] - joints[:, sdr_l]
|
61 |
+
across = across1 + across2
|
62 |
+
across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
|
63 |
+
# print(across1.shape, across2.shape)
|
64 |
+
|
65 |
+
# forward (batch_size, 3)
|
66 |
+
forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
|
67 |
+
if smooth_forward:
|
68 |
+
forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
|
69 |
+
# forward (batch_size, 3)
|
70 |
+
forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
|
71 |
+
|
72 |
+
'''Get Root Rotation'''
|
73 |
+
target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
|
74 |
+
root_quat = qbetween_np(forward, target)
|
75 |
+
|
76 |
+
'''Inverse Kinematics'''
|
77 |
+
# quat_params (batch_size, joints_num, 4)
|
78 |
+
# print(joints.shape[:-1])
|
79 |
+
quat_params = np.zeros(joints.shape[:-1] + (4,))
|
80 |
+
# print(quat_params.shape)
|
81 |
+
root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
82 |
+
quat_params[:, 0] = root_quat
|
83 |
+
# quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
84 |
+
for chain in self._kinematic_tree:
|
85 |
+
R = root_quat
|
86 |
+
for j in range(len(chain) - 1):
|
87 |
+
# (batch, 3)
|
88 |
+
u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
|
89 |
+
# print(u.shape)
|
90 |
+
# (batch, 3)
|
91 |
+
v = joints[:, chain[j+1]] - joints[:, chain[j]]
|
92 |
+
v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
|
93 |
+
# print(u.shape, v.shape)
|
94 |
+
rot_u_v = qbetween_np(u, v)
|
95 |
+
|
96 |
+
R_loc = qmul_np(qinv_np(R), rot_u_v)
|
97 |
+
|
98 |
+
quat_params[:,chain[j + 1], :] = R_loc
|
99 |
+
R = qmul_np(R, R_loc)
|
100 |
+
|
101 |
+
return quat_params
|
102 |
+
|
103 |
+
# Be sure root joint is at the beginning of kinematic chains
|
104 |
+
def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
105 |
+
# quat_params (batch_size, joints_num, 4)
|
106 |
+
# joints (batch_size, joints_num, 3)
|
107 |
+
# root_pos (batch_size, 3)
|
108 |
+
if skel_joints is not None:
|
109 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
110 |
+
if len(self._offset.shape) == 2:
|
111 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
112 |
+
joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
|
113 |
+
joints[:, 0] = root_pos
|
114 |
+
for chain in self._kinematic_tree:
|
115 |
+
if do_root_R:
|
116 |
+
R = quat_params[:, 0]
|
117 |
+
else:
|
118 |
+
R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
|
119 |
+
for i in range(1, len(chain)):
|
120 |
+
R = qmul(R, quat_params[:, chain[i]])
|
121 |
+
offset_vec = offsets[:, chain[i]]
|
122 |
+
joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
|
123 |
+
return joints
|
124 |
+
|
125 |
+
# Be sure root joint is at the beginning of kinematic chains
|
126 |
+
def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
127 |
+
# quat_params (batch_size, joints_num, 4)
|
128 |
+
# joints (batch_size, joints_num, 3)
|
129 |
+
# root_pos (batch_size, 3)
|
130 |
+
if skel_joints is not None:
|
131 |
+
skel_joints = torch.from_numpy(skel_joints)
|
132 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
133 |
+
if len(self._offset.shape) == 2:
|
134 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
135 |
+
offsets = offsets.numpy()
|
136 |
+
joints = np.zeros(quat_params.shape[:-1] + (3,))
|
137 |
+
joints[:, 0] = root_pos
|
138 |
+
for chain in self._kinematic_tree:
|
139 |
+
if do_root_R:
|
140 |
+
R = quat_params[:, 0]
|
141 |
+
else:
|
142 |
+
R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
|
143 |
+
for i in range(1, len(chain)):
|
144 |
+
R = qmul_np(R, quat_params[:, chain[i]])
|
145 |
+
offset_vec = offsets[:, chain[i]]
|
146 |
+
joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
|
147 |
+
return joints
|
148 |
+
|
149 |
+
def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
150 |
+
# cont6d_params (batch_size, joints_num, 6)
|
151 |
+
# joints (batch_size, joints_num, 3)
|
152 |
+
# root_pos (batch_size, 3)
|
153 |
+
if skel_joints is not None:
|
154 |
+
skel_joints = torch.from_numpy(skel_joints)
|
155 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
156 |
+
if len(self._offset.shape) == 2:
|
157 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
158 |
+
offsets = offsets.numpy()
|
159 |
+
joints = np.zeros(cont6d_params.shape[:-1] + (3,))
|
160 |
+
joints[:, 0] = root_pos
|
161 |
+
for chain in self._kinematic_tree:
|
162 |
+
if do_root_R:
|
163 |
+
matR = cont6d_to_matrix_np(cont6d_params[:, 0])
|
164 |
+
else:
|
165 |
+
matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
|
166 |
+
for i in range(1, len(chain)):
|
167 |
+
matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
|
168 |
+
offset_vec = offsets[:, chain[i]][..., np.newaxis]
|
169 |
+
# print(matR.shape, offset_vec.shape)
|
170 |
+
joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
171 |
+
return joints
|
172 |
+
|
173 |
+
def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
174 |
+
# cont6d_params (batch_size, joints_num, 6)
|
175 |
+
# joints (batch_size, joints_num, 3)
|
176 |
+
# root_pos (batch_size, 3)
|
177 |
+
if skel_joints is not None:
|
178 |
+
# skel_joints = torch.from_numpy(skel_joints)
|
179 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
180 |
+
if len(self._offset.shape) == 2:
|
181 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
182 |
+
joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
|
183 |
+
joints[..., 0, :] = root_pos
|
184 |
+
for chain in self._kinematic_tree:
|
185 |
+
if do_root_R:
|
186 |
+
matR = cont6d_to_matrix(cont6d_params[:, 0])
|
187 |
+
else:
|
188 |
+
matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
|
189 |
+
for i in range(1, len(chain)):
|
190 |
+
matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
|
191 |
+
offset_vec = offsets[:, chain[i]].unsqueeze(-1)
|
192 |
+
# print(matR.shape, offset_vec.shape)
|
193 |
+
joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
194 |
+
return joints
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
data/__init__.py
ADDED
File without changes
|
data/t2m_dataset.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import join as pjoin
|
2 |
+
import torch
|
3 |
+
from torch.utils import data
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from torch.utils.data._utils.collate import default_collate
|
7 |
+
import random
|
8 |
+
import codecs as cs
|
9 |
+
|
10 |
+
|
11 |
+
def collate_fn(batch):
|
12 |
+
batch.sort(key=lambda x: x[3], reverse=True)
|
13 |
+
return default_collate(batch)
|
14 |
+
|
15 |
+
class MotionDataset(data.Dataset):
|
16 |
+
def __init__(self, opt, mean, std, split_file):
|
17 |
+
self.opt = opt
|
18 |
+
joints_num = opt.joints_num
|
19 |
+
|
20 |
+
self.data = []
|
21 |
+
self.lengths = []
|
22 |
+
id_list = []
|
23 |
+
with open(split_file, 'r') as f:
|
24 |
+
for line in f.readlines():
|
25 |
+
id_list.append(line.strip())
|
26 |
+
|
27 |
+
for name in tqdm(id_list):
|
28 |
+
try:
|
29 |
+
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
|
30 |
+
if motion.shape[0] < opt.window_size:
|
31 |
+
continue
|
32 |
+
self.lengths.append(motion.shape[0] - opt.window_size)
|
33 |
+
self.data.append(motion)
|
34 |
+
except Exception as e:
|
35 |
+
# Some motion may not exist in KIT dataset
|
36 |
+
print(e)
|
37 |
+
pass
|
38 |
+
|
39 |
+
self.cumsum = np.cumsum([0] + self.lengths)
|
40 |
+
|
41 |
+
if opt.is_train:
|
42 |
+
# root_rot_velocity (B, seq_len, 1)
|
43 |
+
std[0:1] = std[0:1] / opt.feat_bias
|
44 |
+
# root_linear_velocity (B, seq_len, 2)
|
45 |
+
std[1:3] = std[1:3] / opt.feat_bias
|
46 |
+
# root_y (B, seq_len, 1)
|
47 |
+
std[3:4] = std[3:4] / opt.feat_bias
|
48 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
49 |
+
std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
|
50 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
51 |
+
std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
|
52 |
+
joints_num - 1) * 9] / 1.0
|
53 |
+
# local_velocity (B, seq_len, joint_num*3)
|
54 |
+
std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
|
55 |
+
4 + (joints_num - 1) * 9: 4 + (
|
56 |
+
joints_num - 1) * 9 + joints_num * 3] / 1.0
|
57 |
+
# foot contact (B, seq_len, 4)
|
58 |
+
std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
|
59 |
+
4 + (
|
60 |
+
joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
|
61 |
+
|
62 |
+
assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
|
63 |
+
np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
|
64 |
+
np.save(pjoin(opt.meta_dir, 'std.npy'), std)
|
65 |
+
|
66 |
+
self.mean = mean
|
67 |
+
self.std = std
|
68 |
+
print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
|
69 |
+
|
70 |
+
def inv_transform(self, data):
|
71 |
+
return data * self.std + self.mean
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return self.cumsum[-1]
|
75 |
+
|
76 |
+
def __getitem__(self, item):
|
77 |
+
if item != 0:
|
78 |
+
motion_id = np.searchsorted(self.cumsum, item) - 1
|
79 |
+
idx = item - self.cumsum[motion_id] - 1
|
80 |
+
else:
|
81 |
+
motion_id = 0
|
82 |
+
idx = 0
|
83 |
+
motion = self.data[motion_id][idx:idx + self.opt.window_size]
|
84 |
+
"Z Normalization"
|
85 |
+
motion = (motion - self.mean) / self.std
|
86 |
+
|
87 |
+
return motion
|
88 |
+
|
89 |
+
|
90 |
+
class Text2MotionDatasetEval(data.Dataset):
|
91 |
+
def __init__(self, opt, mean, std, split_file, w_vectorizer):
|
92 |
+
self.opt = opt
|
93 |
+
self.w_vectorizer = w_vectorizer
|
94 |
+
self.max_length = 20
|
95 |
+
self.pointer = 0
|
96 |
+
self.max_motion_length = opt.max_motion_length
|
97 |
+
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
|
98 |
+
|
99 |
+
data_dict = {}
|
100 |
+
id_list = []
|
101 |
+
with cs.open(split_file, 'r') as f:
|
102 |
+
for line in f.readlines():
|
103 |
+
id_list.append(line.strip())
|
104 |
+
# id_list = id_list[:250]
|
105 |
+
|
106 |
+
new_name_list = []
|
107 |
+
length_list = []
|
108 |
+
for name in tqdm(id_list):
|
109 |
+
try:
|
110 |
+
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
|
111 |
+
if (len(motion)) < min_motion_len or (len(motion) >= 200):
|
112 |
+
continue
|
113 |
+
text_data = []
|
114 |
+
flag = False
|
115 |
+
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
|
116 |
+
for line in f.readlines():
|
117 |
+
text_dict = {}
|
118 |
+
line_split = line.strip().split('#')
|
119 |
+
caption = line_split[0]
|
120 |
+
tokens = line_split[1].split(' ')
|
121 |
+
f_tag = float(line_split[2])
|
122 |
+
to_tag = float(line_split[3])
|
123 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
124 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
125 |
+
|
126 |
+
text_dict['caption'] = caption
|
127 |
+
text_dict['tokens'] = tokens
|
128 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
129 |
+
flag = True
|
130 |
+
text_data.append(text_dict)
|
131 |
+
else:
|
132 |
+
try:
|
133 |
+
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
|
134 |
+
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
|
135 |
+
continue
|
136 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
137 |
+
while new_name in data_dict:
|
138 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
139 |
+
data_dict[new_name] = {'motion': n_motion,
|
140 |
+
'length': len(n_motion),
|
141 |
+
'text':[text_dict]}
|
142 |
+
new_name_list.append(new_name)
|
143 |
+
length_list.append(len(n_motion))
|
144 |
+
except:
|
145 |
+
print(line_split)
|
146 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
147 |
+
# break
|
148 |
+
|
149 |
+
if flag:
|
150 |
+
data_dict[name] = {'motion': motion,
|
151 |
+
'length': len(motion),
|
152 |
+
'text': text_data}
|
153 |
+
new_name_list.append(name)
|
154 |
+
length_list.append(len(motion))
|
155 |
+
except:
|
156 |
+
pass
|
157 |
+
|
158 |
+
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
159 |
+
|
160 |
+
self.mean = mean
|
161 |
+
self.std = std
|
162 |
+
self.length_arr = np.array(length_list)
|
163 |
+
self.data_dict = data_dict
|
164 |
+
self.name_list = name_list
|
165 |
+
self.reset_max_len(self.max_length)
|
166 |
+
|
167 |
+
def reset_max_len(self, length):
|
168 |
+
assert length <= self.max_motion_length
|
169 |
+
self.pointer = np.searchsorted(self.length_arr, length)
|
170 |
+
print("Pointer Pointing at %d"%self.pointer)
|
171 |
+
self.max_length = length
|
172 |
+
|
173 |
+
def inv_transform(self, data):
|
174 |
+
return data * self.std + self.mean
|
175 |
+
|
176 |
+
def __len__(self):
|
177 |
+
return len(self.data_dict) - self.pointer
|
178 |
+
|
179 |
+
def __getitem__(self, item):
|
180 |
+
idx = self.pointer + item
|
181 |
+
data = self.data_dict[self.name_list[idx]]
|
182 |
+
motion, m_length, text_list = data['motion'], data['length'], data['text']
|
183 |
+
# Randomly select a caption
|
184 |
+
text_data = random.choice(text_list)
|
185 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
186 |
+
|
187 |
+
if len(tokens) < self.opt.max_text_len:
|
188 |
+
# pad with "unk"
|
189 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
190 |
+
sent_len = len(tokens)
|
191 |
+
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
|
192 |
+
else:
|
193 |
+
# crop
|
194 |
+
tokens = tokens[:self.opt.max_text_len]
|
195 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
196 |
+
sent_len = len(tokens)
|
197 |
+
pos_one_hots = []
|
198 |
+
word_embeddings = []
|
199 |
+
for token in tokens:
|
200 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
201 |
+
pos_one_hots.append(pos_oh[None, :])
|
202 |
+
word_embeddings.append(word_emb[None, :])
|
203 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
204 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
205 |
+
|
206 |
+
if self.opt.unit_length < 10:
|
207 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
208 |
+
else:
|
209 |
+
coin2 = 'single'
|
210 |
+
|
211 |
+
if coin2 == 'double':
|
212 |
+
m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
|
213 |
+
elif coin2 == 'single':
|
214 |
+
m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
|
215 |
+
idx = random.randint(0, len(motion) - m_length)
|
216 |
+
motion = motion[idx:idx+m_length]
|
217 |
+
|
218 |
+
"Z Normalization"
|
219 |
+
motion = (motion - self.mean) / self.std
|
220 |
+
|
221 |
+
if m_length < self.max_motion_length:
|
222 |
+
motion = np.concatenate([motion,
|
223 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
|
224 |
+
], axis=0)
|
225 |
+
# print(word_embeddings.shape, motion.shape)
|
226 |
+
# print(tokens)
|
227 |
+
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
|
228 |
+
|
229 |
+
|
230 |
+
class Text2MotionDataset(data.Dataset):
|
231 |
+
def __init__(self, opt, mean, std, split_file):
|
232 |
+
self.opt = opt
|
233 |
+
self.max_length = 20
|
234 |
+
self.pointer = 0
|
235 |
+
self.max_motion_length = opt.max_motion_length
|
236 |
+
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
|
237 |
+
|
238 |
+
data_dict = {}
|
239 |
+
id_list = []
|
240 |
+
with cs.open(split_file, 'r') as f:
|
241 |
+
for line in f.readlines():
|
242 |
+
id_list.append(line.strip())
|
243 |
+
# id_list = id_list[:250]
|
244 |
+
|
245 |
+
new_name_list = []
|
246 |
+
length_list = []
|
247 |
+
for name in tqdm(id_list):
|
248 |
+
try:
|
249 |
+
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
|
250 |
+
if (len(motion)) < min_motion_len or (len(motion) >= 200):
|
251 |
+
continue
|
252 |
+
text_data = []
|
253 |
+
flag = False
|
254 |
+
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
|
255 |
+
for line in f.readlines():
|
256 |
+
text_dict = {}
|
257 |
+
line_split = line.strip().split('#')
|
258 |
+
# print(line)
|
259 |
+
caption = line_split[0]
|
260 |
+
tokens = line_split[1].split(' ')
|
261 |
+
f_tag = float(line_split[2])
|
262 |
+
to_tag = float(line_split[3])
|
263 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
264 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
265 |
+
|
266 |
+
text_dict['caption'] = caption
|
267 |
+
text_dict['tokens'] = tokens
|
268 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
269 |
+
flag = True
|
270 |
+
text_data.append(text_dict)
|
271 |
+
else:
|
272 |
+
try:
|
273 |
+
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
|
274 |
+
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
|
275 |
+
continue
|
276 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
277 |
+
while new_name in data_dict:
|
278 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
279 |
+
data_dict[new_name] = {'motion': n_motion,
|
280 |
+
'length': len(n_motion),
|
281 |
+
'text':[text_dict]}
|
282 |
+
new_name_list.append(new_name)
|
283 |
+
length_list.append(len(n_motion))
|
284 |
+
except:
|
285 |
+
print(line_split)
|
286 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
287 |
+
# break
|
288 |
+
|
289 |
+
if flag:
|
290 |
+
data_dict[name] = {'motion': motion,
|
291 |
+
'length': len(motion),
|
292 |
+
'text': text_data}
|
293 |
+
new_name_list.append(name)
|
294 |
+
length_list.append(len(motion))
|
295 |
+
except Exception as e:
|
296 |
+
# print(e)
|
297 |
+
pass
|
298 |
+
|
299 |
+
# name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
300 |
+
name_list, length_list = new_name_list, length_list
|
301 |
+
|
302 |
+
self.mean = mean
|
303 |
+
self.std = std
|
304 |
+
self.length_arr = np.array(length_list)
|
305 |
+
self.data_dict = data_dict
|
306 |
+
self.name_list = name_list
|
307 |
+
|
308 |
+
def inv_transform(self, data):
|
309 |
+
return data * self.std + self.mean
|
310 |
+
|
311 |
+
def __len__(self):
|
312 |
+
return len(self.data_dict) - self.pointer
|
313 |
+
|
314 |
+
def __getitem__(self, item):
|
315 |
+
idx = self.pointer + item
|
316 |
+
data = self.data_dict[self.name_list[idx]]
|
317 |
+
motion, m_length, text_list = data['motion'], data['length'], data['text']
|
318 |
+
# Randomly select a caption
|
319 |
+
text_data = random.choice(text_list)
|
320 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
321 |
+
|
322 |
+
if self.opt.unit_length < 10:
|
323 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
324 |
+
else:
|
325 |
+
coin2 = 'single'
|
326 |
+
|
327 |
+
if coin2 == 'double':
|
328 |
+
m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
|
329 |
+
elif coin2 == 'single':
|
330 |
+
m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
|
331 |
+
idx = random.randint(0, len(motion) - m_length)
|
332 |
+
motion = motion[idx:idx+m_length]
|
333 |
+
|
334 |
+
"Z Normalization"
|
335 |
+
motion = (motion - self.mean) / self.std
|
336 |
+
|
337 |
+
if m_length < self.max_motion_length:
|
338 |
+
motion = np.concatenate([motion,
|
339 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
|
340 |
+
], axis=0)
|
341 |
+
# print(word_embeddings.shape, motion.shape)
|
342 |
+
# print(tokens)
|
343 |
+
return caption, motion, m_length
|
344 |
+
|
345 |
+
def reset_min_len(self, length):
|
346 |
+
assert length <= self.max_motion_length
|
347 |
+
self.pointer = np.searchsorted(self.length_arr, length)
|
348 |
+
print("Pointer Pointing at %d" % self.pointer)
|
dataset/__init__.py
ADDED
File without changes
|
edit_t2m.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join as pjoin
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
|
8 |
+
from models.vq.model import RVQVAE, LengthEstimator
|
9 |
+
|
10 |
+
from options.eval_option import EvalT2MOptions
|
11 |
+
from utils.get_opt import get_opt
|
12 |
+
|
13 |
+
from utils.fixseed import fixseed
|
14 |
+
from visualization.joints2bvh import Joint2BVHConvertor
|
15 |
+
|
16 |
+
from utils.motion_process import recover_from_ric
|
17 |
+
from utils.plot_script import plot_3d_motion
|
18 |
+
|
19 |
+
from utils.paramUtil import t2m_kinematic_chain
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from gen_t2m import load_vq_model, load_res_model, load_trans_model
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
parser = EvalT2MOptions()
|
27 |
+
opt = parser.parse()
|
28 |
+
fixseed(opt.seed)
|
29 |
+
|
30 |
+
opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
|
31 |
+
torch.autograd.set_detect_anomaly(True)
|
32 |
+
|
33 |
+
dim_pose = 251 if opt.dataset_name == 'kit' else 263
|
34 |
+
|
35 |
+
root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
36 |
+
model_dir = pjoin(root_dir, 'model')
|
37 |
+
result_dir = pjoin('./editing', opt.ext)
|
38 |
+
joints_dir = pjoin(result_dir, 'joints')
|
39 |
+
animation_dir = pjoin(result_dir, 'animations')
|
40 |
+
os.makedirs(joints_dir, exist_ok=True)
|
41 |
+
os.makedirs(animation_dir,exist_ok=True)
|
42 |
+
|
43 |
+
model_opt_path = pjoin(root_dir, 'opt.txt')
|
44 |
+
model_opt = get_opt(model_opt_path, device=opt.device)
|
45 |
+
|
46 |
+
#######################
|
47 |
+
######Loading RVQ######
|
48 |
+
#######################
|
49 |
+
vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
|
50 |
+
vq_opt = get_opt(vq_opt_path, device=opt.device)
|
51 |
+
vq_opt.dim_pose = dim_pose
|
52 |
+
vq_model, vq_opt = load_vq_model(vq_opt)
|
53 |
+
|
54 |
+
model_opt.num_tokens = vq_opt.nb_code
|
55 |
+
model_opt.num_quantizers = vq_opt.num_quantizers
|
56 |
+
model_opt.code_dim = vq_opt.code_dim
|
57 |
+
|
58 |
+
#################################
|
59 |
+
######Loading R-Transformer######
|
60 |
+
#################################
|
61 |
+
res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
|
62 |
+
res_opt = get_opt(res_opt_path, device=opt.device)
|
63 |
+
res_model = load_res_model(res_opt, vq_opt, opt)
|
64 |
+
|
65 |
+
assert res_opt.vq_name == model_opt.vq_name
|
66 |
+
|
67 |
+
#################################
|
68 |
+
######Loading M-Transformer######
|
69 |
+
#################################
|
70 |
+
t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
|
71 |
+
|
72 |
+
t2m_transformer.eval()
|
73 |
+
vq_model.eval()
|
74 |
+
res_model.eval()
|
75 |
+
|
76 |
+
res_model.to(opt.device)
|
77 |
+
t2m_transformer.to(opt.device)
|
78 |
+
vq_model.to(opt.device)
|
79 |
+
|
80 |
+
##### ---- Data ---- #####
|
81 |
+
max_motion_length = 196
|
82 |
+
mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
|
83 |
+
std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
|
84 |
+
def inv_transform(data):
|
85 |
+
return data * std + mean
|
86 |
+
### We provided an example source motion (from 'new_joint_vecs') for editing. See './example_data/000612.mp4'###
|
87 |
+
motion = np.load(opt.source_motion)
|
88 |
+
m_length = len(motion)
|
89 |
+
motion = (motion - mean) / std
|
90 |
+
if max_motion_length > m_length:
|
91 |
+
motion = np.concatenate([motion, np.zeros((max_motion_length - m_length, motion.shape[1])) ], axis=0)
|
92 |
+
motion = torch.from_numpy(motion)[None].to(opt.device)
|
93 |
+
|
94 |
+
prompt_list = []
|
95 |
+
length_list = []
|
96 |
+
if opt.motion_length == 0:
|
97 |
+
opt.motion_length = m_length
|
98 |
+
print("Using default motion length.")
|
99 |
+
|
100 |
+
prompt_list.append(opt.text_prompt)
|
101 |
+
length_list.append(opt.motion_length)
|
102 |
+
if opt.text_prompt == "":
|
103 |
+
raise "Using an empty text prompt."
|
104 |
+
|
105 |
+
token_lens = torch.LongTensor(length_list) // 4
|
106 |
+
token_lens = token_lens.to(opt.device).long()
|
107 |
+
|
108 |
+
m_length = token_lens * 4
|
109 |
+
captions = prompt_list
|
110 |
+
print_captions = captions[0]
|
111 |
+
|
112 |
+
_edit_slice = opt.mask_edit_section
|
113 |
+
edit_slice = []
|
114 |
+
for eds in _edit_slice:
|
115 |
+
_start, _end = eds.split(',')
|
116 |
+
_start = eval(_start)
|
117 |
+
_end = eval(_end)
|
118 |
+
edit_slice.append([_start, _end])
|
119 |
+
|
120 |
+
sample = 0
|
121 |
+
kinematic_chain = t2m_kinematic_chain
|
122 |
+
converter = Joint2BVHConvertor()
|
123 |
+
|
124 |
+
with torch.no_grad():
|
125 |
+
tokens, features = vq_model.encode(motion)
|
126 |
+
### build editing mask, TOEDIT marked as 1 ###
|
127 |
+
edit_mask = torch.zeros_like(tokens[..., 0])
|
128 |
+
seq_len = tokens.shape[1]
|
129 |
+
for _start, _end in edit_slice:
|
130 |
+
if isinstance(_start, float):
|
131 |
+
_start = int(_start*seq_len)
|
132 |
+
_end = int(_end*seq_len)
|
133 |
+
else:
|
134 |
+
_start //= 4
|
135 |
+
_end //= 4
|
136 |
+
edit_mask[:, _start: _end] = 1
|
137 |
+
print_captions = f'{print_captions} [{_start*4/20.}s - {_end*4/20.}s]'
|
138 |
+
edit_mask = edit_mask.bool()
|
139 |
+
for r in range(opt.repeat_times):
|
140 |
+
print("-->Repeat %d"%r)
|
141 |
+
with torch.no_grad():
|
142 |
+
mids = t2m_transformer.edit(
|
143 |
+
captions, tokens[..., 0].clone(), m_length//4,
|
144 |
+
timesteps=opt.time_steps,
|
145 |
+
cond_scale=opt.cond_scale,
|
146 |
+
temperature=opt.temperature,
|
147 |
+
topk_filter_thres=opt.topkr,
|
148 |
+
gsample=opt.gumbel_sample,
|
149 |
+
force_mask=opt.force_mask,
|
150 |
+
edit_mask=edit_mask.clone(),
|
151 |
+
)
|
152 |
+
if opt.use_res_model:
|
153 |
+
mids = res_model.generate(mids, captions, m_length//4, temperature=1, cond_scale=2)
|
154 |
+
else:
|
155 |
+
mids.unsqueeze_(-1)
|
156 |
+
|
157 |
+
pred_motions = vq_model.forward_decoder(mids)
|
158 |
+
|
159 |
+
pred_motions = pred_motions.detach().cpu().numpy()
|
160 |
+
|
161 |
+
source_motions = motion.detach().cpu().numpy()
|
162 |
+
|
163 |
+
data = inv_transform(pred_motions)
|
164 |
+
source_data = inv_transform(source_motions)
|
165 |
+
|
166 |
+
for k, (caption, joint_data, source_data) in enumerate(zip(captions, data, source_data)):
|
167 |
+
print("---->Sample %d: %s %d"%(k, caption, m_length[k]))
|
168 |
+
animation_path = pjoin(animation_dir, str(k))
|
169 |
+
joint_path = pjoin(joints_dir, str(k))
|
170 |
+
|
171 |
+
os.makedirs(animation_path, exist_ok=True)
|
172 |
+
os.makedirs(joint_path, exist_ok=True)
|
173 |
+
|
174 |
+
joint_data = joint_data[:m_length[k]]
|
175 |
+
joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()
|
176 |
+
|
177 |
+
source_data = source_data[:m_length[k]]
|
178 |
+
soucre_joint = recover_from_ric(torch.from_numpy(source_data).float(), 22).numpy()
|
179 |
+
|
180 |
+
bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))
|
181 |
+
_, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)
|
182 |
+
|
183 |
+
bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))
|
184 |
+
_, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)
|
185 |
+
|
186 |
+
|
187 |
+
save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))
|
188 |
+
ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))
|
189 |
+
source_save_path = pjoin(animation_path, "sample%d_source_len%d.mp4"%(k, m_length[k]))
|
190 |
+
|
191 |
+
plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=print_captions, fps=20)
|
192 |
+
plot_3d_motion(save_path, kinematic_chain, joint, title=print_captions, fps=20)
|
193 |
+
plot_3d_motion(source_save_path, kinematic_chain, soucre_joint, title='None', fps=20)
|
194 |
+
np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)
|
195 |
+
np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)
|
environment.yml
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: momask
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- anaconda
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- _libgcc_mutex=0.1=main
|
9 |
+
- _openmp_mutex=5.1=1_gnu
|
10 |
+
- absl-py=1.4.0=pyhd8ed1ab_0
|
11 |
+
- aiohttp=3.8.3=py37h5eee18b_0
|
12 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
13 |
+
- argon2-cffi=21.3.0=pyhd3eb1b0_0
|
14 |
+
- argon2-cffi-bindings=21.2.0=py37h7f8727e_0
|
15 |
+
- async-timeout=4.0.2=py37h06a4308_0
|
16 |
+
- asynctest=0.13.0=py_0
|
17 |
+
- attrs=22.1.0=py37h06a4308_0
|
18 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
19 |
+
- beautifulsoup4=4.11.1=pyha770c72_0
|
20 |
+
- blas=1.0=mkl
|
21 |
+
- bleach=4.1.0=pyhd3eb1b0_0
|
22 |
+
- blinker=1.4=py37h06a4308_0
|
23 |
+
- brotlipy=0.7.0=py37h540881e_1004
|
24 |
+
- c-ares=1.19.0=h5eee18b_0
|
25 |
+
- ca-certificates=2023.05.30=h06a4308_0
|
26 |
+
- catalogue=2.0.8=py37h89c1867_0
|
27 |
+
- certifi=2022.12.7=py37h06a4308_0
|
28 |
+
- cffi=1.15.1=py37h74dc2b5_0
|
29 |
+
- charset-normalizer=2.1.1=pyhd8ed1ab_0
|
30 |
+
- click=8.0.4=py37h89c1867_0
|
31 |
+
- colorama=0.4.5=pyhd8ed1ab_0
|
32 |
+
- cryptography=35.0.0=py37hf1a17b8_2
|
33 |
+
- cudatoolkit=11.0.221=h6bb024c_0
|
34 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
35 |
+
- cymem=2.0.6=py37hd23a5d3_3
|
36 |
+
- cython-blis=0.7.7=py37hda87dfa_1
|
37 |
+
- dataclasses=0.8=pyhc8e2a94_3
|
38 |
+
- dbus=1.13.18=hb2f20db_0
|
39 |
+
- debugpy=1.5.1=py37h295c915_0
|
40 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
41 |
+
- defusedxml=0.7.1=pyhd3eb1b0_0
|
42 |
+
- entrypoints=0.4=py37h06a4308_0
|
43 |
+
- expat=2.4.9=h6a678d5_0
|
44 |
+
- fftw=3.3.9=h27cfd23_1
|
45 |
+
- filelock=3.8.0=pyhd8ed1ab_0
|
46 |
+
- fontconfig=2.13.1=h6c09931_0
|
47 |
+
- freetype=2.11.0=h70c0345_0
|
48 |
+
- frozenlist=1.3.3=py37h5eee18b_0
|
49 |
+
- giflib=5.2.1=h7b6447c_0
|
50 |
+
- glib=2.69.1=h4ff587b_1
|
51 |
+
- gst-plugins-base=1.14.0=h8213a91_2
|
52 |
+
- gstreamer=1.14.0=h28cd5cc_2
|
53 |
+
- h5py=3.7.0=py37h737f45e_0
|
54 |
+
- hdf5=1.10.6=h3ffc7dd_1
|
55 |
+
- icu=58.2=he6710b0_3
|
56 |
+
- idna=3.4=pyhd8ed1ab_0
|
57 |
+
- importlib-metadata=4.11.4=py37h89c1867_0
|
58 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
59 |
+
- ipykernel=6.15.2=py37h06a4308_0
|
60 |
+
- ipython=7.31.1=py37h06a4308_1
|
61 |
+
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
62 |
+
- jedi=0.18.1=py37h06a4308_1
|
63 |
+
- jinja2=3.1.2=pyhd8ed1ab_1
|
64 |
+
- joblib=1.1.0=pyhd3eb1b0_0
|
65 |
+
- jpeg=9b=h024ee3a_2
|
66 |
+
- jsonschema=3.0.2=py37_0
|
67 |
+
- jupyter_client=7.4.9=py37h06a4308_0
|
68 |
+
- jupyter_core=4.11.2=py37h06a4308_0
|
69 |
+
- jupyterlab_pygments=0.1.2=py_0
|
70 |
+
- kiwisolver=1.4.2=py37h295c915_0
|
71 |
+
- langcodes=3.3.0=pyhd8ed1ab_0
|
72 |
+
- lcms2=2.12=h3be6417_0
|
73 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
74 |
+
- libffi=3.3=he6710b0_2
|
75 |
+
- libgcc-ng=11.2.0=h1234567_1
|
76 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
77 |
+
- libgfortran5=11.2.0=h1234567_1
|
78 |
+
- libgomp=11.2.0=h1234567_1
|
79 |
+
- libpng=1.6.37=hbc83047_0
|
80 |
+
- libprotobuf=3.15.8=h780b84a_1
|
81 |
+
- libsodium=1.0.18=h7b6447c_0
|
82 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
83 |
+
- libtiff=4.1.0=h2733197_1
|
84 |
+
- libuuid=1.0.3=h7f8727e_2
|
85 |
+
- libuv=1.40.0=h7b6447c_0
|
86 |
+
- libwebp=1.2.0=h89dd481_0
|
87 |
+
- libxcb=1.15=h7f8727e_0
|
88 |
+
- libxml2=2.9.14=h74e7548_0
|
89 |
+
- lz4-c=1.9.3=h295c915_1
|
90 |
+
- markdown=3.4.3=pyhd8ed1ab_0
|
91 |
+
- markupsafe=2.1.1=py37h540881e_1
|
92 |
+
- matplotlib=3.1.3=py37_0
|
93 |
+
- matplotlib-base=3.1.3=py37hef1b27d_0
|
94 |
+
- matplotlib-inline=0.1.6=py37h06a4308_0
|
95 |
+
- mistune=0.8.4=py37h14c3975_1001
|
96 |
+
- mkl=2021.4.0=h06a4308_640
|
97 |
+
- mkl-service=2.4.0=py37h7f8727e_0
|
98 |
+
- mkl_fft=1.3.1=py37hd3c417c_0
|
99 |
+
- mkl_random=1.2.2=py37h51133e4_0
|
100 |
+
- multidict=6.0.2=py37h5eee18b_0
|
101 |
+
- murmurhash=1.0.7=py37hd23a5d3_0
|
102 |
+
- nb_conda_kernels=2.3.1=py37h06a4308_0
|
103 |
+
- nbclient=0.5.13=py37h06a4308_0
|
104 |
+
- nbconvert=6.4.4=py37h06a4308_0
|
105 |
+
- nbformat=5.5.0=py37h06a4308_0
|
106 |
+
- ncurses=6.3=h5eee18b_3
|
107 |
+
- nest-asyncio=1.5.6=py37h06a4308_0
|
108 |
+
- ninja=1.10.2=h06a4308_5
|
109 |
+
- ninja-base=1.10.2=hd09550d_5
|
110 |
+
- notebook=6.4.12=py37h06a4308_0
|
111 |
+
- numpy=1.21.5=py37h6c91a56_3
|
112 |
+
- numpy-base=1.21.5=py37ha15fc14_3
|
113 |
+
- openssl=1.1.1v=h7f8727e_0
|
114 |
+
- packaging=21.3=pyhd8ed1ab_0
|
115 |
+
- pandocfilters=1.5.0=pyhd3eb1b0_0
|
116 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
117 |
+
- pathy=0.6.2=pyhd8ed1ab_0
|
118 |
+
- pcre=8.45=h295c915_0
|
119 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
120 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
121 |
+
- pillow=9.2.0=py37hace64e9_1
|
122 |
+
- pip=22.2.2=py37h06a4308_0
|
123 |
+
- preshed=3.0.6=py37hd23a5d3_2
|
124 |
+
- prometheus_client=0.14.1=py37h06a4308_0
|
125 |
+
- prompt-toolkit=3.0.36=py37h06a4308_0
|
126 |
+
- psutil=5.9.0=py37h5eee18b_0
|
127 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
128 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
129 |
+
- pydantic=1.8.2=py37h5e8e339_2
|
130 |
+
- pygments=2.11.2=pyhd3eb1b0_0
|
131 |
+
- pyjwt=2.4.0=py37h06a4308_0
|
132 |
+
- pyopenssl=22.0.0=pyhd8ed1ab_1
|
133 |
+
- pyparsing=3.0.9=py37h06a4308_0
|
134 |
+
- pyqt=5.9.2=py37h05f1152_2
|
135 |
+
- pyrsistent=0.18.0=py37heee7806_0
|
136 |
+
- pysocks=1.7.1=py37h89c1867_5
|
137 |
+
- python=3.7.13=h12debd9_0
|
138 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
139 |
+
- python-fastjsonschema=2.16.2=py37h06a4308_0
|
140 |
+
- python_abi=3.7=2_cp37m
|
141 |
+
- pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0
|
142 |
+
- pyzmq=23.2.0=py37h6a678d5_0
|
143 |
+
- qt=5.9.7=h5867ecd_1
|
144 |
+
- readline=8.1.2=h7f8727e_1
|
145 |
+
- requests=2.28.1=pyhd8ed1ab_1
|
146 |
+
- scikit-learn=1.0.2=py37h51133e4_1
|
147 |
+
- scipy=1.7.3=py37h6c91a56_2
|
148 |
+
- send2trash=1.8.0=pyhd3eb1b0_1
|
149 |
+
- setuptools=63.4.1=py37h06a4308_0
|
150 |
+
- shellingham=1.5.0=pyhd8ed1ab_0
|
151 |
+
- sip=4.19.8=py37hf484d3e_0
|
152 |
+
- six=1.16.0=pyhd3eb1b0_1
|
153 |
+
- smart_open=5.2.1=pyhd8ed1ab_0
|
154 |
+
- soupsieve=2.3.2.post1=pyhd8ed1ab_0
|
155 |
+
- spacy=3.3.1=py37h79cecc1_0
|
156 |
+
- spacy-legacy=3.0.10=pyhd8ed1ab_0
|
157 |
+
- spacy-loggers=1.0.3=pyhd8ed1ab_0
|
158 |
+
- sqlite=3.39.3=h5082296_0
|
159 |
+
- srsly=2.4.3=py37hd23a5d3_1
|
160 |
+
- tensorboard-plugin-wit=1.8.1=py37h06a4308_0
|
161 |
+
- terminado=0.17.1=py37h06a4308_0
|
162 |
+
- testpath=0.6.0=py37h06a4308_0
|
163 |
+
- thinc=8.0.15=py37h48bf904_0
|
164 |
+
- threadpoolctl=2.2.0=pyh0d69192_0
|
165 |
+
- tk=8.6.12=h1ccaba5_0
|
166 |
+
- torchaudio=0.7.2=py37
|
167 |
+
- torchvision=0.8.2=py37_cu110
|
168 |
+
- tornado=6.2=py37h5eee18b_0
|
169 |
+
- tqdm=4.64.1=py37h06a4308_0
|
170 |
+
- traitlets=5.7.1=py37h06a4308_0
|
171 |
+
- trimesh=3.15.3=pyh1a96a4e_0
|
172 |
+
- typer=0.4.2=pyhd8ed1ab_0
|
173 |
+
- typing-extensions=3.10.0.2=hd8ed1ab_0
|
174 |
+
- typing_extensions=3.10.0.2=pyha770c72_0
|
175 |
+
- urllib3=1.26.15=pyhd8ed1ab_0
|
176 |
+
- wasabi=0.10.1=pyhd8ed1ab_1
|
177 |
+
- webencodings=0.5.1=py37_1
|
178 |
+
- werkzeug=2.2.3=pyhd8ed1ab_0
|
179 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
180 |
+
- xz=5.2.6=h5eee18b_0
|
181 |
+
- yarl=1.8.1=py37h5eee18b_0
|
182 |
+
- zeromq=4.3.4=h2531618_0
|
183 |
+
- zipp=3.8.1=pyhd8ed1ab_0
|
184 |
+
- zlib=1.2.12=h5eee18b_3
|
185 |
+
- zstd=1.4.9=haebb681_0
|
186 |
+
- pip:
|
187 |
+
- cachetools==5.3.1
|
188 |
+
- einops==0.6.1
|
189 |
+
- ftfy==6.1.1
|
190 |
+
- gdown==4.7.1
|
191 |
+
- google-auth==2.22.0
|
192 |
+
- google-auth-oauthlib==0.4.6
|
193 |
+
- grpcio==1.57.0
|
194 |
+
- oauthlib==3.2.2
|
195 |
+
- protobuf==3.20.3
|
196 |
+
- pyasn1==0.5.0
|
197 |
+
- pyasn1-modules==0.3.0
|
198 |
+
- regex==2023.8.8
|
199 |
+
- requests-oauthlib==1.3.1
|
200 |
+
- rsa==4.9
|
201 |
+
- tensorboard==2.11.2
|
202 |
+
- tensorboard-data-server==0.6.1
|
203 |
+
- wcwidth==0.2.6
|
204 |
+
prefix: /home/chuan/anaconda3/envs/momask
|
eval_t2m_trans_res.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join as pjoin
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
|
7 |
+
from models.vq.model import RVQVAE
|
8 |
+
|
9 |
+
from options.eval_option import EvalT2MOptions
|
10 |
+
from utils.get_opt import get_opt
|
11 |
+
from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
|
12 |
+
from models.t2m_eval_wrapper import EvaluatorModelWrapper
|
13 |
+
|
14 |
+
import utils.eval_t2m as eval_t2m
|
15 |
+
from utils.fixseed import fixseed
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
def load_vq_model(vq_opt):
|
20 |
+
# opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
|
21 |
+
vq_model = RVQVAE(vq_opt,
|
22 |
+
dim_pose,
|
23 |
+
vq_opt.nb_code,
|
24 |
+
vq_opt.code_dim,
|
25 |
+
vq_opt.output_emb_width,
|
26 |
+
vq_opt.down_t,
|
27 |
+
vq_opt.stride_t,
|
28 |
+
vq_opt.width,
|
29 |
+
vq_opt.depth,
|
30 |
+
vq_opt.dilation_growth_rate,
|
31 |
+
vq_opt.vq_act,
|
32 |
+
vq_opt.vq_norm)
|
33 |
+
ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
|
34 |
+
map_location=opt.device)
|
35 |
+
model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
|
36 |
+
vq_model.load_state_dict(ckpt[model_key])
|
37 |
+
print(f'Loading VQ Model {vq_opt.name} Completed!')
|
38 |
+
return vq_model, vq_opt
|
39 |
+
|
40 |
+
def load_trans_model(model_opt, which_model):
|
41 |
+
t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,
|
42 |
+
cond_mode='text',
|
43 |
+
latent_dim=model_opt.latent_dim,
|
44 |
+
ff_size=model_opt.ff_size,
|
45 |
+
num_layers=model_opt.n_layers,
|
46 |
+
num_heads=model_opt.n_heads,
|
47 |
+
dropout=model_opt.dropout,
|
48 |
+
clip_dim=512,
|
49 |
+
cond_drop_prob=model_opt.cond_drop_prob,
|
50 |
+
clip_version=clip_version,
|
51 |
+
opt=model_opt)
|
52 |
+
ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model),
|
53 |
+
map_location=opt.device)
|
54 |
+
model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'
|
55 |
+
# print(ckpt.keys())
|
56 |
+
missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False)
|
57 |
+
assert len(unexpected_keys) == 0
|
58 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
59 |
+
print(f'Loading Mask Transformer {opt.name} from epoch {ckpt["ep"]}!')
|
60 |
+
return t2m_transformer
|
61 |
+
|
62 |
+
def load_res_model(res_opt):
|
63 |
+
res_opt.num_quantizers = vq_opt.num_quantizers
|
64 |
+
res_opt.num_tokens = vq_opt.nb_code
|
65 |
+
res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
|
66 |
+
cond_mode='text',
|
67 |
+
latent_dim=res_opt.latent_dim,
|
68 |
+
ff_size=res_opt.ff_size,
|
69 |
+
num_layers=res_opt.n_layers,
|
70 |
+
num_heads=res_opt.n_heads,
|
71 |
+
dropout=res_opt.dropout,
|
72 |
+
clip_dim=512,
|
73 |
+
shared_codebook=vq_opt.shared_codebook,
|
74 |
+
cond_drop_prob=res_opt.cond_drop_prob,
|
75 |
+
# codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
|
76 |
+
share_weight=res_opt.share_weight,
|
77 |
+
clip_version=clip_version,
|
78 |
+
opt=res_opt)
|
79 |
+
|
80 |
+
ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'),
|
81 |
+
map_location=opt.device)
|
82 |
+
missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False)
|
83 |
+
assert len(unexpected_keys) == 0
|
84 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
85 |
+
print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!')
|
86 |
+
return res_transformer
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
parser = EvalT2MOptions()
|
90 |
+
opt = parser.parse()
|
91 |
+
fixseed(opt.seed)
|
92 |
+
|
93 |
+
opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
|
94 |
+
torch.autograd.set_detect_anomaly(True)
|
95 |
+
|
96 |
+
dim_pose = 251 if opt.dataset_name == 'kit' else 263
|
97 |
+
|
98 |
+
# out_dir = pjoin(opt.check)
|
99 |
+
root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
100 |
+
model_dir = pjoin(root_dir, 'model')
|
101 |
+
out_dir = pjoin(root_dir, 'eval')
|
102 |
+
os.makedirs(out_dir, exist_ok=True)
|
103 |
+
|
104 |
+
out_path = pjoin(out_dir, "%s.log"%opt.ext)
|
105 |
+
|
106 |
+
f = open(pjoin(out_path), 'w')
|
107 |
+
|
108 |
+
model_opt_path = pjoin(root_dir, 'opt.txt')
|
109 |
+
model_opt = get_opt(model_opt_path, device=opt.device)
|
110 |
+
clip_version = 'ViT-B/32'
|
111 |
+
|
112 |
+
vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
|
113 |
+
vq_opt = get_opt(vq_opt_path, device=opt.device)
|
114 |
+
vq_model, vq_opt = load_vq_model(vq_opt)
|
115 |
+
|
116 |
+
model_opt.num_tokens = vq_opt.nb_code
|
117 |
+
model_opt.num_quantizers = vq_opt.num_quantizers
|
118 |
+
model_opt.code_dim = vq_opt.code_dim
|
119 |
+
|
120 |
+
res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
|
121 |
+
res_opt = get_opt(res_opt_path, device=opt.device)
|
122 |
+
res_model = load_res_model(res_opt)
|
123 |
+
|
124 |
+
assert res_opt.vq_name == model_opt.vq_name
|
125 |
+
|
126 |
+
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if opt.dataset_name == 'kit' \
|
127 |
+
else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
|
128 |
+
|
129 |
+
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
|
130 |
+
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
|
131 |
+
|
132 |
+
##### ---- Dataloader ---- #####
|
133 |
+
opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22
|
134 |
+
|
135 |
+
eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'test', device=opt.device)
|
136 |
+
|
137 |
+
# model_dir = pjoin(opt.)
|
138 |
+
for file in os.listdir(model_dir):
|
139 |
+
if opt.which_epoch != "all" and opt.which_epoch not in file:
|
140 |
+
continue
|
141 |
+
print('loading checkpoint {}'.format(file))
|
142 |
+
t2m_transformer = load_trans_model(model_opt, file)
|
143 |
+
t2m_transformer.eval()
|
144 |
+
vq_model.eval()
|
145 |
+
res_model.eval()
|
146 |
+
|
147 |
+
t2m_transformer.to(opt.device)
|
148 |
+
vq_model.to(opt.device)
|
149 |
+
res_model.to(opt.device)
|
150 |
+
|
151 |
+
fid = []
|
152 |
+
div = []
|
153 |
+
top1 = []
|
154 |
+
top2 = []
|
155 |
+
top3 = []
|
156 |
+
matching = []
|
157 |
+
mm = []
|
158 |
+
|
159 |
+
repeat_time = 20
|
160 |
+
for i in range(repeat_time):
|
161 |
+
with torch.no_grad():
|
162 |
+
best_fid, best_div, Rprecision, best_matching, best_mm = \
|
163 |
+
eval_t2m.evaluation_mask_transformer_test_plus_res(eval_val_loader, vq_model, res_model, t2m_transformer,
|
164 |
+
i, eval_wrapper=eval_wrapper,
|
165 |
+
time_steps=opt.time_steps, cond_scale=opt.cond_scale,
|
166 |
+
temperature=opt.temperature, topkr=opt.topkr,
|
167 |
+
force_mask=opt.force_mask, cal_mm=True)
|
168 |
+
fid.append(best_fid)
|
169 |
+
div.append(best_div)
|
170 |
+
top1.append(Rprecision[0])
|
171 |
+
top2.append(Rprecision[1])
|
172 |
+
top3.append(Rprecision[2])
|
173 |
+
matching.append(best_matching)
|
174 |
+
mm.append(best_mm)
|
175 |
+
|
176 |
+
fid = np.array(fid)
|
177 |
+
div = np.array(div)
|
178 |
+
top1 = np.array(top1)
|
179 |
+
top2 = np.array(top2)
|
180 |
+
top3 = np.array(top3)
|
181 |
+
matching = np.array(matching)
|
182 |
+
mm = np.array(mm)
|
183 |
+
|
184 |
+
print(f'{file} final result:')
|
185 |
+
print(f'{file} final result:', file=f, flush=True)
|
186 |
+
|
187 |
+
msg_final = f"\tFID: {np.mean(fid):.3f}, conf. {np.std(fid) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
|
188 |
+
f"\tDiversity: {np.mean(div):.3f}, conf. {np.std(div) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
|
189 |
+
f"\tTOP1: {np.mean(top1):.3f}, conf. {np.std(top1) * 1.96 / np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2) * 1.96 / np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
|
190 |
+
f"\tMatching: {np.mean(matching):.3f}, conf. {np.std(matching) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
|
191 |
+
f"\tMultimodality:{np.mean(mm):.3f}, conf.{np.std(mm) * 1.96 / np.sqrt(repeat_time):.3f}\n\n"
|
192 |
+
# logger.info(msg_final)
|
193 |
+
print(msg_final)
|
194 |
+
print(msg_final, file=f, flush=True)
|
195 |
+
|
196 |
+
f.close()
|
197 |
+
|
198 |
+
|
199 |
+
# python eval_t2m_trans.py --name t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_vq --dataset_name t2m --gpu_id 3 --cond_scale 4 --time_steps 18 --temperature 1 --topkr 0.9 --gumbel_sample --ext cs4_ts18_tau1_topkr0.9_gs
|
eval_t2m_vq.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from os.path import join as pjoin
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from models.vq.model import RVQVAE
|
7 |
+
from options.vq_option import arg_parse
|
8 |
+
from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
|
9 |
+
import utils.eval_t2m as eval_t2m
|
10 |
+
from utils.get_opt import get_opt
|
11 |
+
from models.t2m_eval_wrapper import EvaluatorModelWrapper
|
12 |
+
import warnings
|
13 |
+
warnings.filterwarnings('ignore')
|
14 |
+
import numpy as np
|
15 |
+
from utils.word_vectorizer import WordVectorizer
|
16 |
+
|
17 |
+
def load_vq_model(vq_opt, which_epoch):
|
18 |
+
# opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
|
19 |
+
|
20 |
+
vq_model = RVQVAE(vq_opt,
|
21 |
+
dim_pose,
|
22 |
+
vq_opt.nb_code,
|
23 |
+
vq_opt.code_dim,
|
24 |
+
vq_opt.code_dim,
|
25 |
+
vq_opt.down_t,
|
26 |
+
vq_opt.stride_t,
|
27 |
+
vq_opt.width,
|
28 |
+
vq_opt.depth,
|
29 |
+
vq_opt.dilation_growth_rate,
|
30 |
+
vq_opt.vq_act,
|
31 |
+
vq_opt.vq_norm)
|
32 |
+
ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', which_epoch),
|
33 |
+
map_location='cpu')
|
34 |
+
model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
|
35 |
+
vq_model.load_state_dict(ckpt[model_key])
|
36 |
+
vq_epoch = ckpt['ep'] if 'ep' in ckpt else -1
|
37 |
+
print(f'Loading VQ Model {vq_opt.name} Completed!, Epoch {vq_epoch}')
|
38 |
+
return vq_model, vq_epoch
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
##### ---- Exp dirs ---- #####
|
42 |
+
args = arg_parse(False)
|
43 |
+
args.device = torch.device("cpu" if args.gpu_id == -1 else "cuda:" + str(args.gpu_id))
|
44 |
+
|
45 |
+
args.out_dir = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'eval')
|
46 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
47 |
+
|
48 |
+
f = open(pjoin(args.out_dir, '%s.log'%args.ext), 'w')
|
49 |
+
|
50 |
+
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataset_name == 'kit' \
|
51 |
+
else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
|
52 |
+
|
53 |
+
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
|
54 |
+
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
|
55 |
+
|
56 |
+
##### ---- Dataloader ---- #####
|
57 |
+
args.nb_joints = 21 if args.dataset_name == 'kit' else 22
|
58 |
+
dim_pose = 251 if args.dataset_name == 'kit' else 263
|
59 |
+
|
60 |
+
eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'test', device=args.device)
|
61 |
+
|
62 |
+
print(len(eval_val_loader))
|
63 |
+
|
64 |
+
##### ---- Network ---- #####
|
65 |
+
vq_opt_path = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'opt.txt')
|
66 |
+
vq_opt = get_opt(vq_opt_path, device=args.device)
|
67 |
+
# net = load_vq_model()
|
68 |
+
|
69 |
+
model_dir = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'model')
|
70 |
+
for file in os.listdir(model_dir):
|
71 |
+
# if not file.endswith('tar'):
|
72 |
+
# continue
|
73 |
+
# if not file.startswith('net_best_fid'):
|
74 |
+
# continue
|
75 |
+
if args.which_epoch != "all" and args.which_epoch not in file:
|
76 |
+
continue
|
77 |
+
print(file)
|
78 |
+
net, ep = load_vq_model(vq_opt, file)
|
79 |
+
|
80 |
+
net.eval()
|
81 |
+
net.cuda()
|
82 |
+
|
83 |
+
fid = []
|
84 |
+
div = []
|
85 |
+
top1 = []
|
86 |
+
top2 = []
|
87 |
+
top3 = []
|
88 |
+
matching = []
|
89 |
+
mae = []
|
90 |
+
repeat_time = 20
|
91 |
+
for i in range(repeat_time):
|
92 |
+
best_fid, best_div, Rprecision, best_matching, l1_dist = \
|
93 |
+
eval_t2m.evaluation_vqvae_plus_mpjpe(eval_val_loader, net, i, eval_wrapper=eval_wrapper, num_joint=args.nb_joints)
|
94 |
+
fid.append(best_fid)
|
95 |
+
div.append(best_div)
|
96 |
+
top1.append(Rprecision[0])
|
97 |
+
top2.append(Rprecision[1])
|
98 |
+
top3.append(Rprecision[2])
|
99 |
+
matching.append(best_matching)
|
100 |
+
mae.append(l1_dist)
|
101 |
+
|
102 |
+
fid = np.array(fid)
|
103 |
+
div = np.array(div)
|
104 |
+
top1 = np.array(top1)
|
105 |
+
top2 = np.array(top2)
|
106 |
+
top3 = np.array(top3)
|
107 |
+
matching = np.array(matching)
|
108 |
+
mae = np.array(mae)
|
109 |
+
|
110 |
+
print(f'{file} final result, epoch {ep}')
|
111 |
+
print(f'{file} final result, epoch {ep}', file=f, flush=True)
|
112 |
+
|
113 |
+
msg_final = f"\tFID: {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}\n" \
|
114 |
+
f"\tDiversity: {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}\n" \
|
115 |
+
f"\tTOP1: {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}\n" \
|
116 |
+
f"\tMatching: {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}\n" \
|
117 |
+
f"\tMAE:{np.mean(mae):.3f}, conf.{np.std(mae)*1.96/np.sqrt(repeat_time):.3f}\n\n"
|
118 |
+
# logger.info(msg_final)
|
119 |
+
print(msg_final)
|
120 |
+
print(msg_final, file=f, flush=True)
|
121 |
+
|
122 |
+
f.close()
|
123 |
+
|
example_data/000612.mp4
ADDED
Binary file (154 kB). View file
|
|
example_data/000612.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85e5a8081278a0e31488eaa29386940b9e4b739fb401042f7ad883afb475ab10
|
3 |
+
size 418824
|
gen_t2m.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join as pjoin
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
|
8 |
+
from models.vq.model import RVQVAE, LengthEstimator
|
9 |
+
|
10 |
+
from options.eval_option import EvalT2MOptions
|
11 |
+
from utils.get_opt import get_opt
|
12 |
+
|
13 |
+
from utils.fixseed import fixseed
|
14 |
+
from visualization.joints2bvh import Joint2BVHConvertor
|
15 |
+
from torch.distributions.categorical import Categorical
|
16 |
+
|
17 |
+
|
18 |
+
from utils.motion_process import recover_from_ric
|
19 |
+
from utils.plot_script import plot_3d_motion
|
20 |
+
|
21 |
+
from utils.paramUtil import t2m_kinematic_chain
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
clip_version = 'ViT-B/32'
|
25 |
+
|
26 |
+
def load_vq_model(vq_opt):
|
27 |
+
# opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
|
28 |
+
vq_model = RVQVAE(vq_opt,
|
29 |
+
vq_opt.dim_pose,
|
30 |
+
vq_opt.nb_code,
|
31 |
+
vq_opt.code_dim,
|
32 |
+
vq_opt.output_emb_width,
|
33 |
+
vq_opt.down_t,
|
34 |
+
vq_opt.stride_t,
|
35 |
+
vq_opt.width,
|
36 |
+
vq_opt.depth,
|
37 |
+
vq_opt.dilation_growth_rate,
|
38 |
+
vq_opt.vq_act,
|
39 |
+
vq_opt.vq_norm)
|
40 |
+
ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
|
41 |
+
map_location='cpu')
|
42 |
+
model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
|
43 |
+
vq_model.load_state_dict(ckpt[model_key])
|
44 |
+
print(f'Loading VQ Model {vq_opt.name} Completed!')
|
45 |
+
return vq_model, vq_opt
|
46 |
+
|
47 |
+
def load_trans_model(model_opt, opt, which_model):
|
48 |
+
t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,
|
49 |
+
cond_mode='text',
|
50 |
+
latent_dim=model_opt.latent_dim,
|
51 |
+
ff_size=model_opt.ff_size,
|
52 |
+
num_layers=model_opt.n_layers,
|
53 |
+
num_heads=model_opt.n_heads,
|
54 |
+
dropout=model_opt.dropout,
|
55 |
+
clip_dim=512,
|
56 |
+
cond_drop_prob=model_opt.cond_drop_prob,
|
57 |
+
clip_version=clip_version,
|
58 |
+
opt=model_opt)
|
59 |
+
ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model),
|
60 |
+
map_location='cpu')
|
61 |
+
model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'
|
62 |
+
# print(ckpt.keys())
|
63 |
+
missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False)
|
64 |
+
assert len(unexpected_keys) == 0
|
65 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
66 |
+
print(f'Loading Transformer {opt.name} from epoch {ckpt["ep"]}!')
|
67 |
+
return t2m_transformer
|
68 |
+
|
69 |
+
def load_res_model(res_opt, vq_opt, opt):
|
70 |
+
res_opt.num_quantizers = vq_opt.num_quantizers
|
71 |
+
res_opt.num_tokens = vq_opt.nb_code
|
72 |
+
res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
|
73 |
+
cond_mode='text',
|
74 |
+
latent_dim=res_opt.latent_dim,
|
75 |
+
ff_size=res_opt.ff_size,
|
76 |
+
num_layers=res_opt.n_layers,
|
77 |
+
num_heads=res_opt.n_heads,
|
78 |
+
dropout=res_opt.dropout,
|
79 |
+
clip_dim=512,
|
80 |
+
shared_codebook=vq_opt.shared_codebook,
|
81 |
+
cond_drop_prob=res_opt.cond_drop_prob,
|
82 |
+
# codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
|
83 |
+
share_weight=res_opt.share_weight,
|
84 |
+
clip_version=clip_version,
|
85 |
+
opt=res_opt)
|
86 |
+
|
87 |
+
ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'),
|
88 |
+
map_location=opt.device)
|
89 |
+
missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False)
|
90 |
+
assert len(unexpected_keys) == 0
|
91 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
92 |
+
print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!')
|
93 |
+
return res_transformer
|
94 |
+
|
95 |
+
def load_len_estimator(opt):
|
96 |
+
model = LengthEstimator(512, 50)
|
97 |
+
ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_estimator', 'model', 'finest.tar'),
|
98 |
+
map_location=opt.device)
|
99 |
+
model.load_state_dict(ckpt['estimator'])
|
100 |
+
print(f'Loading Length Estimator from epoch {ckpt["epoch"]}!')
|
101 |
+
return model
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == '__main__':
|
105 |
+
parser = EvalT2MOptions()
|
106 |
+
opt = parser.parse()
|
107 |
+
fixseed(opt.seed)
|
108 |
+
|
109 |
+
opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
|
110 |
+
torch.autograd.set_detect_anomaly(True)
|
111 |
+
|
112 |
+
dim_pose = 251 if opt.dataset_name == 'kit' else 263
|
113 |
+
|
114 |
+
# out_dir = pjoin(opt.check)
|
115 |
+
root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
116 |
+
model_dir = pjoin(root_dir, 'model')
|
117 |
+
result_dir = pjoin('./generation', opt.ext)
|
118 |
+
joints_dir = pjoin(result_dir, 'joints')
|
119 |
+
animation_dir = pjoin(result_dir, 'animations')
|
120 |
+
os.makedirs(joints_dir, exist_ok=True)
|
121 |
+
os.makedirs(animation_dir,exist_ok=True)
|
122 |
+
|
123 |
+
model_opt_path = pjoin(root_dir, 'opt.txt')
|
124 |
+
model_opt = get_opt(model_opt_path, device=opt.device)
|
125 |
+
|
126 |
+
|
127 |
+
#######################
|
128 |
+
######Loading RVQ######
|
129 |
+
#######################
|
130 |
+
vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
|
131 |
+
vq_opt = get_opt(vq_opt_path, device=opt.device)
|
132 |
+
vq_opt.dim_pose = dim_pose
|
133 |
+
vq_model, vq_opt = load_vq_model(vq_opt)
|
134 |
+
|
135 |
+
model_opt.num_tokens = vq_opt.nb_code
|
136 |
+
model_opt.num_quantizers = vq_opt.num_quantizers
|
137 |
+
model_opt.code_dim = vq_opt.code_dim
|
138 |
+
|
139 |
+
#################################
|
140 |
+
######Loading R-Transformer######
|
141 |
+
#################################
|
142 |
+
res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
|
143 |
+
res_opt = get_opt(res_opt_path, device=opt.device)
|
144 |
+
res_model = load_res_model(res_opt, vq_opt, opt)
|
145 |
+
|
146 |
+
assert res_opt.vq_name == model_opt.vq_name
|
147 |
+
|
148 |
+
#################################
|
149 |
+
######Loading M-Transformer######
|
150 |
+
#################################
|
151 |
+
t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
|
152 |
+
|
153 |
+
##################################
|
154 |
+
#####Loading Length Predictor#####
|
155 |
+
##################################
|
156 |
+
length_estimator = load_len_estimator(model_opt)
|
157 |
+
|
158 |
+
t2m_transformer.eval()
|
159 |
+
vq_model.eval()
|
160 |
+
res_model.eval()
|
161 |
+
length_estimator.eval()
|
162 |
+
|
163 |
+
res_model.to(opt.device)
|
164 |
+
t2m_transformer.to(opt.device)
|
165 |
+
vq_model.to(opt.device)
|
166 |
+
length_estimator.to(opt.device)
|
167 |
+
|
168 |
+
##### ---- Dataloader ---- #####
|
169 |
+
opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22
|
170 |
+
|
171 |
+
mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
|
172 |
+
std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
|
173 |
+
def inv_transform(data):
|
174 |
+
return data * std + mean
|
175 |
+
|
176 |
+
prompt_list = []
|
177 |
+
length_list = []
|
178 |
+
|
179 |
+
est_length = False
|
180 |
+
if opt.text_prompt != "":
|
181 |
+
prompt_list.append(opt.text_prompt)
|
182 |
+
if opt.motion_length == 0:
|
183 |
+
est_length = True
|
184 |
+
else:
|
185 |
+
length_list.append(opt.motion_length)
|
186 |
+
elif opt.text_path != "":
|
187 |
+
with open(opt.text_path, 'r') as f:
|
188 |
+
lines = f.readlines()
|
189 |
+
for line in lines:
|
190 |
+
infos = line.split('#')
|
191 |
+
prompt_list.append(infos[0])
|
192 |
+
if len(infos) == 1 or (not infos[1].isdigit()):
|
193 |
+
est_length = True
|
194 |
+
length_list = []
|
195 |
+
else:
|
196 |
+
length_list.append(int(infos[-1]))
|
197 |
+
else:
|
198 |
+
raise "A text prompt, or a file a text prompts are required!!!"
|
199 |
+
# print('loading checkpoint {}'.format(file))
|
200 |
+
|
201 |
+
if est_length:
|
202 |
+
print("Since no motion length are specified, we will use estimated motion lengthes!!")
|
203 |
+
text_embedding = t2m_transformer.encode_text(prompt_list)
|
204 |
+
pred_dis = length_estimator(text_embedding)
|
205 |
+
probs = F.softmax(pred_dis, dim=-1) # (b, ntoken)
|
206 |
+
token_lens = Categorical(probs).sample() # (b, seqlen)
|
207 |
+
# lengths = torch.multinomial()
|
208 |
+
else:
|
209 |
+
token_lens = torch.LongTensor(length_list) // 4
|
210 |
+
token_lens = token_lens.to(opt.device).long()
|
211 |
+
|
212 |
+
m_length = token_lens * 4
|
213 |
+
captions = prompt_list
|
214 |
+
|
215 |
+
sample = 0
|
216 |
+
kinematic_chain = t2m_kinematic_chain
|
217 |
+
converter = Joint2BVHConvertor()
|
218 |
+
|
219 |
+
for r in range(opt.repeat_times):
|
220 |
+
print("-->Repeat %d"%r)
|
221 |
+
with torch.no_grad():
|
222 |
+
mids = t2m_transformer.generate(captions, token_lens,
|
223 |
+
timesteps=opt.time_steps,
|
224 |
+
cond_scale=opt.cond_scale,
|
225 |
+
temperature=opt.temperature,
|
226 |
+
topk_filter_thres=opt.topkr,
|
227 |
+
gsample=opt.gumbel_sample)
|
228 |
+
# print(mids)
|
229 |
+
# print(mids.shape)
|
230 |
+
mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5)
|
231 |
+
pred_motions = vq_model.forward_decoder(mids)
|
232 |
+
|
233 |
+
pred_motions = pred_motions.detach().cpu().numpy()
|
234 |
+
|
235 |
+
data = inv_transform(pred_motions)
|
236 |
+
|
237 |
+
for k, (caption, joint_data) in enumerate(zip(captions, data)):
|
238 |
+
print("---->Sample %d: %s %d"%(k, caption, m_length[k]))
|
239 |
+
animation_path = pjoin(animation_dir, str(k))
|
240 |
+
joint_path = pjoin(joints_dir, str(k))
|
241 |
+
|
242 |
+
os.makedirs(animation_path, exist_ok=True)
|
243 |
+
os.makedirs(joint_path, exist_ok=True)
|
244 |
+
|
245 |
+
joint_data = joint_data[:m_length[k]]
|
246 |
+
joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()
|
247 |
+
|
248 |
+
bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))
|
249 |
+
_, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)
|
250 |
+
|
251 |
+
bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))
|
252 |
+
_, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)
|
253 |
+
|
254 |
+
|
255 |
+
save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))
|
256 |
+
ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))
|
257 |
+
|
258 |
+
plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=caption, fps=20)
|
259 |
+
plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)
|
260 |
+
np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)
|
261 |
+
np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)
|
models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/__init__.py
ADDED
File without changes
|
models/mask_transformer/__init__.py
ADDED
File without changes
|
models/mask_transformer/tools.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
# return mask where padding is FALSE
|
7 |
+
def lengths_to_mask(lengths, max_len):
|
8 |
+
# max_len = max(lengths)
|
9 |
+
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
|
10 |
+
return mask #(b, len)
|
11 |
+
|
12 |
+
# return mask where padding is ALL FALSE
|
13 |
+
def get_pad_mask_idx(seq, pad_idx):
|
14 |
+
return (seq != pad_idx).unsqueeze(1)
|
15 |
+
|
16 |
+
# Given seq: (b, s)
|
17 |
+
# Return mat: (1, s, s)
|
18 |
+
# Example Output:
|
19 |
+
# [[[ True, False, False],
|
20 |
+
# [ True, True, False],
|
21 |
+
# [ True, True, True]]]
|
22 |
+
# For causal attention
|
23 |
+
def get_subsequent_mask(seq):
|
24 |
+
sz_b, seq_len = seq.shape
|
25 |
+
subsequent_mask = (1 - torch.triu(
|
26 |
+
torch.ones((1, seq_len, seq_len)), diagonal=1)).bool()
|
27 |
+
return subsequent_mask.to(seq.device)
|
28 |
+
|
29 |
+
|
30 |
+
def exists(val):
|
31 |
+
return val is not None
|
32 |
+
|
33 |
+
def default(val, d):
|
34 |
+
return val if exists(val) else d
|
35 |
+
|
36 |
+
def eval_decorator(fn):
|
37 |
+
def inner(model, *args, **kwargs):
|
38 |
+
was_training = model.training
|
39 |
+
model.eval()
|
40 |
+
out = fn(model, *args, **kwargs)
|
41 |
+
model.train(was_training)
|
42 |
+
return out
|
43 |
+
return inner
|
44 |
+
|
45 |
+
def l2norm(t):
|
46 |
+
return F.normalize(t, dim = -1)
|
47 |
+
|
48 |
+
# tensor helpers
|
49 |
+
|
50 |
+
# Get a random subset of TRUE mask, with prob
|
51 |
+
def get_mask_subset_prob(mask, prob):
|
52 |
+
subset_mask = torch.bernoulli(mask, p=prob) & mask
|
53 |
+
return subset_mask
|
54 |
+
|
55 |
+
|
56 |
+
# Get mask of special_tokens in ids
|
57 |
+
def get_mask_special_tokens(ids, special_ids):
|
58 |
+
mask = torch.zeros_like(ids).bool()
|
59 |
+
for special_id in special_ids:
|
60 |
+
mask |= (ids==special_id)
|
61 |
+
return mask
|
62 |
+
|
63 |
+
# network builder helpers
|
64 |
+
def _get_activation_fn(activation):
|
65 |
+
if activation == "relu":
|
66 |
+
return F.relu
|
67 |
+
elif activation == "gelu":
|
68 |
+
return F.gelu
|
69 |
+
|
70 |
+
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
71 |
+
|
72 |
+
# classifier free guidance functions
|
73 |
+
|
74 |
+
def uniform(shape, device=None):
|
75 |
+
return torch.zeros(shape, device=device).float().uniform_(0, 1)
|
76 |
+
|
77 |
+
def prob_mask_like(shape, prob, device=None):
|
78 |
+
if prob == 1:
|
79 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
80 |
+
elif prob == 0:
|
81 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
82 |
+
else:
|
83 |
+
return uniform(shape, device=device) < prob
|
84 |
+
|
85 |
+
# sampling helpers
|
86 |
+
|
87 |
+
def log(t, eps = 1e-20):
|
88 |
+
return torch.log(t.clamp(min = eps))
|
89 |
+
|
90 |
+
def gumbel_noise(t):
|
91 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
92 |
+
return -log(-log(noise))
|
93 |
+
|
94 |
+
def gumbel_sample(t, temperature = 1., dim = 1):
|
95 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
96 |
+
|
97 |
+
|
98 |
+
# Example input:
|
99 |
+
# [[ 0.3596, 0.0862, 0.9771, -1.0000, -1.0000, -1.0000],
|
100 |
+
# [ 0.4141, 0.1781, 0.6628, 0.5721, -1.0000, -1.0000],
|
101 |
+
# [ 0.9428, 0.3586, 0.1659, 0.8172, 0.9273, -1.0000]]
|
102 |
+
# Example output:
|
103 |
+
# [[ -inf, -inf, 0.9771, -inf, -inf, -inf],
|
104 |
+
# [ -inf, -inf, 0.6628, -inf, -inf, -inf],
|
105 |
+
# [0.9428, -inf, -inf, -inf, -inf, -inf]]
|
106 |
+
def top_k(logits, thres = 0.9, dim = 1):
|
107 |
+
k = math.ceil((1 - thres) * logits.shape[dim])
|
108 |
+
val, ind = logits.topk(k, dim = dim)
|
109 |
+
probs = torch.full_like(logits, float('-inf'))
|
110 |
+
probs.scatter_(dim, ind, val)
|
111 |
+
# func verified
|
112 |
+
# print(probs)
|
113 |
+
# print(logits)
|
114 |
+
# raise
|
115 |
+
return probs
|
116 |
+
|
117 |
+
# noise schedules
|
118 |
+
|
119 |
+
# More on large value, less on small
|
120 |
+
def cosine_schedule(t):
|
121 |
+
return torch.cos(t * math.pi * 0.5)
|
122 |
+
|
123 |
+
def scale_cosine_schedule(t, scale):
|
124 |
+
return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.)
|
125 |
+
|
126 |
+
# More on small value, less on large
|
127 |
+
def q_schedule(bs, low, high, device):
|
128 |
+
noise = uniform((bs,), device=device)
|
129 |
+
schedule = 1 - cosine_schedule(noise)
|
130 |
+
return torch.round(schedule * (high - low - 1)).long() + low
|
131 |
+
|
132 |
+
def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1):
|
133 |
+
loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing)
|
134 |
+
# pred_id = torch.argmax(pred, dim=1)
|
135 |
+
# mask = labels.ne(ignore_index)
|
136 |
+
# n_correct = pred_id.eq(labels).masked_select(mask)
|
137 |
+
# acc = torch.mean(n_correct.float()).item()
|
138 |
+
pred_id_k = torch.topk(pred, k=tk, dim=1).indices
|
139 |
+
pred_id = pred_id_k[:, 0]
|
140 |
+
mask = labels.ne(ignore_index)
|
141 |
+
n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask)
|
142 |
+
acc = torch.mean(n_correct.float()).item()
|
143 |
+
|
144 |
+
return loss, pred_id, acc
|
145 |
+
|
146 |
+
|
147 |
+
def cal_loss(pred, labels, ignore_index=None, smoothing=0.):
|
148 |
+
'''Calculate cross entropy loss, apply label smoothing if needed.'''
|
149 |
+
# print(pred.shape, labels.shape) #torch.Size([64, 1028, 55]) torch.Size([64, 55])
|
150 |
+
# print(pred.shape, labels.shape) #torch.Size([64, 1027, 55]) torch.Size([64, 55])
|
151 |
+
if smoothing:
|
152 |
+
space = 2
|
153 |
+
n_class = pred.size(1)
|
154 |
+
mask = labels.ne(ignore_index)
|
155 |
+
one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class]
|
156 |
+
# one_hot = torch.zeros_like(pred).scatter(1, labels.unsqueeze(1), 1)
|
157 |
+
sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
|
158 |
+
neg_log_prb = -F.log_softmax(pred, dim=1)
|
159 |
+
loss = (sm_one_hot * neg_log_prb).sum(dim=1)
|
160 |
+
# loss = F.cross_entropy(pred, sm_one_hot, reduction='none')
|
161 |
+
loss = torch.mean(loss.masked_select(mask))
|
162 |
+
else:
|
163 |
+
loss = F.cross_entropy(pred, labels, ignore_index=ignore_index)
|
164 |
+
|
165 |
+
return loss
|
models/mask_transformer/transformer.py
ADDED
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
# from networks.layers import *
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import clip
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
import math
|
9 |
+
from random import random
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
from typing import Callable, Optional, List, Dict
|
12 |
+
from copy import deepcopy
|
13 |
+
from functools import partial
|
14 |
+
from models.mask_transformer.tools import *
|
15 |
+
from torch.distributions.categorical import Categorical
|
16 |
+
|
17 |
+
class InputProcess(nn.Module):
|
18 |
+
def __init__(self, input_feats, latent_dim):
|
19 |
+
super().__init__()
|
20 |
+
self.input_feats = input_feats
|
21 |
+
self.latent_dim = latent_dim
|
22 |
+
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
# [bs, ntokens, input_feats]
|
26 |
+
x = x.permute((1, 0, 2)) # [seqen, bs, input_feats]
|
27 |
+
# print(x.shape)
|
28 |
+
x = self.poseEmbedding(x) # [seqlen, bs, d]
|
29 |
+
return x
|
30 |
+
|
31 |
+
class PositionalEncoding(nn.Module):
|
32 |
+
#Borrow from MDM, the same as above, but add dropout, exponential may improve precision
|
33 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
34 |
+
super(PositionalEncoding, self).__init__()
|
35 |
+
self.dropout = nn.Dropout(p=dropout)
|
36 |
+
|
37 |
+
pe = torch.zeros(max_len, d_model)
|
38 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
39 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
40 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
41 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
42 |
+
pe = pe.unsqueeze(0).transpose(0, 1) #[max_len, 1, d_model]
|
43 |
+
|
44 |
+
self.register_buffer('pe', pe)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
# not used in the final model
|
48 |
+
x = x + self.pe[:x.shape[0], :]
|
49 |
+
return self.dropout(x)
|
50 |
+
|
51 |
+
class OutputProcess_Bert(nn.Module):
|
52 |
+
def __init__(self, out_feats, latent_dim):
|
53 |
+
super().__init__()
|
54 |
+
self.dense = nn.Linear(latent_dim, latent_dim)
|
55 |
+
self.transform_act_fn = F.gelu
|
56 |
+
self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
|
57 |
+
self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!
|
58 |
+
|
59 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
60 |
+
hidden_states = self.dense(hidden_states)
|
61 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
62 |
+
hidden_states = self.LayerNorm(hidden_states)
|
63 |
+
output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats]
|
64 |
+
output = output.permute(1, 2, 0) # [bs, c, seqlen]
|
65 |
+
return output
|
66 |
+
|
67 |
+
class OutputProcess(nn.Module):
|
68 |
+
def __init__(self, out_feats, latent_dim):
|
69 |
+
super().__init__()
|
70 |
+
self.dense = nn.Linear(latent_dim, latent_dim)
|
71 |
+
self.transform_act_fn = F.gelu
|
72 |
+
self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
|
73 |
+
self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!
|
74 |
+
|
75 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
76 |
+
hidden_states = self.dense(hidden_states)
|
77 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
78 |
+
hidden_states = self.LayerNorm(hidden_states)
|
79 |
+
output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats]
|
80 |
+
output = output.permute(1, 2, 0) # [bs, e, seqlen]
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
class MaskTransformer(nn.Module):
|
85 |
+
def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
|
86 |
+
num_heads=4, dropout=0.1, clip_dim=512, cond_drop_prob=0.1,
|
87 |
+
clip_version=None, opt=None, **kargs):
|
88 |
+
super(MaskTransformer, self).__init__()
|
89 |
+
print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')
|
90 |
+
|
91 |
+
self.code_dim = code_dim
|
92 |
+
self.latent_dim = latent_dim
|
93 |
+
self.clip_dim = clip_dim
|
94 |
+
self.dropout = dropout
|
95 |
+
self.opt = opt
|
96 |
+
|
97 |
+
self.cond_mode = cond_mode
|
98 |
+
self.cond_drop_prob = cond_drop_prob
|
99 |
+
|
100 |
+
if self.cond_mode == 'action':
|
101 |
+
assert 'num_actions' in kargs
|
102 |
+
self.num_actions = kargs.get('num_actions', 1)
|
103 |
+
|
104 |
+
'''
|
105 |
+
Preparing Networks
|
106 |
+
'''
|
107 |
+
self.input_process = InputProcess(self.code_dim, self.latent_dim)
|
108 |
+
self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)
|
109 |
+
|
110 |
+
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
|
111 |
+
nhead=num_heads,
|
112 |
+
dim_feedforward=ff_size,
|
113 |
+
dropout=dropout,
|
114 |
+
activation='gelu')
|
115 |
+
|
116 |
+
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
|
117 |
+
num_layers=num_layers)
|
118 |
+
|
119 |
+
self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
|
120 |
+
|
121 |
+
# if self.cond_mode != 'no_cond':
|
122 |
+
if self.cond_mode == 'text':
|
123 |
+
self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim)
|
124 |
+
elif self.cond_mode == 'action':
|
125 |
+
self.cond_emb = nn.Linear(self.num_actions, self.latent_dim)
|
126 |
+
elif self.cond_mode == 'uncond':
|
127 |
+
self.cond_emb = nn.Identity()
|
128 |
+
else:
|
129 |
+
raise KeyError("Unsupported condition mode!!!")
|
130 |
+
|
131 |
+
|
132 |
+
_num_tokens = opt.num_tokens + 2 # two dummy tokens, one for masking, one for padding
|
133 |
+
self.mask_id = opt.num_tokens
|
134 |
+
self.pad_id = opt.num_tokens + 1
|
135 |
+
|
136 |
+
self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim)
|
137 |
+
|
138 |
+
self.token_emb = nn.Embedding(_num_tokens, self.code_dim)
|
139 |
+
|
140 |
+
self.apply(self.__init_weights)
|
141 |
+
|
142 |
+
'''
|
143 |
+
Preparing frozen weights
|
144 |
+
'''
|
145 |
+
|
146 |
+
if self.cond_mode == 'text':
|
147 |
+
print('Loading CLIP...')
|
148 |
+
self.clip_version = clip_version
|
149 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
150 |
+
|
151 |
+
self.noise_schedule = cosine_schedule
|
152 |
+
|
153 |
+
def load_and_freeze_token_emb(self, codebook):
|
154 |
+
'''
|
155 |
+
:param codebook: (c, d)
|
156 |
+
:return:
|
157 |
+
'''
|
158 |
+
assert self.training, 'Only necessary in training mode'
|
159 |
+
c, d = codebook.shape
|
160 |
+
self.token_emb.weight = nn.Parameter(torch.cat([codebook, torch.zeros(size=(2, d), device=codebook.device)], dim=0)) #add two dummy tokens, 0 vectors
|
161 |
+
self.token_emb.requires_grad_(False)
|
162 |
+
# self.token_emb.weight.requires_grad = False
|
163 |
+
# self.token_emb_ready = True
|
164 |
+
print("Token embedding initialized!")
|
165 |
+
|
166 |
+
def __init_weights(self, module):
|
167 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
168 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
169 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
170 |
+
module.bias.data.zero_()
|
171 |
+
elif isinstance(module, nn.LayerNorm):
|
172 |
+
module.bias.data.zero_()
|
173 |
+
module.weight.data.fill_(1.0)
|
174 |
+
|
175 |
+
def parameters_wo_clip(self):
|
176 |
+
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
|
177 |
+
|
178 |
+
def load_and_freeze_clip(self, clip_version):
|
179 |
+
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
180 |
+
jit=False) # Must set jit=False for training
|
181 |
+
# Cannot run on cpu
|
182 |
+
clip.model.convert_weights(
|
183 |
+
clip_model) # Actually this line is unnecessary since clip by default already on float16
|
184 |
+
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
|
185 |
+
|
186 |
+
# Freeze CLIP weights
|
187 |
+
clip_model.eval()
|
188 |
+
for p in clip_model.parameters():
|
189 |
+
p.requires_grad = False
|
190 |
+
|
191 |
+
return clip_model
|
192 |
+
|
193 |
+
def encode_text(self, raw_text):
|
194 |
+
device = next(self.parameters()).device
|
195 |
+
text = clip.tokenize(raw_text, truncate=True).to(device)
|
196 |
+
feat_clip_text = self.clip_model.encode_text(text).float()
|
197 |
+
return feat_clip_text
|
198 |
+
|
199 |
+
def mask_cond(self, cond, force_mask=False):
|
200 |
+
bs, d = cond.shape
|
201 |
+
if force_mask:
|
202 |
+
return torch.zeros_like(cond)
|
203 |
+
elif self.training and self.cond_drop_prob > 0.:
|
204 |
+
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
|
205 |
+
return cond * (1. - mask)
|
206 |
+
else:
|
207 |
+
return cond
|
208 |
+
|
209 |
+
def trans_forward(self, motion_ids, cond, padding_mask, force_mask=False):
|
210 |
+
'''
|
211 |
+
:param motion_ids: (b, seqlen)
|
212 |
+
:padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
|
213 |
+
:param cond: (b, embed_dim) for text, (b, num_actions) for action
|
214 |
+
:param force_mask: boolean
|
215 |
+
:return:
|
216 |
+
-logits: (b, num_token, seqlen)
|
217 |
+
'''
|
218 |
+
|
219 |
+
cond = self.mask_cond(cond, force_mask=force_mask)
|
220 |
+
|
221 |
+
# print(motion_ids.shape)
|
222 |
+
x = self.token_emb(motion_ids)
|
223 |
+
# print(x.shape)
|
224 |
+
# (b, seqlen, d) -> (seqlen, b, latent_dim)
|
225 |
+
x = self.input_process(x)
|
226 |
+
|
227 |
+
cond = self.cond_emb(cond).unsqueeze(0) #(1, b, latent_dim)
|
228 |
+
|
229 |
+
x = self.position_enc(x)
|
230 |
+
xseq = torch.cat([cond, x], dim=0) #(seqlen+1, b, latent_dim)
|
231 |
+
|
232 |
+
padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) #(b, seqlen+1)
|
233 |
+
# print(xseq.shape, padding_mask.shape)
|
234 |
+
|
235 |
+
# print(padding_mask.shape, xseq.shape)
|
236 |
+
|
237 |
+
output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[1:] #(seqlen, b, e)
|
238 |
+
logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
|
239 |
+
return logits
|
240 |
+
|
241 |
+
def forward(self, ids, y, m_lens):
|
242 |
+
'''
|
243 |
+
:param ids: (b, n)
|
244 |
+
:param y: raw text for cond_mode=text, (b, ) for cond_mode=action
|
245 |
+
:m_lens: (b,)
|
246 |
+
:return:
|
247 |
+
'''
|
248 |
+
|
249 |
+
bs, ntokens = ids.shape
|
250 |
+
device = ids.device
|
251 |
+
|
252 |
+
# Positions that are PADDED are ALL FALSE
|
253 |
+
non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n)
|
254 |
+
ids = torch.where(non_pad_mask, ids, self.pad_id)
|
255 |
+
|
256 |
+
force_mask = False
|
257 |
+
if self.cond_mode == 'text':
|
258 |
+
with torch.no_grad():
|
259 |
+
cond_vector = self.encode_text(y)
|
260 |
+
elif self.cond_mode == 'action':
|
261 |
+
cond_vector = self.enc_action(y).to(device).float()
|
262 |
+
elif self.cond_mode == 'uncond':
|
263 |
+
cond_vector = torch.zeros(bs, self.latent_dim).float().to(device)
|
264 |
+
force_mask = True
|
265 |
+
else:
|
266 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
267 |
+
|
268 |
+
|
269 |
+
'''
|
270 |
+
Prepare mask
|
271 |
+
'''
|
272 |
+
rand_time = uniform((bs,), device=device)
|
273 |
+
rand_mask_probs = self.noise_schedule(rand_time)
|
274 |
+
num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)
|
275 |
+
|
276 |
+
batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)
|
277 |
+
# Positions to be MASKED are ALL TRUE
|
278 |
+
mask = batch_randperm < num_token_masked.unsqueeze(-1)
|
279 |
+
|
280 |
+
# Positions to be MASKED must also be NON-PADDED
|
281 |
+
mask &= non_pad_mask
|
282 |
+
|
283 |
+
# Note this is our training target, not input
|
284 |
+
labels = torch.where(mask, ids, self.mask_id)
|
285 |
+
|
286 |
+
x_ids = ids.clone()
|
287 |
+
|
288 |
+
# Further Apply Bert Masking Scheme
|
289 |
+
# Step 1: 10% replace with an incorrect token
|
290 |
+
mask_rid = get_mask_subset_prob(mask, 0.1)
|
291 |
+
rand_id = torch.randint_like(x_ids, high=self.opt.num_tokens)
|
292 |
+
x_ids = torch.where(mask_rid, rand_id, x_ids)
|
293 |
+
# Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
|
294 |
+
mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)
|
295 |
+
|
296 |
+
# mask_mid = mask
|
297 |
+
|
298 |
+
x_ids = torch.where(mask_mid, self.mask_id, x_ids)
|
299 |
+
|
300 |
+
logits = self.trans_forward(x_ids, cond_vector, ~non_pad_mask, force_mask)
|
301 |
+
ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id)
|
302 |
+
|
303 |
+
return ce_loss, pred_id, acc
|
304 |
+
|
305 |
+
def forward_with_cond_scale(self,
|
306 |
+
motion_ids,
|
307 |
+
cond_vector,
|
308 |
+
padding_mask,
|
309 |
+
cond_scale=3,
|
310 |
+
force_mask=False):
|
311 |
+
# bs = motion_ids.shape[0]
|
312 |
+
# if cond_scale == 1:
|
313 |
+
if force_mask:
|
314 |
+
return self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True)
|
315 |
+
|
316 |
+
logits = self.trans_forward(motion_ids, cond_vector, padding_mask)
|
317 |
+
if cond_scale == 1:
|
318 |
+
return logits
|
319 |
+
|
320 |
+
aux_logits = self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True)
|
321 |
+
|
322 |
+
scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
|
323 |
+
return scaled_logits
|
324 |
+
|
325 |
+
@torch.no_grad()
|
326 |
+
@eval_decorator
|
327 |
+
def generate(self,
|
328 |
+
conds,
|
329 |
+
m_lens,
|
330 |
+
timesteps: int,
|
331 |
+
cond_scale: int,
|
332 |
+
temperature=1,
|
333 |
+
topk_filter_thres=0.9,
|
334 |
+
gsample=False,
|
335 |
+
force_mask=False
|
336 |
+
):
|
337 |
+
# print(self.opt.num_quantizers)
|
338 |
+
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
|
339 |
+
|
340 |
+
device = next(self.parameters()).device
|
341 |
+
seq_len = max(m_lens)
|
342 |
+
batch_size = len(m_lens)
|
343 |
+
|
344 |
+
if self.cond_mode == 'text':
|
345 |
+
with torch.no_grad():
|
346 |
+
cond_vector = self.encode_text(conds)
|
347 |
+
elif self.cond_mode == 'action':
|
348 |
+
cond_vector = self.enc_action(conds).to(device)
|
349 |
+
elif self.cond_mode == 'uncond':
|
350 |
+
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
|
351 |
+
else:
|
352 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
353 |
+
|
354 |
+
padding_mask = ~lengths_to_mask(m_lens, seq_len)
|
355 |
+
# print(padding_mask.shape, )
|
356 |
+
|
357 |
+
# Start from all tokens being masked
|
358 |
+
ids = torch.where(padding_mask, self.pad_id, self.mask_id)
|
359 |
+
scores = torch.where(padding_mask, 1e5, 0.)
|
360 |
+
starting_temperature = temperature
|
361 |
+
|
362 |
+
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
|
363 |
+
# 0 < timestep < 1
|
364 |
+
rand_mask_prob = self.noise_schedule(timestep) # Tensor
|
365 |
+
|
366 |
+
'''
|
367 |
+
Maskout, and cope with variable length
|
368 |
+
'''
|
369 |
+
# fix: the ratio regarding lengths, instead of seq_len
|
370 |
+
num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1) # (b, )
|
371 |
+
|
372 |
+
# select num_token_masked tokens with lowest scores to be masked
|
373 |
+
sorted_indices = scores.argsort(
|
374 |
+
dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
|
375 |
+
ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
|
376 |
+
is_mask = (ranks < num_token_masked.unsqueeze(-1))
|
377 |
+
ids = torch.where(is_mask, self.mask_id, ids)
|
378 |
+
|
379 |
+
'''
|
380 |
+
Preparing input
|
381 |
+
'''
|
382 |
+
# (b, num_token, seqlen)
|
383 |
+
logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector,
|
384 |
+
padding_mask=padding_mask,
|
385 |
+
cond_scale=cond_scale,
|
386 |
+
force_mask=force_mask)
|
387 |
+
|
388 |
+
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
|
389 |
+
# print(logits.shape, self.opt.num_tokens)
|
390 |
+
# clean low prob token
|
391 |
+
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
|
392 |
+
|
393 |
+
'''
|
394 |
+
Update ids
|
395 |
+
'''
|
396 |
+
# if force_mask:
|
397 |
+
temperature = starting_temperature
|
398 |
+
# else:
|
399 |
+
# temperature = starting_temperature * (steps_until_x0 / timesteps)
|
400 |
+
# temperature = max(temperature, 1e-4)
|
401 |
+
# print(filtered_logits.shape)
|
402 |
+
# temperature is annealed, gradually reducing temperature as well as randomness
|
403 |
+
if gsample: # use gumbel_softmax sampling
|
404 |
+
# print("1111")
|
405 |
+
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
|
406 |
+
else: # use multinomial sampling
|
407 |
+
# print("2222")
|
408 |
+
probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
|
409 |
+
# print(temperature, starting_temperature, steps_until_x0, timesteps)
|
410 |
+
# print(probs / temperature)
|
411 |
+
pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
|
412 |
+
|
413 |
+
# print(pred_ids.max(), pred_ids.min())
|
414 |
+
# if pred_ids.
|
415 |
+
ids = torch.where(is_mask, pred_ids, ids)
|
416 |
+
|
417 |
+
'''
|
418 |
+
Updating scores
|
419 |
+
'''
|
420 |
+
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
|
421 |
+
scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1)
|
422 |
+
scores = scores.squeeze(-1) # (b, seqlen)
|
423 |
+
|
424 |
+
# We do not want to re-mask the previously kept tokens, or pad tokens
|
425 |
+
scores = scores.masked_fill(~is_mask, 1e5)
|
426 |
+
|
427 |
+
ids = torch.where(padding_mask, -1, ids)
|
428 |
+
# print("Final", ids.max(), ids.min())
|
429 |
+
return ids
|
430 |
+
|
431 |
+
|
432 |
+
@torch.no_grad()
|
433 |
+
@eval_decorator
|
434 |
+
def edit(self,
|
435 |
+
conds,
|
436 |
+
tokens,
|
437 |
+
m_lens,
|
438 |
+
timesteps: int,
|
439 |
+
cond_scale: int,
|
440 |
+
temperature=1,
|
441 |
+
topk_filter_thres=0.9,
|
442 |
+
gsample=False,
|
443 |
+
force_mask=False,
|
444 |
+
edit_mask=None,
|
445 |
+
padding_mask=None,
|
446 |
+
):
|
447 |
+
|
448 |
+
assert edit_mask.shape == tokens.shape if edit_mask is not None else True
|
449 |
+
device = next(self.parameters()).device
|
450 |
+
seq_len = tokens.shape[1]
|
451 |
+
|
452 |
+
if self.cond_mode == 'text':
|
453 |
+
with torch.no_grad():
|
454 |
+
cond_vector = self.encode_text(conds)
|
455 |
+
elif self.cond_mode == 'action':
|
456 |
+
cond_vector = self.enc_action(conds).to(device)
|
457 |
+
elif self.cond_mode == 'uncond':
|
458 |
+
cond_vector = torch.zeros(1, self.latent_dim).float().to(device)
|
459 |
+
else:
|
460 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
461 |
+
|
462 |
+
if padding_mask == None:
|
463 |
+
padding_mask = ~lengths_to_mask(m_lens, seq_len)
|
464 |
+
|
465 |
+
# Start from all tokens being masked
|
466 |
+
if edit_mask == None:
|
467 |
+
mask_free = True
|
468 |
+
ids = torch.where(padding_mask, self.pad_id, tokens)
|
469 |
+
edit_mask = torch.ones_like(padding_mask)
|
470 |
+
edit_mask = edit_mask & ~padding_mask
|
471 |
+
edit_len = edit_mask.sum(dim=-1)
|
472 |
+
scores = torch.where(edit_mask, 0., 1e5)
|
473 |
+
else:
|
474 |
+
mask_free = False
|
475 |
+
edit_mask = edit_mask & ~padding_mask
|
476 |
+
edit_len = edit_mask.sum(dim=-1)
|
477 |
+
ids = torch.where(edit_mask, self.mask_id, tokens)
|
478 |
+
scores = torch.where(edit_mask, 0., 1e5)
|
479 |
+
starting_temperature = temperature
|
480 |
+
|
481 |
+
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
|
482 |
+
# 0 < timestep < 1
|
483 |
+
rand_mask_prob = 0.16 if mask_free else self.noise_schedule(timestep) # Tensor
|
484 |
+
|
485 |
+
'''
|
486 |
+
Maskout, and cope with variable length
|
487 |
+
'''
|
488 |
+
# fix: the ratio regarding lengths, instead of seq_len
|
489 |
+
num_token_masked = torch.round(rand_mask_prob * edit_len).clamp(min=1) # (b, )
|
490 |
+
|
491 |
+
# select num_token_masked tokens with lowest scores to be masked
|
492 |
+
sorted_indices = scores.argsort(
|
493 |
+
dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
|
494 |
+
ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
|
495 |
+
is_mask = (ranks < num_token_masked.unsqueeze(-1))
|
496 |
+
# is_mask = (torch.rand_like(scores) < 0.8) * ~padding_mask if mask_free else is_mask
|
497 |
+
ids = torch.where(is_mask, self.mask_id, ids)
|
498 |
+
|
499 |
+
'''
|
500 |
+
Preparing input
|
501 |
+
'''
|
502 |
+
# (b, num_token, seqlen)
|
503 |
+
logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector,
|
504 |
+
padding_mask=padding_mask,
|
505 |
+
cond_scale=cond_scale,
|
506 |
+
force_mask=force_mask)
|
507 |
+
|
508 |
+
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
|
509 |
+
# print(logits.shape, self.opt.num_tokens)
|
510 |
+
# clean low prob token
|
511 |
+
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
|
512 |
+
|
513 |
+
'''
|
514 |
+
Update ids
|
515 |
+
'''
|
516 |
+
# if force_mask:
|
517 |
+
temperature = starting_temperature
|
518 |
+
# else:
|
519 |
+
# temperature = starting_temperature * (steps_until_x0 / timesteps)
|
520 |
+
# temperature = max(temperature, 1e-4)
|
521 |
+
# print(filtered_logits.shape)
|
522 |
+
# temperature is annealed, gradually reducing temperature as well as randomness
|
523 |
+
if gsample: # use gumbel_softmax sampling
|
524 |
+
# print("1111")
|
525 |
+
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
|
526 |
+
else: # use multinomial sampling
|
527 |
+
# print("2222")
|
528 |
+
probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
|
529 |
+
# print(temperature, starting_temperature, steps_until_x0, timesteps)
|
530 |
+
# print(probs / temperature)
|
531 |
+
pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
|
532 |
+
|
533 |
+
# print(pred_ids.max(), pred_ids.min())
|
534 |
+
# if pred_ids.
|
535 |
+
ids = torch.where(is_mask, pred_ids, ids)
|
536 |
+
|
537 |
+
'''
|
538 |
+
Updating scores
|
539 |
+
'''
|
540 |
+
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
|
541 |
+
scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1)
|
542 |
+
scores = scores.squeeze(-1) # (b, seqlen)
|
543 |
+
|
544 |
+
# We do not want to re-mask the previously kept tokens, or pad tokens
|
545 |
+
scores = scores.masked_fill(~edit_mask, 1e5) if mask_free else scores.masked_fill(~is_mask, 1e5)
|
546 |
+
|
547 |
+
ids = torch.where(padding_mask, -1, ids)
|
548 |
+
# print("Final", ids.max(), ids.min())
|
549 |
+
return ids
|
550 |
+
|
551 |
+
@torch.no_grad()
|
552 |
+
@eval_decorator
|
553 |
+
def edit_beta(self,
|
554 |
+
conds,
|
555 |
+
conds_og,
|
556 |
+
tokens,
|
557 |
+
m_lens,
|
558 |
+
cond_scale: int,
|
559 |
+
force_mask=False,
|
560 |
+
):
|
561 |
+
|
562 |
+
device = next(self.parameters()).device
|
563 |
+
seq_len = tokens.shape[1]
|
564 |
+
|
565 |
+
if self.cond_mode == 'text':
|
566 |
+
with torch.no_grad():
|
567 |
+
cond_vector = self.encode_text(conds)
|
568 |
+
if conds_og is not None:
|
569 |
+
cond_vector_og = self.encode_text(conds_og)
|
570 |
+
else:
|
571 |
+
cond_vector_og = None
|
572 |
+
elif self.cond_mode == 'action':
|
573 |
+
cond_vector = self.enc_action(conds).to(device)
|
574 |
+
if conds_og is not None:
|
575 |
+
cond_vector_og = self.enc_action(conds_og).to(device)
|
576 |
+
else:
|
577 |
+
cond_vector_og = None
|
578 |
+
else:
|
579 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
580 |
+
|
581 |
+
padding_mask = ~lengths_to_mask(m_lens, seq_len)
|
582 |
+
|
583 |
+
# Start from all tokens being masked
|
584 |
+
ids = torch.where(padding_mask, self.pad_id, tokens) # Do not mask anything
|
585 |
+
|
586 |
+
'''
|
587 |
+
Preparing input
|
588 |
+
'''
|
589 |
+
# (b, num_token, seqlen)
|
590 |
+
logits = self.forward_with_cond_scale(ids,
|
591 |
+
cond_vector=cond_vector,
|
592 |
+
cond_vector_neg=cond_vector_og,
|
593 |
+
padding_mask=padding_mask,
|
594 |
+
cond_scale=cond_scale,
|
595 |
+
force_mask=force_mask)
|
596 |
+
|
597 |
+
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
|
598 |
+
|
599 |
+
'''
|
600 |
+
Updating scores
|
601 |
+
'''
|
602 |
+
probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
|
603 |
+
tokens[tokens == -1] = 0 # just to get through an error when index = -1 using gather
|
604 |
+
og_tokens_scores = probs_without_temperature.gather(2, tokens.unsqueeze(dim=-1)) # (b, seqlen, 1)
|
605 |
+
og_tokens_scores = og_tokens_scores.squeeze(-1) # (b, seqlen)
|
606 |
+
|
607 |
+
return og_tokens_scores
|
608 |
+
|
609 |
+
|
610 |
+
class ResidualTransformer(nn.Module):
|
611 |
+
def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, cond_drop_prob=0.1,
|
612 |
+
num_heads=4, dropout=0.1, clip_dim=512, shared_codebook=False, share_weight=False,
|
613 |
+
clip_version=None, opt=None, **kargs):
|
614 |
+
super(ResidualTransformer, self).__init__()
|
615 |
+
print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')
|
616 |
+
|
617 |
+
# assert shared_codebook == True, "Only support shared codebook right now!"
|
618 |
+
|
619 |
+
self.code_dim = code_dim
|
620 |
+
self.latent_dim = latent_dim
|
621 |
+
self.clip_dim = clip_dim
|
622 |
+
self.dropout = dropout
|
623 |
+
self.opt = opt
|
624 |
+
|
625 |
+
self.cond_mode = cond_mode
|
626 |
+
# self.cond_drop_prob = cond_drop_prob
|
627 |
+
|
628 |
+
if self.cond_mode == 'action':
|
629 |
+
assert 'num_actions' in kargs
|
630 |
+
self.num_actions = kargs.get('num_actions', 1)
|
631 |
+
self.cond_drop_prob = cond_drop_prob
|
632 |
+
|
633 |
+
'''
|
634 |
+
Preparing Networks
|
635 |
+
'''
|
636 |
+
self.input_process = InputProcess(self.code_dim, self.latent_dim)
|
637 |
+
self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)
|
638 |
+
|
639 |
+
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
|
640 |
+
nhead=num_heads,
|
641 |
+
dim_feedforward=ff_size,
|
642 |
+
dropout=dropout,
|
643 |
+
activation='gelu')
|
644 |
+
|
645 |
+
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
|
646 |
+
num_layers=num_layers)
|
647 |
+
|
648 |
+
self.encode_quant = partial(F.one_hot, num_classes=self.opt.num_quantizers)
|
649 |
+
self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
|
650 |
+
|
651 |
+
self.quant_emb = nn.Linear(self.opt.num_quantizers, self.latent_dim)
|
652 |
+
# if self.cond_mode != 'no_cond':
|
653 |
+
if self.cond_mode == 'text':
|
654 |
+
self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim)
|
655 |
+
elif self.cond_mode == 'action':
|
656 |
+
self.cond_emb = nn.Linear(self.num_actions, self.latent_dim)
|
657 |
+
else:
|
658 |
+
raise KeyError("Unsupported condition mode!!!")
|
659 |
+
|
660 |
+
|
661 |
+
_num_tokens = opt.num_tokens + 1 # one dummy tokens for padding
|
662 |
+
self.pad_id = opt.num_tokens
|
663 |
+
|
664 |
+
# self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim)
|
665 |
+
self.output_process = OutputProcess(out_feats=code_dim, latent_dim=latent_dim)
|
666 |
+
|
667 |
+
if shared_codebook:
|
668 |
+
token_embed = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
|
669 |
+
self.token_embed_weight = token_embed.expand(opt.num_quantizers-1, _num_tokens, code_dim)
|
670 |
+
if share_weight:
|
671 |
+
self.output_proj_weight = self.token_embed_weight
|
672 |
+
self.output_proj_bias = None
|
673 |
+
else:
|
674 |
+
output_proj = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
|
675 |
+
output_bias = nn.Parameter(torch.zeros(size=(_num_tokens,)))
|
676 |
+
# self.output_proj_bias = 0
|
677 |
+
self.output_proj_weight = output_proj.expand(opt.num_quantizers-1, _num_tokens, code_dim)
|
678 |
+
self.output_proj_bias = output_bias.expand(opt.num_quantizers-1, _num_tokens)
|
679 |
+
|
680 |
+
else:
|
681 |
+
if share_weight:
|
682 |
+
self.embed_proj_shared_weight = nn.Parameter(torch.normal(mean=0, std=0.02, size=(opt.num_quantizers - 2, _num_tokens, code_dim)))
|
683 |
+
self.token_embed_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
|
684 |
+
self.output_proj_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
|
685 |
+
self.output_proj_bias = None
|
686 |
+
self.registered = False
|
687 |
+
else:
|
688 |
+
output_proj_weight = torch.normal(mean=0, std=0.02,
|
689 |
+
size=(opt.num_quantizers - 1, _num_tokens, code_dim))
|
690 |
+
|
691 |
+
self.output_proj_weight = nn.Parameter(output_proj_weight)
|
692 |
+
self.output_proj_bias = nn.Parameter(torch.zeros(size=(opt.num_quantizers, _num_tokens)))
|
693 |
+
token_embed_weight = torch.normal(mean=0, std=0.02,
|
694 |
+
size=(opt.num_quantizers - 1, _num_tokens, code_dim))
|
695 |
+
self.token_embed_weight = nn.Parameter(token_embed_weight)
|
696 |
+
|
697 |
+
self.apply(self.__init_weights)
|
698 |
+
self.shared_codebook = shared_codebook
|
699 |
+
self.share_weight = share_weight
|
700 |
+
|
701 |
+
if self.cond_mode == 'text':
|
702 |
+
print('Loading CLIP...')
|
703 |
+
self.clip_version = clip_version
|
704 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
705 |
+
|
706 |
+
# def
|
707 |
+
|
708 |
+
def mask_cond(self, cond, force_mask=False):
|
709 |
+
bs, d = cond.shape
|
710 |
+
if force_mask:
|
711 |
+
return torch.zeros_like(cond)
|
712 |
+
elif self.training and self.cond_drop_prob > 0.:
|
713 |
+
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
|
714 |
+
return cond * (1. - mask)
|
715 |
+
else:
|
716 |
+
return cond
|
717 |
+
|
718 |
+
def __init_weights(self, module):
|
719 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
720 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
721 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
722 |
+
module.bias.data.zero_()
|
723 |
+
elif isinstance(module, nn.LayerNorm):
|
724 |
+
module.bias.data.zero_()
|
725 |
+
module.weight.data.fill_(1.0)
|
726 |
+
|
727 |
+
def parameters_wo_clip(self):
|
728 |
+
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
|
729 |
+
|
730 |
+
def load_and_freeze_clip(self, clip_version):
|
731 |
+
clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
|
732 |
+
jit=False) # Must set jit=False for training
|
733 |
+
# Cannot run on cpu
|
734 |
+
clip.model.convert_weights(
|
735 |
+
clip_model) # Actually this line is unnecessary since clip by default already on float16
|
736 |
+
# Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
|
737 |
+
|
738 |
+
# Freeze CLIP weights
|
739 |
+
clip_model.eval()
|
740 |
+
for p in clip_model.parameters():
|
741 |
+
p.requires_grad = False
|
742 |
+
|
743 |
+
return clip_model
|
744 |
+
|
745 |
+
def encode_text(self, raw_text):
|
746 |
+
device = next(self.parameters()).device
|
747 |
+
text = clip.tokenize(raw_text, truncate=True).to(device)
|
748 |
+
feat_clip_text = self.clip_model.encode_text(text).float()
|
749 |
+
return feat_clip_text
|
750 |
+
|
751 |
+
|
752 |
+
def q_schedule(self, bs, low, high):
|
753 |
+
noise = uniform((bs,), device=self.opt.device)
|
754 |
+
schedule = 1 - cosine_schedule(noise)
|
755 |
+
return torch.round(schedule * (high - low)) + low
|
756 |
+
|
757 |
+
def process_embed_proj_weight(self):
|
758 |
+
if self.share_weight and (not self.shared_codebook):
|
759 |
+
# if not self.registered:
|
760 |
+
self.output_proj_weight = torch.cat([self.embed_proj_shared_weight, self.output_proj_weight_], dim=0)
|
761 |
+
self.token_embed_weight = torch.cat([self.token_embed_weight_, self.embed_proj_shared_weight], dim=0)
|
762 |
+
# self.registered = True
|
763 |
+
|
764 |
+
def output_project(self, logits, qids):
|
765 |
+
'''
|
766 |
+
:logits: (bs, code_dim, seqlen)
|
767 |
+
:qids: (bs)
|
768 |
+
|
769 |
+
:return:
|
770 |
+
-logits (bs, ntoken, seqlen)
|
771 |
+
'''
|
772 |
+
# (num_qlayers-1, num_token, code_dim) -> (bs, ntoken, code_dim)
|
773 |
+
output_proj_weight = self.output_proj_weight[qids]
|
774 |
+
# (num_qlayers, ntoken) -> (bs, ntoken)
|
775 |
+
output_proj_bias = None if self.output_proj_bias is None else self.output_proj_bias[qids]
|
776 |
+
|
777 |
+
output = torch.einsum('bnc, bcs->bns', output_proj_weight, logits)
|
778 |
+
if output_proj_bias is not None:
|
779 |
+
output += output + output_proj_bias.unsqueeze(-1)
|
780 |
+
return output
|
781 |
+
|
782 |
+
|
783 |
+
|
784 |
+
def trans_forward(self, motion_codes, qids, cond, padding_mask, force_mask=False):
|
785 |
+
'''
|
786 |
+
:param motion_codes: (b, seqlen, d)
|
787 |
+
:padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
|
788 |
+
:param qids: (b), quantizer layer ids
|
789 |
+
:param cond: (b, embed_dim) for text, (b, num_actions) for action
|
790 |
+
:return:
|
791 |
+
-logits: (b, num_token, seqlen)
|
792 |
+
'''
|
793 |
+
cond = self.mask_cond(cond, force_mask=force_mask)
|
794 |
+
|
795 |
+
# (b, seqlen, d) -> (seqlen, b, latent_dim)
|
796 |
+
x = self.input_process(motion_codes)
|
797 |
+
|
798 |
+
# (b, num_quantizer)
|
799 |
+
q_onehot = self.encode_quant(qids).float().to(x.device)
|
800 |
+
|
801 |
+
q_emb = self.quant_emb(q_onehot).unsqueeze(0) # (1, b, latent_dim)
|
802 |
+
cond = self.cond_emb(cond).unsqueeze(0) # (1, b, latent_dim)
|
803 |
+
|
804 |
+
x = self.position_enc(x)
|
805 |
+
xseq = torch.cat([cond, q_emb, x], dim=0) # (seqlen+2, b, latent_dim)
|
806 |
+
|
807 |
+
padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:2]), padding_mask], dim=1) # (b, seqlen+2)
|
808 |
+
output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[2:] # (seqlen, b, e)
|
809 |
+
logits = self.output_process(output)
|
810 |
+
return logits
|
811 |
+
|
812 |
+
def forward_with_cond_scale(self,
|
813 |
+
motion_codes,
|
814 |
+
q_id,
|
815 |
+
cond_vector,
|
816 |
+
padding_mask,
|
817 |
+
cond_scale=3,
|
818 |
+
force_mask=False):
|
819 |
+
bs = motion_codes.shape[0]
|
820 |
+
# if cond_scale == 1:
|
821 |
+
qids = torch.full((bs,), q_id, dtype=torch.long, device=motion_codes.device)
|
822 |
+
if force_mask:
|
823 |
+
logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True)
|
824 |
+
logits = self.output_project(logits, qids-1)
|
825 |
+
return logits
|
826 |
+
|
827 |
+
logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask)
|
828 |
+
logits = self.output_project(logits, qids-1)
|
829 |
+
if cond_scale == 1:
|
830 |
+
return logits
|
831 |
+
|
832 |
+
aux_logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True)
|
833 |
+
aux_logits = self.output_project(aux_logits, qids-1)
|
834 |
+
|
835 |
+
scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
|
836 |
+
return scaled_logits
|
837 |
+
|
838 |
+
def forward(self, all_indices, y, m_lens):
|
839 |
+
'''
|
840 |
+
:param all_indices: (b, n, q)
|
841 |
+
:param y: raw text for cond_mode=text, (b, ) for cond_mode=action
|
842 |
+
:m_lens: (b,)
|
843 |
+
:return:
|
844 |
+
'''
|
845 |
+
|
846 |
+
self.process_embed_proj_weight()
|
847 |
+
|
848 |
+
bs, ntokens, num_quant_layers = all_indices.shape
|
849 |
+
device = all_indices.device
|
850 |
+
|
851 |
+
# Positions that are PADDED are ALL FALSE
|
852 |
+
non_pad_mask = lengths_to_mask(m_lens, ntokens) # (b, n)
|
853 |
+
|
854 |
+
q_non_pad_mask = repeat(non_pad_mask, 'b n -> b n q', q=num_quant_layers)
|
855 |
+
all_indices = torch.where(q_non_pad_mask, all_indices, self.pad_id) #(b, n, q)
|
856 |
+
|
857 |
+
# randomly sample quantization layers to work on, [1, num_q)
|
858 |
+
active_q_layers = q_schedule(bs, low=1, high=num_quant_layers, device=device)
|
859 |
+
|
860 |
+
# print(self.token_embed_weight.shape, all_indices.shape)
|
861 |
+
token_embed = repeat(self.token_embed_weight, 'q c d-> b c d q', b=bs)
|
862 |
+
gather_indices = repeat(all_indices[..., :-1], 'b n q -> b n d q', d=token_embed.shape[2])
|
863 |
+
# print(token_embed.shape, gather_indices.shape)
|
864 |
+
all_codes = token_embed.gather(1, gather_indices) # (b, n, d, q-1)
|
865 |
+
|
866 |
+
cumsum_codes = torch.cumsum(all_codes, dim=-1) #(b, n, d, q-1)
|
867 |
+
|
868 |
+
active_indices = all_indices[torch.arange(bs), :, active_q_layers] # (b, n)
|
869 |
+
history_sum = cumsum_codes[torch.arange(bs), :, :, active_q_layers - 1]
|
870 |
+
|
871 |
+
force_mask = False
|
872 |
+
if self.cond_mode == 'text':
|
873 |
+
with torch.no_grad():
|
874 |
+
cond_vector = self.encode_text(y)
|
875 |
+
elif self.cond_mode == 'action':
|
876 |
+
cond_vector = self.enc_action(y).to(device).float()
|
877 |
+
elif self.cond_mode == 'uncond':
|
878 |
+
cond_vector = torch.zeros(bs, self.latent_dim).float().to(device)
|
879 |
+
force_mask = True
|
880 |
+
else:
|
881 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
882 |
+
|
883 |
+
logits = self.trans_forward(history_sum, active_q_layers, cond_vector, ~non_pad_mask, force_mask)
|
884 |
+
logits = self.output_project(logits, active_q_layers-1)
|
885 |
+
ce_loss, pred_id, acc = cal_performance(logits, active_indices, ignore_index=self.pad_id)
|
886 |
+
|
887 |
+
return ce_loss, pred_id, acc
|
888 |
+
|
889 |
+
@torch.no_grad()
|
890 |
+
@eval_decorator
|
891 |
+
def generate(self,
|
892 |
+
motion_ids,
|
893 |
+
conds,
|
894 |
+
m_lens,
|
895 |
+
temperature=1,
|
896 |
+
topk_filter_thres=0.9,
|
897 |
+
cond_scale=2,
|
898 |
+
num_res_layers=-1, # If it's -1, use all.
|
899 |
+
):
|
900 |
+
|
901 |
+
# print(self.opt.num_quantizers)
|
902 |
+
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
|
903 |
+
self.process_embed_proj_weight()
|
904 |
+
|
905 |
+
device = next(self.parameters()).device
|
906 |
+
seq_len = motion_ids.shape[1]
|
907 |
+
batch_size = len(conds)
|
908 |
+
|
909 |
+
if self.cond_mode == 'text':
|
910 |
+
with torch.no_grad():
|
911 |
+
cond_vector = self.encode_text(conds)
|
912 |
+
elif self.cond_mode == 'action':
|
913 |
+
cond_vector = self.enc_action(conds).to(device)
|
914 |
+
elif self.cond_mode == 'uncond':
|
915 |
+
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
|
916 |
+
else:
|
917 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
918 |
+
|
919 |
+
# token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
|
920 |
+
# gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
|
921 |
+
# history_sum = token_embed.gather(1, gathered_ids)
|
922 |
+
|
923 |
+
# print(pa, seq_len)
|
924 |
+
padding_mask = ~lengths_to_mask(m_lens, seq_len)
|
925 |
+
# print(padding_mask.shape, motion_ids.shape)
|
926 |
+
motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
|
927 |
+
all_indices = [motion_ids]
|
928 |
+
history_sum = 0
|
929 |
+
num_quant_layers = self.opt.num_quantizers if num_res_layers==-1 else num_res_layers+1
|
930 |
+
|
931 |
+
for i in range(1, num_quant_layers):
|
932 |
+
# print(f"--> Working on {i}-th quantizer")
|
933 |
+
# Start from all tokens being masked
|
934 |
+
# qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
|
935 |
+
token_embed = self.token_embed_weight[i-1]
|
936 |
+
token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
|
937 |
+
gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
|
938 |
+
history_sum += token_embed.gather(1, gathered_ids)
|
939 |
+
|
940 |
+
logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale)
|
941 |
+
# logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)
|
942 |
+
|
943 |
+
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
|
944 |
+
# clean low prob token
|
945 |
+
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
|
946 |
+
|
947 |
+
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
|
948 |
+
|
949 |
+
# probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
|
950 |
+
# # print(temperature, starting_temperature, steps_until_x0, timesteps)
|
951 |
+
# # print(probs / temperature)
|
952 |
+
# pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
|
953 |
+
|
954 |
+
ids = torch.where(padding_mask, self.pad_id, pred_ids)
|
955 |
+
|
956 |
+
motion_ids = ids
|
957 |
+
all_indices.append(ids)
|
958 |
+
|
959 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
960 |
+
# padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
|
961 |
+
# all_indices = torch.where(padding_mask, -1, all_indices)
|
962 |
+
all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
|
963 |
+
# all_indices = all_indices.masked_fill()
|
964 |
+
return all_indices
|
965 |
+
|
966 |
+
@torch.no_grad()
|
967 |
+
@eval_decorator
|
968 |
+
def edit(self,
|
969 |
+
motion_ids,
|
970 |
+
conds,
|
971 |
+
m_lens,
|
972 |
+
temperature=1,
|
973 |
+
topk_filter_thres=0.9,
|
974 |
+
cond_scale=2
|
975 |
+
):
|
976 |
+
|
977 |
+
# print(self.opt.num_quantizers)
|
978 |
+
# assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
|
979 |
+
self.process_embed_proj_weight()
|
980 |
+
|
981 |
+
device = next(self.parameters()).device
|
982 |
+
seq_len = motion_ids.shape[1]
|
983 |
+
batch_size = len(conds)
|
984 |
+
|
985 |
+
if self.cond_mode == 'text':
|
986 |
+
with torch.no_grad():
|
987 |
+
cond_vector = self.encode_text(conds)
|
988 |
+
elif self.cond_mode == 'action':
|
989 |
+
cond_vector = self.enc_action(conds).to(device)
|
990 |
+
elif self.cond_mode == 'uncond':
|
991 |
+
cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
|
992 |
+
else:
|
993 |
+
raise NotImplementedError("Unsupported condition mode!!!")
|
994 |
+
|
995 |
+
# token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
|
996 |
+
# gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
|
997 |
+
# history_sum = token_embed.gather(1, gathered_ids)
|
998 |
+
|
999 |
+
# print(pa, seq_len)
|
1000 |
+
padding_mask = ~lengths_to_mask(m_lens, seq_len)
|
1001 |
+
# print(padding_mask.shape, motion_ids.shape)
|
1002 |
+
motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
|
1003 |
+
all_indices = [motion_ids]
|
1004 |
+
history_sum = 0
|
1005 |
+
|
1006 |
+
for i in range(1, self.opt.num_quantizers):
|
1007 |
+
# print(f"--> Working on {i}-th quantizer")
|
1008 |
+
# Start from all tokens being masked
|
1009 |
+
# qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
|
1010 |
+
token_embed = self.token_embed_weight[i-1]
|
1011 |
+
token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
|
1012 |
+
gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
|
1013 |
+
history_sum += token_embed.gather(1, gathered_ids)
|
1014 |
+
|
1015 |
+
logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale)
|
1016 |
+
# logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)
|
1017 |
+
|
1018 |
+
logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
|
1019 |
+
# clean low prob token
|
1020 |
+
filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
|
1021 |
+
|
1022 |
+
pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
|
1023 |
+
|
1024 |
+
# probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
|
1025 |
+
# # print(temperature, starting_temperature, steps_until_x0, timesteps)
|
1026 |
+
# # print(probs / temperature)
|
1027 |
+
# pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
|
1028 |
+
|
1029 |
+
ids = torch.where(padding_mask, self.pad_id, pred_ids)
|
1030 |
+
|
1031 |
+
motion_ids = ids
|
1032 |
+
all_indices.append(ids)
|
1033 |
+
|
1034 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
1035 |
+
# padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
|
1036 |
+
# all_indices = torch.where(padding_mask, -1, all_indices)
|
1037 |
+
all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
|
1038 |
+
# all_indices = all_indices.masked_fill()
|
1039 |
+
return all_indices
|
models/mask_transformer/transformer_trainer.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import defaultdict
|
3 |
+
import torch.optim as optim
|
4 |
+
# import tensorflow as tf
|
5 |
+
from torch.utils.tensorboard import SummaryWriter
|
6 |
+
from collections import OrderedDict
|
7 |
+
from utils.utils import *
|
8 |
+
from os.path import join as pjoin
|
9 |
+
from utils.eval_t2m import evaluation_mask_transformer, evaluation_res_transformer
|
10 |
+
from models.mask_transformer.tools import *
|
11 |
+
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
|
14 |
+
def def_value():
|
15 |
+
return 0.0
|
16 |
+
|
17 |
+
class MaskTransformerTrainer:
|
18 |
+
def __init__(self, args, t2m_transformer, vq_model):
|
19 |
+
self.opt = args
|
20 |
+
self.t2m_transformer = t2m_transformer
|
21 |
+
self.vq_model = vq_model
|
22 |
+
self.device = args.device
|
23 |
+
self.vq_model.eval()
|
24 |
+
|
25 |
+
if args.is_train:
|
26 |
+
self.logger = SummaryWriter(args.log_dir)
|
27 |
+
|
28 |
+
|
29 |
+
def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):
|
30 |
+
|
31 |
+
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
|
32 |
+
for param_group in self.opt_t2m_transformer.param_groups:
|
33 |
+
param_group["lr"] = current_lr
|
34 |
+
|
35 |
+
return current_lr
|
36 |
+
|
37 |
+
|
38 |
+
def forward(self, batch_data):
|
39 |
+
|
40 |
+
conds, motion, m_lens = batch_data
|
41 |
+
motion = motion.detach().float().to(self.device)
|
42 |
+
m_lens = m_lens.detach().long().to(self.device)
|
43 |
+
|
44 |
+
# (b, n, q)
|
45 |
+
code_idx, _ = self.vq_model.encode(motion)
|
46 |
+
m_lens = m_lens // 4
|
47 |
+
|
48 |
+
conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds
|
49 |
+
|
50 |
+
# loss_dict = {}
|
51 |
+
# self.pred_ids = []
|
52 |
+
# self.acc = []
|
53 |
+
|
54 |
+
_loss, _pred_ids, _acc = self.t2m_transformer(code_idx[..., 0], conds, m_lens)
|
55 |
+
|
56 |
+
return _loss, _acc
|
57 |
+
|
58 |
+
def update(self, batch_data):
|
59 |
+
loss, acc = self.forward(batch_data)
|
60 |
+
|
61 |
+
self.opt_t2m_transformer.zero_grad()
|
62 |
+
loss.backward()
|
63 |
+
self.opt_t2m_transformer.step()
|
64 |
+
self.scheduler.step()
|
65 |
+
|
66 |
+
return loss.item(), acc
|
67 |
+
|
68 |
+
def save(self, file_name, ep, total_it):
|
69 |
+
t2m_trans_state_dict = self.t2m_transformer.state_dict()
|
70 |
+
clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_model.')]
|
71 |
+
for e in clip_weights:
|
72 |
+
del t2m_trans_state_dict[e]
|
73 |
+
state = {
|
74 |
+
't2m_transformer': t2m_trans_state_dict,
|
75 |
+
'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(),
|
76 |
+
'scheduler':self.scheduler.state_dict(),
|
77 |
+
'ep': ep,
|
78 |
+
'total_it': total_it,
|
79 |
+
}
|
80 |
+
torch.save(state, file_name)
|
81 |
+
|
82 |
+
def resume(self, model_dir):
|
83 |
+
checkpoint = torch.load(model_dir, map_location=self.device)
|
84 |
+
missing_keys, unexpected_keys = self.t2m_transformer.load_state_dict(checkpoint['t2m_transformer'], strict=False)
|
85 |
+
assert len(unexpected_keys) == 0
|
86 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
87 |
+
|
88 |
+
try:
|
89 |
+
self.opt_t2m_transformer.load_state_dict(checkpoint['opt_t2m_transformer']) # Optimizer
|
90 |
+
|
91 |
+
self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler
|
92 |
+
except:
|
93 |
+
print('Resume wo optimizer')
|
94 |
+
return checkpoint['ep'], checkpoint['total_it']
|
95 |
+
|
96 |
+
def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval):
|
97 |
+
self.t2m_transformer.to(self.device)
|
98 |
+
self.vq_model.to(self.device)
|
99 |
+
|
100 |
+
self.opt_t2m_transformer = optim.AdamW(self.t2m_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5)
|
101 |
+
self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_t2m_transformer,
|
102 |
+
milestones=self.opt.milestones,
|
103 |
+
gamma=self.opt.gamma)
|
104 |
+
|
105 |
+
epoch = 0
|
106 |
+
it = 0
|
107 |
+
|
108 |
+
if self.opt.is_continue:
|
109 |
+
model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO
|
110 |
+
epoch, it = self.resume(model_dir)
|
111 |
+
print("Load model epoch:%d iterations:%d"%(epoch, it))
|
112 |
+
|
113 |
+
start_time = time.time()
|
114 |
+
total_iters = self.opt.max_epoch * len(train_loader)
|
115 |
+
print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
|
116 |
+
print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader)))
|
117 |
+
logs = defaultdict(def_value, OrderedDict())
|
118 |
+
|
119 |
+
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer(
|
120 |
+
self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch,
|
121 |
+
best_fid=100, best_div=100,
|
122 |
+
best_top1=0, best_top2=0, best_top3=0,
|
123 |
+
best_matching=100, eval_wrapper=eval_wrapper,
|
124 |
+
plot_func=plot_eval, save_ckpt=False, save_anim=False
|
125 |
+
)
|
126 |
+
best_acc = 0.
|
127 |
+
|
128 |
+
while epoch < self.opt.max_epoch:
|
129 |
+
self.t2m_transformer.train()
|
130 |
+
self.vq_model.eval()
|
131 |
+
|
132 |
+
for i, batch in enumerate(train_loader):
|
133 |
+
it += 1
|
134 |
+
if it < self.opt.warm_up_iter:
|
135 |
+
self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)
|
136 |
+
|
137 |
+
loss, acc = self.update(batch_data=batch)
|
138 |
+
logs['loss'] += loss
|
139 |
+
logs['acc'] += acc
|
140 |
+
logs['lr'] += self.opt_t2m_transformer.param_groups[0]['lr']
|
141 |
+
|
142 |
+
if it % self.opt.log_every == 0:
|
143 |
+
mean_loss = OrderedDict()
|
144 |
+
# self.logger.add_scalar('val_loss', val_loss, it)
|
145 |
+
# self.l
|
146 |
+
for tag, value in logs.items():
|
147 |
+
self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
|
148 |
+
mean_loss[tag] = value / self.opt.log_every
|
149 |
+
logs = defaultdict(def_value, OrderedDict())
|
150 |
+
print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
|
151 |
+
|
152 |
+
if it % self.opt.save_latest == 0:
|
153 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
154 |
+
|
155 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
156 |
+
epoch += 1
|
157 |
+
|
158 |
+
print('Validation time:')
|
159 |
+
self.vq_model.eval()
|
160 |
+
self.t2m_transformer.eval()
|
161 |
+
|
162 |
+
val_loss = []
|
163 |
+
val_acc = []
|
164 |
+
with torch.no_grad():
|
165 |
+
for i, batch_data in enumerate(val_loader):
|
166 |
+
loss, acc = self.forward(batch_data)
|
167 |
+
val_loss.append(loss.item())
|
168 |
+
val_acc.append(acc)
|
169 |
+
|
170 |
+
print(f"Validation loss:{np.mean(val_loss):.3f}, accuracy:{np.mean(val_acc):.3f}")
|
171 |
+
|
172 |
+
self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch)
|
173 |
+
self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch)
|
174 |
+
|
175 |
+
if np.mean(val_acc) > best_acc:
|
176 |
+
print(f"Improved accuracy from {best_acc:.02f} to {np.mean(val_acc)}!!!")
|
177 |
+
self.save(pjoin(self.opt.model_dir, 'net_best_acc.tar'), epoch, it)
|
178 |
+
best_acc = np.mean(val_acc)
|
179 |
+
|
180 |
+
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer(
|
181 |
+
self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid,
|
182 |
+
best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3,
|
183 |
+
best_matching=best_matching, eval_wrapper=eval_wrapper,
|
184 |
+
plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0)
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
class ResidualTransformerTrainer:
|
189 |
+
def __init__(self, args, res_transformer, vq_model):
|
190 |
+
self.opt = args
|
191 |
+
self.res_transformer = res_transformer
|
192 |
+
self.vq_model = vq_model
|
193 |
+
self.device = args.device
|
194 |
+
self.vq_model.eval()
|
195 |
+
|
196 |
+
if args.is_train:
|
197 |
+
self.logger = SummaryWriter(args.log_dir)
|
198 |
+
# self.l1_criterion = torch.nn.SmoothL1Loss()
|
199 |
+
|
200 |
+
|
201 |
+
def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):
|
202 |
+
|
203 |
+
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
|
204 |
+
for param_group in self.opt_res_transformer.param_groups:
|
205 |
+
param_group["lr"] = current_lr
|
206 |
+
|
207 |
+
return current_lr
|
208 |
+
|
209 |
+
|
210 |
+
def forward(self, batch_data):
|
211 |
+
|
212 |
+
conds, motion, m_lens = batch_data
|
213 |
+
motion = motion.detach().float().to(self.device)
|
214 |
+
m_lens = m_lens.detach().long().to(self.device)
|
215 |
+
|
216 |
+
# (b, n, q), (q, b, n ,d)
|
217 |
+
code_idx, all_codes = self.vq_model.encode(motion)
|
218 |
+
m_lens = m_lens // 4
|
219 |
+
|
220 |
+
conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds
|
221 |
+
|
222 |
+
ce_loss, pred_ids, acc = self.res_transformer(code_idx, conds, m_lens)
|
223 |
+
|
224 |
+
return ce_loss, acc
|
225 |
+
|
226 |
+
def update(self, batch_data):
|
227 |
+
loss, acc = self.forward(batch_data)
|
228 |
+
|
229 |
+
self.opt_res_transformer.zero_grad()
|
230 |
+
loss.backward()
|
231 |
+
self.opt_res_transformer.step()
|
232 |
+
self.scheduler.step()
|
233 |
+
|
234 |
+
return loss.item(), acc
|
235 |
+
|
236 |
+
def save(self, file_name, ep, total_it):
|
237 |
+
res_trans_state_dict = self.res_transformer.state_dict()
|
238 |
+
clip_weights = [e for e in res_trans_state_dict.keys() if e.startswith('clip_model.')]
|
239 |
+
for e in clip_weights:
|
240 |
+
del res_trans_state_dict[e]
|
241 |
+
state = {
|
242 |
+
'res_transformer': res_trans_state_dict,
|
243 |
+
'opt_res_transformer': self.opt_res_transformer.state_dict(),
|
244 |
+
'scheduler':self.scheduler.state_dict(),
|
245 |
+
'ep': ep,
|
246 |
+
'total_it': total_it,
|
247 |
+
}
|
248 |
+
torch.save(state, file_name)
|
249 |
+
|
250 |
+
def resume(self, model_dir):
|
251 |
+
checkpoint = torch.load(model_dir, map_location=self.device)
|
252 |
+
missing_keys, unexpected_keys = self.res_transformer.load_state_dict(checkpoint['res_transformer'], strict=False)
|
253 |
+
assert len(unexpected_keys) == 0
|
254 |
+
assert all([k.startswith('clip_model.') for k in missing_keys])
|
255 |
+
|
256 |
+
try:
|
257 |
+
self.opt_res_transformer.load_state_dict(checkpoint['opt_res_transformer']) # Optimizer
|
258 |
+
|
259 |
+
self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler
|
260 |
+
except:
|
261 |
+
print('Resume wo optimizer')
|
262 |
+
return checkpoint['ep'], checkpoint['total_it']
|
263 |
+
|
264 |
+
def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval):
|
265 |
+
self.res_transformer.to(self.device)
|
266 |
+
self.vq_model.to(self.device)
|
267 |
+
|
268 |
+
self.opt_res_transformer = optim.AdamW(self.res_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5)
|
269 |
+
self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_res_transformer,
|
270 |
+
milestones=self.opt.milestones,
|
271 |
+
gamma=self.opt.gamma)
|
272 |
+
|
273 |
+
epoch = 0
|
274 |
+
it = 0
|
275 |
+
|
276 |
+
if self.opt.is_continue:
|
277 |
+
model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO
|
278 |
+
epoch, it = self.resume(model_dir)
|
279 |
+
print("Load model epoch:%d iterations:%d"%(epoch, it))
|
280 |
+
|
281 |
+
start_time = time.time()
|
282 |
+
total_iters = self.opt.max_epoch * len(train_loader)
|
283 |
+
print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
|
284 |
+
print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader)))
|
285 |
+
logs = defaultdict(def_value, OrderedDict())
|
286 |
+
|
287 |
+
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer(
|
288 |
+
self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch,
|
289 |
+
best_fid=100, best_div=100,
|
290 |
+
best_top1=0, best_top2=0, best_top3=0,
|
291 |
+
best_matching=100, eval_wrapper=eval_wrapper,
|
292 |
+
plot_func=plot_eval, save_ckpt=False, save_anim=False
|
293 |
+
)
|
294 |
+
best_loss = 100
|
295 |
+
best_acc = 0
|
296 |
+
|
297 |
+
while epoch < self.opt.max_epoch:
|
298 |
+
self.res_transformer.train()
|
299 |
+
self.vq_model.eval()
|
300 |
+
|
301 |
+
for i, batch in enumerate(train_loader):
|
302 |
+
it += 1
|
303 |
+
if it < self.opt.warm_up_iter:
|
304 |
+
self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)
|
305 |
+
|
306 |
+
loss, acc = self.update(batch_data=batch)
|
307 |
+
logs['loss'] += loss
|
308 |
+
logs["acc"] += acc
|
309 |
+
logs['lr'] += self.opt_res_transformer.param_groups[0]['lr']
|
310 |
+
|
311 |
+
if it % self.opt.log_every == 0:
|
312 |
+
mean_loss = OrderedDict()
|
313 |
+
# self.logger.add_scalar('val_loss', val_loss, it)
|
314 |
+
# self.l
|
315 |
+
for tag, value in logs.items():
|
316 |
+
self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
|
317 |
+
mean_loss[tag] = value / self.opt.log_every
|
318 |
+
logs = defaultdict(def_value, OrderedDict())
|
319 |
+
print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
|
320 |
+
|
321 |
+
if it % self.opt.save_latest == 0:
|
322 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
323 |
+
|
324 |
+
epoch += 1
|
325 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
326 |
+
|
327 |
+
print('Validation time:')
|
328 |
+
self.vq_model.eval()
|
329 |
+
self.res_transformer.eval()
|
330 |
+
|
331 |
+
val_loss = []
|
332 |
+
val_acc = []
|
333 |
+
with torch.no_grad():
|
334 |
+
for i, batch_data in enumerate(val_loader):
|
335 |
+
loss, acc = self.forward(batch_data)
|
336 |
+
val_loss.append(loss.item())
|
337 |
+
val_acc.append(acc)
|
338 |
+
|
339 |
+
print(f"Validation loss:{np.mean(val_loss):.3f}, Accuracy:{np.mean(val_acc):.3f}")
|
340 |
+
|
341 |
+
self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch)
|
342 |
+
self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch)
|
343 |
+
|
344 |
+
if np.mean(val_loss) < best_loss:
|
345 |
+
print(f"Improved loss from {best_loss:.02f} to {np.mean(val_loss)}!!!")
|
346 |
+
self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it)
|
347 |
+
best_loss = np.mean(val_loss)
|
348 |
+
|
349 |
+
if np.mean(val_acc) > best_acc:
|
350 |
+
print(f"Improved acc from {best_acc:.02f} to {np.mean(val_acc)}!!!")
|
351 |
+
# self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it)
|
352 |
+
best_acc = np.mean(val_acc)
|
353 |
+
|
354 |
+
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer(
|
355 |
+
self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid,
|
356 |
+
best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3,
|
357 |
+
best_matching=best_matching, eval_wrapper=eval_wrapper,
|
358 |
+
plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0)
|
359 |
+
)
|
models/t2m_eval_modules.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import time
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
8 |
+
# from networks.layers import *
|
9 |
+
|
10 |
+
|
11 |
+
def init_weight(m):
|
12 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
|
13 |
+
nn.init.xavier_normal_(m.weight)
|
14 |
+
# m.bias.data.fill_(0.01)
|
15 |
+
if m.bias is not None:
|
16 |
+
nn.init.constant_(m.bias, 0)
|
17 |
+
|
18 |
+
|
19 |
+
# batch_size, dimension and position
|
20 |
+
# output: (batch_size, dim)
|
21 |
+
def positional_encoding(batch_size, dim, pos):
|
22 |
+
assert batch_size == pos.shape[0]
|
23 |
+
positions_enc = np.array([
|
24 |
+
[pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)]
|
25 |
+
for j in range(batch_size)
|
26 |
+
], dtype=np.float32)
|
27 |
+
positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2])
|
28 |
+
positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2])
|
29 |
+
return torch.from_numpy(positions_enc).float()
|
30 |
+
|
31 |
+
|
32 |
+
def get_padding_mask(batch_size, seq_len, cap_lens):
|
33 |
+
cap_lens = cap_lens.data.tolist()
|
34 |
+
mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32)
|
35 |
+
for i, cap_len in enumerate(cap_lens):
|
36 |
+
mask_2d[i, :, :cap_len] = 0
|
37 |
+
return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone()
|
38 |
+
|
39 |
+
|
40 |
+
def top_k_logits(logits, k):
|
41 |
+
v, ix = torch.topk(logits, k)
|
42 |
+
out = logits.clone()
|
43 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
44 |
+
return out
|
45 |
+
|
46 |
+
|
47 |
+
class PositionalEncoding(nn.Module):
|
48 |
+
|
49 |
+
def __init__(self, d_model, max_len=300):
|
50 |
+
super(PositionalEncoding, self).__init__()
|
51 |
+
|
52 |
+
pe = torch.zeros(max_len, d_model)
|
53 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
54 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
55 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
56 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
57 |
+
# pe = pe.unsqueeze(0).transpose(0, 1)
|
58 |
+
self.register_buffer('pe', pe)
|
59 |
+
|
60 |
+
def forward(self, pos):
|
61 |
+
return self.pe[pos]
|
62 |
+
|
63 |
+
|
64 |
+
class MovementConvEncoder(nn.Module):
|
65 |
+
def __init__(self, input_size, hidden_size, output_size):
|
66 |
+
super(MovementConvEncoder, self).__init__()
|
67 |
+
self.main = nn.Sequential(
|
68 |
+
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
|
69 |
+
nn.Dropout(0.2, inplace=True),
|
70 |
+
nn.LeakyReLU(0.2, inplace=True),
|
71 |
+
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
|
72 |
+
nn.Dropout(0.2, inplace=True),
|
73 |
+
nn.LeakyReLU(0.2, inplace=True),
|
74 |
+
)
|
75 |
+
self.out_net = nn.Linear(output_size, output_size)
|
76 |
+
self.main.apply(init_weight)
|
77 |
+
self.out_net.apply(init_weight)
|
78 |
+
|
79 |
+
def forward(self, inputs):
|
80 |
+
inputs = inputs.permute(0, 2, 1)
|
81 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
82 |
+
# print(outputs.shape)
|
83 |
+
return self.out_net(outputs)
|
84 |
+
|
85 |
+
|
86 |
+
class MovementConvDecoder(nn.Module):
|
87 |
+
def __init__(self, input_size, hidden_size, output_size):
|
88 |
+
super(MovementConvDecoder, self).__init__()
|
89 |
+
self.main = nn.Sequential(
|
90 |
+
nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
|
91 |
+
# nn.Dropout(0.2, inplace=True),
|
92 |
+
nn.LeakyReLU(0.2, inplace=True),
|
93 |
+
nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
|
94 |
+
# nn.Dropout(0.2, inplace=True),
|
95 |
+
nn.LeakyReLU(0.2, inplace=True),
|
96 |
+
)
|
97 |
+
self.out_net = nn.Linear(output_size, output_size)
|
98 |
+
|
99 |
+
self.main.apply(init_weight)
|
100 |
+
self.out_net.apply(init_weight)
|
101 |
+
|
102 |
+
def forward(self, inputs):
|
103 |
+
inputs = inputs.permute(0, 2, 1)
|
104 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
105 |
+
return self.out_net(outputs)
|
106 |
+
|
107 |
+
class TextEncoderBiGRUCo(nn.Module):
|
108 |
+
def __init__(self, word_size, pos_size, hidden_size, output_size, device):
|
109 |
+
super(TextEncoderBiGRUCo, self).__init__()
|
110 |
+
self.device = device
|
111 |
+
|
112 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
113 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
114 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
115 |
+
self.output_net = nn.Sequential(
|
116 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
117 |
+
nn.LayerNorm(hidden_size),
|
118 |
+
nn.LeakyReLU(0.2, inplace=True),
|
119 |
+
nn.Linear(hidden_size, output_size)
|
120 |
+
)
|
121 |
+
|
122 |
+
self.input_emb.apply(init_weight)
|
123 |
+
self.pos_emb.apply(init_weight)
|
124 |
+
self.output_net.apply(init_weight)
|
125 |
+
# self.linear2.apply(init_weight)
|
126 |
+
# self.batch_size = batch_size
|
127 |
+
self.hidden_size = hidden_size
|
128 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
129 |
+
|
130 |
+
# input(batch_size, seq_len, dim)
|
131 |
+
def forward(self, word_embs, pos_onehot, cap_lens):
|
132 |
+
num_samples = word_embs.shape[0]
|
133 |
+
|
134 |
+
pos_embs = self.pos_emb(pos_onehot)
|
135 |
+
inputs = word_embs + pos_embs
|
136 |
+
input_embs = self.input_emb(inputs)
|
137 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
138 |
+
|
139 |
+
cap_lens = cap_lens.data.tolist()
|
140 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
141 |
+
|
142 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
143 |
+
|
144 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
145 |
+
|
146 |
+
return self.output_net(gru_last)
|
147 |
+
|
148 |
+
|
149 |
+
class MotionEncoderBiGRUCo(nn.Module):
|
150 |
+
def __init__(self, input_size, hidden_size, output_size, device):
|
151 |
+
super(MotionEncoderBiGRUCo, self).__init__()
|
152 |
+
self.device = device
|
153 |
+
|
154 |
+
self.input_emb = nn.Linear(input_size, hidden_size)
|
155 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
156 |
+
self.output_net = nn.Sequential(
|
157 |
+
nn.Linear(hidden_size*2, hidden_size),
|
158 |
+
nn.LayerNorm(hidden_size),
|
159 |
+
nn.LeakyReLU(0.2, inplace=True),
|
160 |
+
nn.Linear(hidden_size, output_size)
|
161 |
+
)
|
162 |
+
|
163 |
+
self.input_emb.apply(init_weight)
|
164 |
+
self.output_net.apply(init_weight)
|
165 |
+
self.hidden_size = hidden_size
|
166 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
167 |
+
|
168 |
+
# input(batch_size, seq_len, dim)
|
169 |
+
def forward(self, inputs, m_lens):
|
170 |
+
num_samples = inputs.shape[0]
|
171 |
+
|
172 |
+
input_embs = self.input_emb(inputs)
|
173 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
174 |
+
|
175 |
+
cap_lens = m_lens.data.tolist()
|
176 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
177 |
+
|
178 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
179 |
+
|
180 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
181 |
+
|
182 |
+
return self.output_net(gru_last)
|
models/t2m_eval_wrapper.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.t2m_eval_modules import *
|
2 |
+
from utils.word_vectorizer import POS_enumerator
|
3 |
+
from os.path import join as pjoin
|
4 |
+
|
5 |
+
def build_models(opt):
|
6 |
+
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
|
7 |
+
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
|
8 |
+
pos_size=opt.dim_pos_ohot,
|
9 |
+
hidden_size=opt.dim_text_hidden,
|
10 |
+
output_size=opt.dim_coemb_hidden,
|
11 |
+
device=opt.device)
|
12 |
+
|
13 |
+
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
|
14 |
+
hidden_size=opt.dim_motion_hidden,
|
15 |
+
output_size=opt.dim_coemb_hidden,
|
16 |
+
device=opt.device)
|
17 |
+
|
18 |
+
checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
|
19 |
+
map_location=opt.device)
|
20 |
+
movement_enc.load_state_dict(checkpoint['movement_encoder'])
|
21 |
+
text_enc.load_state_dict(checkpoint['text_encoder'])
|
22 |
+
motion_enc.load_state_dict(checkpoint['motion_encoder'])
|
23 |
+
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
|
24 |
+
return text_enc, motion_enc, movement_enc
|
25 |
+
|
26 |
+
|
27 |
+
class EvaluatorModelWrapper(object):
|
28 |
+
|
29 |
+
def __init__(self, opt):
|
30 |
+
|
31 |
+
if opt.dataset_name == 't2m':
|
32 |
+
opt.dim_pose = 263
|
33 |
+
elif opt.dataset_name == 'kit':
|
34 |
+
opt.dim_pose = 251
|
35 |
+
else:
|
36 |
+
raise KeyError('Dataset not Recognized!!!')
|
37 |
+
|
38 |
+
opt.dim_word = 300
|
39 |
+
opt.max_motion_length = 196
|
40 |
+
opt.dim_pos_ohot = len(POS_enumerator)
|
41 |
+
opt.dim_motion_hidden = 1024
|
42 |
+
opt.max_text_len = 20
|
43 |
+
opt.dim_text_hidden = 512
|
44 |
+
opt.dim_coemb_hidden = 512
|
45 |
+
|
46 |
+
# print(opt)
|
47 |
+
|
48 |
+
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
|
49 |
+
self.opt = opt
|
50 |
+
self.device = opt.device
|
51 |
+
|
52 |
+
self.text_encoder.to(opt.device)
|
53 |
+
self.motion_encoder.to(opt.device)
|
54 |
+
self.movement_encoder.to(opt.device)
|
55 |
+
|
56 |
+
self.text_encoder.eval()
|
57 |
+
self.motion_encoder.eval()
|
58 |
+
self.movement_encoder.eval()
|
59 |
+
|
60 |
+
# Please note that the results does not follow the order of inputs
|
61 |
+
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
|
62 |
+
with torch.no_grad():
|
63 |
+
word_embs = word_embs.detach().to(self.device).float()
|
64 |
+
pos_ohot = pos_ohot.detach().to(self.device).float()
|
65 |
+
motions = motions.detach().to(self.device).float()
|
66 |
+
|
67 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
68 |
+
motions = motions[align_idx]
|
69 |
+
m_lens = m_lens[align_idx]
|
70 |
+
|
71 |
+
'''Movement Encoding'''
|
72 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
73 |
+
m_lens = m_lens // self.opt.unit_length
|
74 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
75 |
+
|
76 |
+
'''Text Encoding'''
|
77 |
+
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
|
78 |
+
text_embedding = text_embedding[align_idx]
|
79 |
+
return text_embedding, motion_embedding
|
80 |
+
|
81 |
+
# Please note that the results does not follow the order of inputs
|
82 |
+
def get_motion_embeddings(self, motions, m_lens):
|
83 |
+
with torch.no_grad():
|
84 |
+
motions = motions.detach().to(self.device).float()
|
85 |
+
|
86 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
87 |
+
motions = motions[align_idx]
|
88 |
+
m_lens = m_lens[align_idx]
|
89 |
+
|
90 |
+
'''Movement Encoding'''
|
91 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
92 |
+
m_lens = m_lens // self.opt.unit_length
|
93 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
94 |
+
return motion_embedding
|
95 |
+
|
96 |
+
## Borrowed form MDM
|
97 |
+
# our version
|
98 |
+
def build_evaluators(opt):
|
99 |
+
movement_enc = MovementConvEncoder(opt['dim_pose']-4, opt['dim_movement_enc_hidden'], opt['dim_movement_latent'])
|
100 |
+
text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'],
|
101 |
+
pos_size=opt['dim_pos_ohot'],
|
102 |
+
hidden_size=opt['dim_text_hidden'],
|
103 |
+
output_size=opt['dim_coemb_hidden'],
|
104 |
+
device=opt['device'])
|
105 |
+
|
106 |
+
motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_latent'],
|
107 |
+
hidden_size=opt['dim_motion_hidden'],
|
108 |
+
output_size=opt['dim_coemb_hidden'],
|
109 |
+
device=opt['device'])
|
110 |
+
|
111 |
+
ckpt_dir = opt['dataset_name']
|
112 |
+
if opt['dataset_name'] == 'humanml':
|
113 |
+
ckpt_dir = 't2m'
|
114 |
+
|
115 |
+
checkpoint = torch.load(pjoin(opt['checkpoints_dir'], ckpt_dir, 'text_mot_match', 'model', 'finest.tar'),
|
116 |
+
map_location=opt['device'])
|
117 |
+
movement_enc.load_state_dict(checkpoint['movement_encoder'])
|
118 |
+
text_enc.load_state_dict(checkpoint['text_encoder'])
|
119 |
+
motion_enc.load_state_dict(checkpoint['motion_encoder'])
|
120 |
+
print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
|
121 |
+
return text_enc, motion_enc, movement_enc
|
122 |
+
|
123 |
+
# our wrapper
|
124 |
+
class EvaluatorWrapper(object):
|
125 |
+
|
126 |
+
def __init__(self, dataset_name, device):
|
127 |
+
opt = {
|
128 |
+
'dataset_name': dataset_name,
|
129 |
+
'device': device,
|
130 |
+
'dim_word': 300,
|
131 |
+
'max_motion_length': 196,
|
132 |
+
'dim_pos_ohot': len(POS_enumerator),
|
133 |
+
'dim_motion_hidden': 1024,
|
134 |
+
'max_text_len': 20,
|
135 |
+
'dim_text_hidden': 512,
|
136 |
+
'dim_coemb_hidden': 512,
|
137 |
+
'dim_pose': 263 if dataset_name == 'humanml' else 251,
|
138 |
+
'dim_movement_enc_hidden': 512,
|
139 |
+
'dim_movement_latent': 512,
|
140 |
+
'checkpoints_dir': './checkpoints',
|
141 |
+
'unit_length': 4,
|
142 |
+
}
|
143 |
+
|
144 |
+
self.text_encoder, self.motion_encoder, self.movement_encoder = build_evaluators(opt)
|
145 |
+
self.opt = opt
|
146 |
+
self.device = opt['device']
|
147 |
+
|
148 |
+
self.text_encoder.to(opt['device'])
|
149 |
+
self.motion_encoder.to(opt['device'])
|
150 |
+
self.movement_encoder.to(opt['device'])
|
151 |
+
|
152 |
+
self.text_encoder.eval()
|
153 |
+
self.motion_encoder.eval()
|
154 |
+
self.movement_encoder.eval()
|
155 |
+
|
156 |
+
# Please note that the results does not following the order of inputs
|
157 |
+
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
|
158 |
+
with torch.no_grad():
|
159 |
+
word_embs = word_embs.detach().to(self.device).float()
|
160 |
+
pos_ohot = pos_ohot.detach().to(self.device).float()
|
161 |
+
motions = motions.detach().to(self.device).float()
|
162 |
+
|
163 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
164 |
+
motions = motions[align_idx]
|
165 |
+
m_lens = m_lens[align_idx]
|
166 |
+
|
167 |
+
'''Movement Encoding'''
|
168 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
169 |
+
m_lens = m_lens // self.opt['unit_length']
|
170 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
171 |
+
# print(motions.shape, movements.shape, motion_embedding.shape, m_lens)
|
172 |
+
|
173 |
+
'''Text Encoding'''
|
174 |
+
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
|
175 |
+
text_embedding = text_embedding[align_idx]
|
176 |
+
return text_embedding, motion_embedding
|
177 |
+
|
178 |
+
# Please note that the results does not following the order of inputs
|
179 |
+
def get_motion_embeddings(self, motions, m_lens):
|
180 |
+
with torch.no_grad():
|
181 |
+
motions = motions.detach().to(self.device).float()
|
182 |
+
|
183 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
184 |
+
motions = motions[align_idx]
|
185 |
+
m_lens = m_lens[align_idx]
|
186 |
+
|
187 |
+
'''Movement Encoding'''
|
188 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
189 |
+
m_lens = m_lens // self.opt['unit_length']
|
190 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
191 |
+
return motion_embedding
|
models/vq/__init__.py
ADDED
File without changes
|
models/vq/encdec.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from models.vq.resnet import Resnet1D
|
3 |
+
|
4 |
+
|
5 |
+
class Encoder(nn.Module):
|
6 |
+
def __init__(self,
|
7 |
+
input_emb_width=3,
|
8 |
+
output_emb_width=512,
|
9 |
+
down_t=2,
|
10 |
+
stride_t=2,
|
11 |
+
width=512,
|
12 |
+
depth=3,
|
13 |
+
dilation_growth_rate=3,
|
14 |
+
activation='relu',
|
15 |
+
norm=None):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
blocks = []
|
19 |
+
filter_t, pad_t = stride_t * 2, stride_t // 2
|
20 |
+
blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
|
21 |
+
blocks.append(nn.ReLU())
|
22 |
+
|
23 |
+
for i in range(down_t):
|
24 |
+
input_dim = width
|
25 |
+
block = nn.Sequential(
|
26 |
+
nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
|
27 |
+
Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm),
|
28 |
+
)
|
29 |
+
blocks.append(block)
|
30 |
+
blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
|
31 |
+
self.model = nn.Sequential(*blocks)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.model(x)
|
35 |
+
|
36 |
+
|
37 |
+
class Decoder(nn.Module):
|
38 |
+
def __init__(self,
|
39 |
+
input_emb_width=3,
|
40 |
+
output_emb_width=512,
|
41 |
+
down_t=2,
|
42 |
+
stride_t=2,
|
43 |
+
width=512,
|
44 |
+
depth=3,
|
45 |
+
dilation_growth_rate=3,
|
46 |
+
activation='relu',
|
47 |
+
norm=None):
|
48 |
+
super().__init__()
|
49 |
+
blocks = []
|
50 |
+
|
51 |
+
blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
|
52 |
+
blocks.append(nn.ReLU())
|
53 |
+
for i in range(down_t):
|
54 |
+
out_dim = width
|
55 |
+
block = nn.Sequential(
|
56 |
+
Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm),
|
57 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
58 |
+
nn.Conv1d(width, out_dim, 3, 1, 1)
|
59 |
+
)
|
60 |
+
blocks.append(block)
|
61 |
+
blocks.append(nn.Conv1d(width, width, 3, 1, 1))
|
62 |
+
blocks.append(nn.ReLU())
|
63 |
+
blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
|
64 |
+
self.model = nn.Sequential(*blocks)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.model(x)
|
68 |
+
return x.permute(0, 2, 1)
|
models/vq/model.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from models.vq.encdec import Encoder, Decoder
|
5 |
+
from models.vq.residual_vq import ResidualVQ
|
6 |
+
|
7 |
+
class RVQVAE(nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
args,
|
10 |
+
input_width=263,
|
11 |
+
nb_code=1024,
|
12 |
+
code_dim=512,
|
13 |
+
output_emb_width=512,
|
14 |
+
down_t=3,
|
15 |
+
stride_t=2,
|
16 |
+
width=512,
|
17 |
+
depth=3,
|
18 |
+
dilation_growth_rate=3,
|
19 |
+
activation='relu',
|
20 |
+
norm=None):
|
21 |
+
|
22 |
+
super().__init__()
|
23 |
+
assert output_emb_width == code_dim
|
24 |
+
self.code_dim = code_dim
|
25 |
+
self.num_code = nb_code
|
26 |
+
# self.quant = args.quantizer
|
27 |
+
self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth,
|
28 |
+
dilation_growth_rate, activation=activation, norm=norm)
|
29 |
+
self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth,
|
30 |
+
dilation_growth_rate, activation=activation, norm=norm)
|
31 |
+
rvqvae_config = {
|
32 |
+
'num_quantizers': args.num_quantizers,
|
33 |
+
'shared_codebook': args.shared_codebook,
|
34 |
+
'quantize_dropout_prob': args.quantize_dropout_prob,
|
35 |
+
'quantize_dropout_cutoff_index': 0,
|
36 |
+
'nb_code': nb_code,
|
37 |
+
'code_dim':code_dim,
|
38 |
+
'args': args,
|
39 |
+
}
|
40 |
+
self.quantizer = ResidualVQ(**rvqvae_config)
|
41 |
+
|
42 |
+
def preprocess(self, x):
|
43 |
+
# (bs, T, Jx3) -> (bs, Jx3, T)
|
44 |
+
x = x.permute(0, 2, 1).float()
|
45 |
+
return x
|
46 |
+
|
47 |
+
def postprocess(self, x):
|
48 |
+
# (bs, Jx3, T) -> (bs, T, Jx3)
|
49 |
+
x = x.permute(0, 2, 1)
|
50 |
+
return x
|
51 |
+
|
52 |
+
def encode(self, x):
|
53 |
+
N, T, _ = x.shape
|
54 |
+
x_in = self.preprocess(x)
|
55 |
+
x_encoder = self.encoder(x_in)
|
56 |
+
# print(x_encoder.shape)
|
57 |
+
code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True)
|
58 |
+
# print(code_idx.shape)
|
59 |
+
# code_idx = code_idx.view(N, -1)
|
60 |
+
# (N, T, Q)
|
61 |
+
# print()
|
62 |
+
return code_idx, all_codes
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x_in = self.preprocess(x)
|
66 |
+
# Encode
|
67 |
+
x_encoder = self.encoder(x_in)
|
68 |
+
|
69 |
+
## quantization
|
70 |
+
# x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5,
|
71 |
+
# force_dropout_index=0) #TODO hardcode
|
72 |
+
x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5)
|
73 |
+
|
74 |
+
# print(code_idx[0, :, 1])
|
75 |
+
## decoder
|
76 |
+
x_out = self.decoder(x_quantized)
|
77 |
+
# x_out = self.postprocess(x_decoder)
|
78 |
+
return x_out, commit_loss, perplexity
|
79 |
+
|
80 |
+
def forward_decoder(self, x):
|
81 |
+
x_d = self.quantizer.get_codes_from_indices(x)
|
82 |
+
# x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
|
83 |
+
x = x_d.sum(dim=0).permute(0, 2, 1)
|
84 |
+
|
85 |
+
# decoder
|
86 |
+
x_out = self.decoder(x)
|
87 |
+
# x_out = self.postprocess(x_decoder)
|
88 |
+
return x_out
|
89 |
+
|
90 |
+
class LengthEstimator(nn.Module):
|
91 |
+
def __init__(self, input_size, output_size):
|
92 |
+
super(LengthEstimator, self).__init__()
|
93 |
+
nd = 512
|
94 |
+
self.output = nn.Sequential(
|
95 |
+
nn.Linear(input_size, nd),
|
96 |
+
nn.LayerNorm(nd),
|
97 |
+
nn.LeakyReLU(0.2, inplace=True),
|
98 |
+
|
99 |
+
nn.Dropout(0.2),
|
100 |
+
nn.Linear(nd, nd // 2),
|
101 |
+
nn.LayerNorm(nd // 2),
|
102 |
+
nn.LeakyReLU(0.2, inplace=True),
|
103 |
+
|
104 |
+
nn.Dropout(0.2),
|
105 |
+
nn.Linear(nd // 2, nd // 4),
|
106 |
+
nn.LayerNorm(nd // 4),
|
107 |
+
nn.LeakyReLU(0.2, inplace=True),
|
108 |
+
|
109 |
+
nn.Linear(nd // 4, output_size)
|
110 |
+
)
|
111 |
+
|
112 |
+
self.output.apply(self.__init_weights)
|
113 |
+
|
114 |
+
def __init_weights(self, module):
|
115 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
116 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
117 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
118 |
+
module.bias.data.zero_()
|
119 |
+
elif isinstance(module, nn.LayerNorm):
|
120 |
+
module.bias.data.zero_()
|
121 |
+
module.weight.data.fill_(1.0)
|
122 |
+
|
123 |
+
def forward(self, text_emb):
|
124 |
+
return self.output(text_emb)
|
models/vq/quantizer.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange, repeat, reduce, pack, unpack
|
6 |
+
|
7 |
+
# from vector_quantize_pytorch import ResidualVQ
|
8 |
+
|
9 |
+
#Borrow from vector_quantize_pytorch
|
10 |
+
|
11 |
+
def log(t, eps = 1e-20):
|
12 |
+
return torch.log(t.clamp(min = eps))
|
13 |
+
|
14 |
+
def gumbel_noise(t):
|
15 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
16 |
+
return -log(-log(noise))
|
17 |
+
|
18 |
+
def gumbel_sample(
|
19 |
+
logits,
|
20 |
+
temperature = 1.,
|
21 |
+
stochastic = False,
|
22 |
+
dim = -1,
|
23 |
+
training = True
|
24 |
+
):
|
25 |
+
|
26 |
+
if training and stochastic and temperature > 0:
|
27 |
+
sampling_logits = (logits / temperature) + gumbel_noise(logits)
|
28 |
+
else:
|
29 |
+
sampling_logits = logits
|
30 |
+
|
31 |
+
ind = sampling_logits.argmax(dim = dim)
|
32 |
+
|
33 |
+
return ind
|
34 |
+
|
35 |
+
class QuantizeEMAReset(nn.Module):
|
36 |
+
def __init__(self, nb_code, code_dim, args):
|
37 |
+
super(QuantizeEMAReset, self).__init__()
|
38 |
+
self.nb_code = nb_code
|
39 |
+
self.code_dim = code_dim
|
40 |
+
self.mu = args.mu ##TO_DO
|
41 |
+
self.reset_codebook()
|
42 |
+
|
43 |
+
def reset_codebook(self):
|
44 |
+
self.init = False
|
45 |
+
self.code_sum = None
|
46 |
+
self.code_count = None
|
47 |
+
self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False).cuda())
|
48 |
+
|
49 |
+
def _tile(self, x):
|
50 |
+
nb_code_x, code_dim = x.shape
|
51 |
+
if nb_code_x < self.nb_code:
|
52 |
+
n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
|
53 |
+
std = 0.01 / np.sqrt(code_dim)
|
54 |
+
out = x.repeat(n_repeats, 1)
|
55 |
+
out = out + torch.randn_like(out) * std
|
56 |
+
else:
|
57 |
+
out = x
|
58 |
+
return out
|
59 |
+
|
60 |
+
def init_codebook(self, x):
|
61 |
+
out = self._tile(x)
|
62 |
+
self.codebook = out[:self.nb_code]
|
63 |
+
self.code_sum = self.codebook.clone()
|
64 |
+
self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
|
65 |
+
self.init = True
|
66 |
+
|
67 |
+
def quantize(self, x, sample_codebook_temp=0.):
|
68 |
+
# N X C -> C X N
|
69 |
+
k_w = self.codebook.t()
|
70 |
+
# x: NT X C
|
71 |
+
# NT X N
|
72 |
+
distance = torch.sum(x ** 2, dim=-1, keepdim=True) - \
|
73 |
+
2 * torch.matmul(x, k_w) + \
|
74 |
+
torch.sum(k_w ** 2, dim=0, keepdim=True) # (N * L, b)
|
75 |
+
|
76 |
+
# code_idx = torch.argmin(distance, dim=-1)
|
77 |
+
|
78 |
+
code_idx = gumbel_sample(-distance, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training)
|
79 |
+
|
80 |
+
return code_idx
|
81 |
+
|
82 |
+
def dequantize(self, code_idx):
|
83 |
+
x = F.embedding(code_idx, self.codebook)
|
84 |
+
return x
|
85 |
+
|
86 |
+
def get_codebook_entry(self, indices):
|
87 |
+
return self.dequantize(indices).permute(0, 2, 1)
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def compute_perplexity(self, code_idx):
|
91 |
+
# Calculate new centres
|
92 |
+
code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
|
93 |
+
code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
|
94 |
+
|
95 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
96 |
+
prob = code_count / torch.sum(code_count)
|
97 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
98 |
+
return perplexity
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def update_codebook(self, x, code_idx):
|
102 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
103 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
104 |
+
|
105 |
+
code_sum = torch.matmul(code_onehot, x) # nb_code, c
|
106 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
107 |
+
|
108 |
+
out = self._tile(x)
|
109 |
+
code_rand = out[:self.nb_code]
|
110 |
+
|
111 |
+
# Update centres
|
112 |
+
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
|
113 |
+
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count
|
114 |
+
|
115 |
+
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
|
116 |
+
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
|
117 |
+
self.codebook = usage * code_update + (1-usage) * code_rand
|
118 |
+
|
119 |
+
|
120 |
+
prob = code_count / torch.sum(code_count)
|
121 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
122 |
+
|
123 |
+
return perplexity
|
124 |
+
|
125 |
+
def preprocess(self, x):
|
126 |
+
# NCT -> NTC -> [NT, C]
|
127 |
+
# x = x.permute(0, 2, 1).contiguous()
|
128 |
+
# x = x.view(-1, x.shape[-1])
|
129 |
+
x = rearrange(x, 'n c t -> (n t) c')
|
130 |
+
return x
|
131 |
+
|
132 |
+
def forward(self, x, return_idx=False, temperature=0.):
|
133 |
+
N, width, T = x.shape
|
134 |
+
|
135 |
+
x = self.preprocess(x)
|
136 |
+
if self.training and not self.init:
|
137 |
+
self.init_codebook(x)
|
138 |
+
|
139 |
+
code_idx = self.quantize(x, temperature)
|
140 |
+
x_d = self.dequantize(code_idx)
|
141 |
+
|
142 |
+
if self.training:
|
143 |
+
perplexity = self.update_codebook(x, code_idx)
|
144 |
+
else:
|
145 |
+
perplexity = self.compute_perplexity(code_idx)
|
146 |
+
|
147 |
+
commit_loss = F.mse_loss(x, x_d.detach()) # It's right. the t2m-gpt paper is wrong on embed loss and commitment loss.
|
148 |
+
|
149 |
+
# Passthrough
|
150 |
+
x_d = x + (x_d - x).detach()
|
151 |
+
|
152 |
+
# Postprocess
|
153 |
+
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
|
154 |
+
code_idx = code_idx.view(N, T).contiguous()
|
155 |
+
# print(code_idx[0])
|
156 |
+
if return_idx:
|
157 |
+
return x_d, code_idx, commit_loss, perplexity
|
158 |
+
return x_d, commit_loss, perplexity
|
159 |
+
|
160 |
+
class QuantizeEMA(QuantizeEMAReset):
|
161 |
+
@torch.no_grad()
|
162 |
+
def update_codebook(self, x, code_idx):
|
163 |
+
code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
|
164 |
+
code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
|
165 |
+
|
166 |
+
code_sum = torch.matmul(code_onehot, x) # nb_code, c
|
167 |
+
code_count = code_onehot.sum(dim=-1) # nb_code
|
168 |
+
|
169 |
+
# Update centres
|
170 |
+
self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
|
171 |
+
self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count
|
172 |
+
|
173 |
+
usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
|
174 |
+
code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
|
175 |
+
self.codebook = usage * code_update + (1-usage) * self.codebook
|
176 |
+
|
177 |
+
prob = code_count / torch.sum(code_count)
|
178 |
+
perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
|
179 |
+
|
180 |
+
return perplexity
|
models/vq/residual_vq.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from math import ceil
|
3 |
+
from functools import partial
|
4 |
+
from itertools import zip_longest
|
5 |
+
from random import randrange
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
# from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
|
11 |
+
from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA
|
12 |
+
|
13 |
+
from einops import rearrange, repeat, pack, unpack
|
14 |
+
|
15 |
+
# helper functions
|
16 |
+
|
17 |
+
def exists(val):
|
18 |
+
return val is not None
|
19 |
+
|
20 |
+
def default(val, d):
|
21 |
+
return val if exists(val) else d
|
22 |
+
|
23 |
+
def round_up_multiple(num, mult):
|
24 |
+
return ceil(num / mult) * mult
|
25 |
+
|
26 |
+
# main class
|
27 |
+
|
28 |
+
class ResidualVQ(nn.Module):
|
29 |
+
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
num_quantizers,
|
33 |
+
shared_codebook=False,
|
34 |
+
quantize_dropout_prob=0.5,
|
35 |
+
quantize_dropout_cutoff_index=0,
|
36 |
+
**kwargs
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.num_quantizers = num_quantizers
|
41 |
+
|
42 |
+
# self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
|
43 |
+
if shared_codebook:
|
44 |
+
layer = QuantizeEMAReset(**kwargs)
|
45 |
+
self.layers = nn.ModuleList([layer for _ in range(num_quantizers)])
|
46 |
+
else:
|
47 |
+
self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)])
|
48 |
+
# self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)])
|
49 |
+
|
50 |
+
# self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
51 |
+
|
52 |
+
assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0
|
53 |
+
|
54 |
+
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
55 |
+
self.quantize_dropout_prob = quantize_dropout_prob
|
56 |
+
|
57 |
+
|
58 |
+
@property
|
59 |
+
def codebooks(self):
|
60 |
+
codebooks = [layer.codebook for layer in self.layers]
|
61 |
+
codebooks = torch.stack(codebooks, dim = 0)
|
62 |
+
return codebooks # 'q c d'
|
63 |
+
|
64 |
+
def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize
|
65 |
+
|
66 |
+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
67 |
+
|
68 |
+
# because of quantize dropout, one can pass in indices that are coarse
|
69 |
+
# and the network should be able to reconstruct
|
70 |
+
|
71 |
+
if quantize_dim < self.num_quantizers:
|
72 |
+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
|
73 |
+
|
74 |
+
# get ready for gathering
|
75 |
+
|
76 |
+
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
|
77 |
+
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])
|
78 |
+
|
79 |
+
# take care of quantizer dropout
|
80 |
+
|
81 |
+
mask = gather_indices == -1.
|
82 |
+
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
|
83 |
+
|
84 |
+
# print(gather_indices.max(), gather_indices.min())
|
85 |
+
all_codes = codebooks.gather(2, gather_indices) # gather all codes
|
86 |
+
|
87 |
+
# mask out any codes that were dropout-ed
|
88 |
+
|
89 |
+
all_codes = all_codes.masked_fill(mask, 0.)
|
90 |
+
|
91 |
+
return all_codes # 'q b n d'
|
92 |
+
|
93 |
+
def get_codebook_entry(self, indices): #indices shape 'b n q'
|
94 |
+
all_codes = self.get_codes_from_indices(indices) #'q b n d'
|
95 |
+
latent = torch.sum(all_codes, dim=0) #'b n d'
|
96 |
+
latent = latent.permute(0, 2, 1)
|
97 |
+
return latent
|
98 |
+
|
99 |
+
def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1):
|
100 |
+
# debug check
|
101 |
+
# print(self.codebooks[:,0,0].detach().cpu().numpy())
|
102 |
+
num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device
|
103 |
+
|
104 |
+
quantized_out = 0.
|
105 |
+
residual = x
|
106 |
+
|
107 |
+
all_losses = []
|
108 |
+
all_indices = []
|
109 |
+
all_perplexity = []
|
110 |
+
|
111 |
+
|
112 |
+
should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob
|
113 |
+
|
114 |
+
start_drop_quantize_index = num_quant
|
115 |
+
# To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers
|
116 |
+
if should_quantize_dropout:
|
117 |
+
start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch
|
118 |
+
null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
|
119 |
+
null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
|
120 |
+
# null_loss = 0.
|
121 |
+
|
122 |
+
if force_dropout_index >= 0:
|
123 |
+
should_quantize_dropout = True
|
124 |
+
start_drop_quantize_index = force_dropout_index
|
125 |
+
null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
|
126 |
+
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
127 |
+
|
128 |
+
# print(force_dropout_index)
|
129 |
+
# go through the layers
|
130 |
+
|
131 |
+
for quantizer_index, layer in enumerate(self.layers):
|
132 |
+
|
133 |
+
if should_quantize_dropout and quantizer_index > start_drop_quantize_index:
|
134 |
+
all_indices.append(null_indices)
|
135 |
+
# all_losses.append(null_loss)
|
136 |
+
continue
|
137 |
+
|
138 |
+
# layer_indices = None
|
139 |
+
# if return_loss:
|
140 |
+
# layer_indices = indices[..., quantizer_index] #gt indices
|
141 |
+
|
142 |
+
# quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO
|
143 |
+
quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer
|
144 |
+
|
145 |
+
# print(quantized.shape, residual.shape)
|
146 |
+
residual -= quantized.detach()
|
147 |
+
quantized_out += quantized
|
148 |
+
|
149 |
+
embed_indices, loss, perplexity = rest
|
150 |
+
all_indices.append(embed_indices)
|
151 |
+
all_losses.append(loss)
|
152 |
+
all_perplexity.append(perplexity)
|
153 |
+
|
154 |
+
|
155 |
+
# stack all losses and indices
|
156 |
+
all_indices = torch.stack(all_indices, dim=-1)
|
157 |
+
all_losses = sum(all_losses)/len(all_losses)
|
158 |
+
all_perplexity = sum(all_perplexity)/len(all_perplexity)
|
159 |
+
|
160 |
+
ret = (quantized_out, all_indices, all_losses, all_perplexity)
|
161 |
+
|
162 |
+
if return_all_codes:
|
163 |
+
# whether to return all codes from all codebooks across layers
|
164 |
+
all_codes = self.get_codes_from_indices(all_indices)
|
165 |
+
|
166 |
+
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
167 |
+
ret = (*ret, all_codes)
|
168 |
+
|
169 |
+
return ret
|
170 |
+
|
171 |
+
def quantize(self, x, return_latent=False):
|
172 |
+
all_indices = []
|
173 |
+
quantized_out = 0.
|
174 |
+
residual = x
|
175 |
+
all_codes = []
|
176 |
+
for quantizer_index, layer in enumerate(self.layers):
|
177 |
+
|
178 |
+
quantized, *rest = layer(residual, return_idx=True) #single quantizer
|
179 |
+
|
180 |
+
residual = residual - quantized.detach()
|
181 |
+
quantized_out = quantized_out + quantized
|
182 |
+
|
183 |
+
embed_indices, loss, perplexity = rest
|
184 |
+
all_indices.append(embed_indices)
|
185 |
+
# print(quantizer_index, embed_indices[0])
|
186 |
+
# print(quantizer_index, quantized[0])
|
187 |
+
# break
|
188 |
+
all_codes.append(quantized)
|
189 |
+
|
190 |
+
code_idx = torch.stack(all_indices, dim=-1)
|
191 |
+
all_codes = torch.stack(all_codes, dim=0)
|
192 |
+
if return_latent:
|
193 |
+
return code_idx, all_codes
|
194 |
+
return code_idx
|
models/vq/resnet.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class nonlinearity(nn.Module):
|
5 |
+
def __init(self):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
def forward(self, x):
|
9 |
+
return x * torch.sigmoid(x)
|
10 |
+
|
11 |
+
|
12 |
+
class ResConv1DBlock(nn.Module):
|
13 |
+
def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=0.2):
|
14 |
+
super(ResConv1DBlock, self).__init__()
|
15 |
+
|
16 |
+
padding = dilation
|
17 |
+
self.norm = norm
|
18 |
+
|
19 |
+
if norm == "LN":
|
20 |
+
self.norm1 = nn.LayerNorm(n_in)
|
21 |
+
self.norm2 = nn.LayerNorm(n_in)
|
22 |
+
elif norm == "GN":
|
23 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
|
24 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
|
25 |
+
elif norm == "BN":
|
26 |
+
self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
|
27 |
+
self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
|
28 |
+
else:
|
29 |
+
self.norm1 = nn.Identity()
|
30 |
+
self.norm2 = nn.Identity()
|
31 |
+
|
32 |
+
if activation == "relu":
|
33 |
+
self.activation1 = nn.ReLU()
|
34 |
+
self.activation2 = nn.ReLU()
|
35 |
+
|
36 |
+
elif activation == "silu":
|
37 |
+
self.activation1 = nonlinearity()
|
38 |
+
self.activation2 = nonlinearity()
|
39 |
+
|
40 |
+
elif activation == "gelu":
|
41 |
+
self.activation1 = nn.GELU()
|
42 |
+
self.activation2 = nn.GELU()
|
43 |
+
|
44 |
+
self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
|
45 |
+
self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0, )
|
46 |
+
self.dropout = nn.Dropout(dropout)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x_orig = x
|
50 |
+
if self.norm == "LN":
|
51 |
+
x = self.norm1(x.transpose(-2, -1))
|
52 |
+
x = self.activation1(x.transpose(-2, -1))
|
53 |
+
else:
|
54 |
+
x = self.norm1(x)
|
55 |
+
x = self.activation1(x)
|
56 |
+
|
57 |
+
x = self.conv1(x)
|
58 |
+
|
59 |
+
if self.norm == "LN":
|
60 |
+
x = self.norm2(x.transpose(-2, -1))
|
61 |
+
x = self.activation2(x.transpose(-2, -1))
|
62 |
+
else:
|
63 |
+
x = self.norm2(x)
|
64 |
+
x = self.activation2(x)
|
65 |
+
|
66 |
+
x = self.conv2(x)
|
67 |
+
x = self.dropout(x)
|
68 |
+
x = x + x_orig
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class Resnet1D(nn.Module):
|
73 |
+
def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm)
|
77 |
+
for depth in range(n_depth)]
|
78 |
+
if reverse_dilation:
|
79 |
+
blocks = blocks[::-1]
|
80 |
+
|
81 |
+
self.model = nn.Sequential(*blocks)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
return self.model(x)
|
models/vq/vq_trainer.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torch.nn.utils import clip_grad_norm_
|
4 |
+
from torch.utils.tensorboard import SummaryWriter
|
5 |
+
from os.path import join as pjoin
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import torch.optim as optim
|
9 |
+
|
10 |
+
import time
|
11 |
+
import numpy as np
|
12 |
+
from collections import OrderedDict, defaultdict
|
13 |
+
from utils.eval_t2m import evaluation_vqvae, evaluation_res_conv
|
14 |
+
from utils.utils import print_current_loss
|
15 |
+
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
|
19 |
+
def def_value():
|
20 |
+
return 0.0
|
21 |
+
|
22 |
+
|
23 |
+
class RVQTokenizerTrainer:
|
24 |
+
def __init__(self, args, vq_model):
|
25 |
+
self.opt = args
|
26 |
+
self.vq_model = vq_model
|
27 |
+
self.device = args.device
|
28 |
+
|
29 |
+
if args.is_train:
|
30 |
+
self.logger = SummaryWriter(args.log_dir)
|
31 |
+
if args.recons_loss == 'l1':
|
32 |
+
self.l1_criterion = torch.nn.L1Loss()
|
33 |
+
elif args.recons_loss == 'l1_smooth':
|
34 |
+
self.l1_criterion = torch.nn.SmoothL1Loss()
|
35 |
+
|
36 |
+
# self.critic = CriticWrapper(self.opt.dataset_name, self.opt.device)
|
37 |
+
|
38 |
+
def forward(self, batch_data):
|
39 |
+
motions = batch_data.detach().to(self.device).float()
|
40 |
+
pred_motion, loss_commit, perplexity = self.vq_model(motions)
|
41 |
+
|
42 |
+
self.motions = motions
|
43 |
+
self.pred_motion = pred_motion
|
44 |
+
|
45 |
+
loss_rec = self.l1_criterion(pred_motion, motions)
|
46 |
+
pred_local_pos = pred_motion[..., 4 : (self.opt.joints_num - 1) * 3 + 4]
|
47 |
+
local_pos = motions[..., 4 : (self.opt.joints_num - 1) * 3 + 4]
|
48 |
+
loss_explicit = self.l1_criterion(pred_local_pos, local_pos)
|
49 |
+
|
50 |
+
loss = loss_rec + self.opt.loss_vel * loss_explicit + self.opt.commit * loss_commit
|
51 |
+
|
52 |
+
# return loss, loss_rec, loss_vel, loss_commit, perplexity
|
53 |
+
# return loss, loss_rec, loss_percept, loss_commit, perplexity
|
54 |
+
return loss, loss_rec, loss_explicit, loss_commit, perplexity
|
55 |
+
|
56 |
+
|
57 |
+
# @staticmethod
|
58 |
+
def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):
|
59 |
+
|
60 |
+
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
|
61 |
+
for param_group in self.opt_vq_model.param_groups:
|
62 |
+
param_group["lr"] = current_lr
|
63 |
+
|
64 |
+
return current_lr
|
65 |
+
|
66 |
+
def save(self, file_name, ep, total_it):
|
67 |
+
state = {
|
68 |
+
"vq_model": self.vq_model.state_dict(),
|
69 |
+
"opt_vq_model": self.opt_vq_model.state_dict(),
|
70 |
+
"scheduler": self.scheduler.state_dict(),
|
71 |
+
'ep': ep,
|
72 |
+
'total_it': total_it,
|
73 |
+
}
|
74 |
+
torch.save(state, file_name)
|
75 |
+
|
76 |
+
def resume(self, model_dir):
|
77 |
+
checkpoint = torch.load(model_dir, map_location=self.device)
|
78 |
+
self.vq_model.load_state_dict(checkpoint['vq_model'])
|
79 |
+
self.opt_vq_model.load_state_dict(checkpoint['opt_vq_model'])
|
80 |
+
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
81 |
+
return checkpoint['ep'], checkpoint['total_it']
|
82 |
+
|
83 |
+
def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval=None):
|
84 |
+
self.vq_model.to(self.device)
|
85 |
+
|
86 |
+
self.opt_vq_model = optim.AdamW(self.vq_model.parameters(), lr=self.opt.lr, betas=(0.9, 0.99), weight_decay=self.opt.weight_decay)
|
87 |
+
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt_vq_model, milestones=self.opt.milestones, gamma=self.opt.gamma)
|
88 |
+
|
89 |
+
epoch = 0
|
90 |
+
it = 0
|
91 |
+
if self.opt.is_continue:
|
92 |
+
model_dir = pjoin(self.opt.model_dir, 'latest.tar')
|
93 |
+
epoch, it = self.resume(model_dir)
|
94 |
+
print("Load model epoch:%d iterations:%d"%(epoch, it))
|
95 |
+
|
96 |
+
start_time = time.time()
|
97 |
+
total_iters = self.opt.max_epoch * len(train_loader)
|
98 |
+
print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
|
99 |
+
print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(eval_val_loader)))
|
100 |
+
# val_loss = 0
|
101 |
+
# min_val_loss = np.inf
|
102 |
+
# min_val_epoch = epoch
|
103 |
+
current_lr = self.opt.lr
|
104 |
+
logs = defaultdict(def_value, OrderedDict())
|
105 |
+
|
106 |
+
# sys.exit()
|
107 |
+
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_vqvae(
|
108 |
+
self.opt.model_dir, eval_val_loader, self.vq_model, self.logger, epoch, best_fid=1000,
|
109 |
+
best_div=100, best_top1=0,
|
110 |
+
best_top2=0, best_top3=0, best_matching=100,
|
111 |
+
eval_wrapper=eval_wrapper, save=False)
|
112 |
+
|
113 |
+
while epoch < self.opt.max_epoch:
|
114 |
+
self.vq_model.train()
|
115 |
+
for i, batch_data in enumerate(train_loader):
|
116 |
+
it += 1
|
117 |
+
if it < self.opt.warm_up_iter:
|
118 |
+
current_lr = self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)
|
119 |
+
loss, loss_rec, loss_vel, loss_commit, perplexity = self.forward(batch_data)
|
120 |
+
self.opt_vq_model.zero_grad()
|
121 |
+
loss.backward()
|
122 |
+
self.opt_vq_model.step()
|
123 |
+
|
124 |
+
if it >= self.opt.warm_up_iter:
|
125 |
+
self.scheduler.step()
|
126 |
+
|
127 |
+
logs['loss'] += loss.item()
|
128 |
+
logs['loss_rec'] += loss_rec.item()
|
129 |
+
# Note it not necessarily velocity, too lazy to change the name now
|
130 |
+
logs['loss_vel'] += loss_vel.item()
|
131 |
+
logs['loss_commit'] += loss_commit.item()
|
132 |
+
logs['perplexity'] += perplexity.item()
|
133 |
+
logs['lr'] += self.opt_vq_model.param_groups[0]['lr']
|
134 |
+
|
135 |
+
if it % self.opt.log_every == 0:
|
136 |
+
mean_loss = OrderedDict()
|
137 |
+
# self.logger.add_scalar('val_loss', val_loss, it)
|
138 |
+
# self.l
|
139 |
+
for tag, value in logs.items():
|
140 |
+
self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
|
141 |
+
mean_loss[tag] = value / self.opt.log_every
|
142 |
+
logs = defaultdict(def_value, OrderedDict())
|
143 |
+
print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
|
144 |
+
|
145 |
+
if it % self.opt.save_latest == 0:
|
146 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
147 |
+
|
148 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
149 |
+
|
150 |
+
epoch += 1
|
151 |
+
# if epoch % self.opt.save_every_e == 0:
|
152 |
+
# self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, total_it=it)
|
153 |
+
|
154 |
+
print('Validation time:')
|
155 |
+
self.vq_model.eval()
|
156 |
+
val_loss_rec = []
|
157 |
+
val_loss_vel = []
|
158 |
+
val_loss_commit = []
|
159 |
+
val_loss = []
|
160 |
+
val_perpexity = []
|
161 |
+
with torch.no_grad():
|
162 |
+
for i, batch_data in enumerate(val_loader):
|
163 |
+
loss, loss_rec, loss_vel, loss_commit, perplexity = self.forward(batch_data)
|
164 |
+
# val_loss_rec += self.l1_criterion(self.recon_motions, self.motions).item()
|
165 |
+
# val_loss_emb += self.embedding_loss.item()
|
166 |
+
val_loss.append(loss.item())
|
167 |
+
val_loss_rec.append(loss_rec.item())
|
168 |
+
val_loss_vel.append(loss_vel.item())
|
169 |
+
val_loss_commit.append(loss_commit.item())
|
170 |
+
val_perpexity.append(perplexity.item())
|
171 |
+
|
172 |
+
# val_loss = val_loss_rec / (len(val_dataloader) + 1)
|
173 |
+
# val_loss = val_loss / (len(val_dataloader) + 1)
|
174 |
+
# val_loss_rec = val_loss_rec / (len(val_dataloader) + 1)
|
175 |
+
# val_loss_emb = val_loss_emb / (len(val_dataloader) + 1)
|
176 |
+
self.logger.add_scalar('Val/loss', sum(val_loss) / len(val_loss), epoch)
|
177 |
+
self.logger.add_scalar('Val/loss_rec', sum(val_loss_rec) / len(val_loss_rec), epoch)
|
178 |
+
self.logger.add_scalar('Val/loss_vel', sum(val_loss_vel) / len(val_loss_vel), epoch)
|
179 |
+
self.logger.add_scalar('Val/loss_commit', sum(val_loss_commit) / len(val_loss), epoch)
|
180 |
+
self.logger.add_scalar('Val/loss_perplexity', sum(val_perpexity) / len(val_loss_rec), epoch)
|
181 |
+
|
182 |
+
print('Validation Loss: %.5f Reconstruction: %.5f, Velocity: %.5f, Commit: %.5f' %
|
183 |
+
(sum(val_loss)/len(val_loss), sum(val_loss_rec)/len(val_loss),
|
184 |
+
sum(val_loss_vel)/len(val_loss), sum(val_loss_commit)/len(val_loss)))
|
185 |
+
|
186 |
+
# if sum(val_loss) / len(val_loss) < min_val_loss:
|
187 |
+
# min_val_loss = sum(val_loss) / len(val_loss)
|
188 |
+
# # if sum(val_loss_vel) / len(val_loss_vel) < min_val_loss:
|
189 |
+
# # min_val_loss = sum(val_loss_vel) / len(val_loss_vel)
|
190 |
+
# min_val_epoch = epoch
|
191 |
+
# self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
|
192 |
+
# print('Best Validation Model So Far!~')
|
193 |
+
|
194 |
+
best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_vqvae(
|
195 |
+
self.opt.model_dir, eval_val_loader, self.vq_model, self.logger, epoch, best_fid=best_fid,
|
196 |
+
best_div=best_div, best_top1=best_top1,
|
197 |
+
best_top2=best_top2, best_top3=best_top3, best_matching=best_matching, eval_wrapper=eval_wrapper)
|
198 |
+
|
199 |
+
|
200 |
+
if epoch % self.opt.eval_every_e == 0:
|
201 |
+
data = torch.cat([self.motions[:4], self.pred_motion[:4]], dim=0).detach().cpu().numpy()
|
202 |
+
# np.save(pjoin(self.opt.eval_dir, 'E%04d.npy' % (epoch)), data)
|
203 |
+
save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch))
|
204 |
+
os.makedirs(save_dir, exist_ok=True)
|
205 |
+
plot_eval(data, save_dir)
|
206 |
+
# if plot_eval is not None:
|
207 |
+
# save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch))
|
208 |
+
# os.makedirs(save_dir, exist_ok=True)
|
209 |
+
# plot_eval(data, save_dir)
|
210 |
+
|
211 |
+
# if epoch - min_val_epoch >= self.opt.early_stop_e:
|
212 |
+
# print('Early Stopping!~')
|
213 |
+
|
214 |
+
|
215 |
+
class LengthEstTrainer(object):
|
216 |
+
|
217 |
+
def __init__(self, args, estimator, text_encoder, encode_fnc):
|
218 |
+
self.opt = args
|
219 |
+
self.estimator = estimator
|
220 |
+
self.text_encoder = text_encoder
|
221 |
+
self.encode_fnc = encode_fnc
|
222 |
+
self.device = args.device
|
223 |
+
|
224 |
+
if args.is_train:
|
225 |
+
# self.motion_dis
|
226 |
+
self.logger = SummaryWriter(args.log_dir)
|
227 |
+
self.mul_cls_criterion = torch.nn.CrossEntropyLoss()
|
228 |
+
|
229 |
+
def resume(self, model_dir):
|
230 |
+
checkpoints = torch.load(model_dir, map_location=self.device)
|
231 |
+
self.estimator.load_state_dict(checkpoints['estimator'])
|
232 |
+
# self.opt_estimator.load_state_dict(checkpoints['opt_estimator'])
|
233 |
+
return checkpoints['epoch'], checkpoints['iter']
|
234 |
+
|
235 |
+
def save(self, model_dir, epoch, niter):
|
236 |
+
state = {
|
237 |
+
'estimator': self.estimator.state_dict(),
|
238 |
+
# 'opt_estimator': self.opt_estimator.state_dict(),
|
239 |
+
'epoch': epoch,
|
240 |
+
'niter': niter,
|
241 |
+
}
|
242 |
+
torch.save(state, model_dir)
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def zero_grad(opt_list):
|
246 |
+
for opt in opt_list:
|
247 |
+
opt.zero_grad()
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def clip_norm(network_list):
|
251 |
+
for network in network_list:
|
252 |
+
clip_grad_norm_(network.parameters(), 0.5)
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def step(opt_list):
|
256 |
+
for opt in opt_list:
|
257 |
+
opt.step()
|
258 |
+
|
259 |
+
def train(self, train_dataloader, val_dataloader):
|
260 |
+
self.estimator.to(self.device)
|
261 |
+
self.text_encoder.to(self.device)
|
262 |
+
|
263 |
+
self.opt_estimator = optim.Adam(self.estimator.parameters(), lr=self.opt.lr)
|
264 |
+
|
265 |
+
epoch = 0
|
266 |
+
it = 0
|
267 |
+
|
268 |
+
if self.opt.is_continue:
|
269 |
+
model_dir = pjoin(self.opt.model_dir, 'latest.tar')
|
270 |
+
epoch, it = self.resume(model_dir)
|
271 |
+
|
272 |
+
start_time = time.time()
|
273 |
+
total_iters = self.opt.max_epoch * len(train_dataloader)
|
274 |
+
print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader)))
|
275 |
+
val_loss = 0
|
276 |
+
min_val_loss = np.inf
|
277 |
+
logs = defaultdict(float)
|
278 |
+
while epoch < self.opt.max_epoch:
|
279 |
+
# time0 = time.time()
|
280 |
+
for i, batch_data in enumerate(train_dataloader):
|
281 |
+
self.estimator.train()
|
282 |
+
|
283 |
+
conds, _, m_lens = batch_data
|
284 |
+
# word_emb = word_emb.detach().to(self.device).float()
|
285 |
+
# pos_ohot = pos_ohot.detach().to(self.device).float()
|
286 |
+
# m_lens = m_lens.to(self.device).long()
|
287 |
+
text_embs = self.encode_fnc(self.text_encoder, conds, self.opt.device).detach()
|
288 |
+
# print(text_embs.shape, text_embs.device)
|
289 |
+
|
290 |
+
pred_dis = self.estimator(text_embs)
|
291 |
+
|
292 |
+
self.zero_grad([self.opt_estimator])
|
293 |
+
|
294 |
+
gt_labels = m_lens // self.opt.unit_length
|
295 |
+
gt_labels = gt_labels.long().to(self.device)
|
296 |
+
# print(gt_labels.shape, pred_dis.shape)
|
297 |
+
# print(gt_labels.max(), gt_labels.min())
|
298 |
+
# print(pred_dis)
|
299 |
+
acc = (gt_labels == pred_dis.argmax(dim=-1)).sum() / len(gt_labels)
|
300 |
+
loss = self.mul_cls_criterion(pred_dis, gt_labels)
|
301 |
+
|
302 |
+
loss.backward()
|
303 |
+
|
304 |
+
self.clip_norm([self.estimator])
|
305 |
+
self.step([self.opt_estimator])
|
306 |
+
|
307 |
+
logs['loss'] += loss.item()
|
308 |
+
logs['acc'] += acc.item()
|
309 |
+
|
310 |
+
it += 1
|
311 |
+
if it % self.opt.log_every == 0:
|
312 |
+
mean_loss = OrderedDict({'val_loss': val_loss})
|
313 |
+
# self.logger.add_scalar('Val/loss', val_loss, it)
|
314 |
+
|
315 |
+
for tag, value in logs.items():
|
316 |
+
self.logger.add_scalar("Train/%s"%tag, value / self.opt.log_every, it)
|
317 |
+
mean_loss[tag] = value / self.opt.log_every
|
318 |
+
logs = defaultdict(float)
|
319 |
+
print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
|
320 |
+
|
321 |
+
if it % self.opt.save_latest == 0:
|
322 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
323 |
+
|
324 |
+
self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
|
325 |
+
|
326 |
+
epoch += 1
|
327 |
+
|
328 |
+
print('Validation time:')
|
329 |
+
|
330 |
+
val_loss = 0
|
331 |
+
val_acc = 0
|
332 |
+
# self.estimator.eval()
|
333 |
+
with torch.no_grad():
|
334 |
+
for i, batch_data in enumerate(val_dataloader):
|
335 |
+
self.estimator.eval()
|
336 |
+
|
337 |
+
conds, _, m_lens = batch_data
|
338 |
+
# word_emb = word_emb.detach().to(self.device).float()
|
339 |
+
# pos_ohot = pos_ohot.detach().to(self.device).float()
|
340 |
+
# m_lens = m_lens.to(self.device).long()
|
341 |
+
text_embs = self.encode_fnc(self.text_encoder, conds, self.opt.device)
|
342 |
+
pred_dis = self.estimator(text_embs)
|
343 |
+
|
344 |
+
gt_labels = m_lens // self.opt.unit_length
|
345 |
+
gt_labels = gt_labels.long().to(self.device)
|
346 |
+
loss = self.mul_cls_criterion(pred_dis, gt_labels)
|
347 |
+
acc = (gt_labels == pred_dis.argmax(dim=-1)).sum() / len(gt_labels)
|
348 |
+
|
349 |
+
val_loss += loss.item()
|
350 |
+
val_acc += acc.item()
|
351 |
+
|
352 |
+
|
353 |
+
val_loss = val_loss / len(val_dataloader)
|
354 |
+
val_acc = val_acc / len(val_dataloader)
|
355 |
+
print('Validation Loss: %.5f Validation Acc: %.5f' % (val_loss, val_acc))
|
356 |
+
|
357 |
+
if val_loss < min_val_loss:
|
358 |
+
self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
|
359 |
+
min_val_loss = val_loss
|
motion_loaders/__init__.py
ADDED
File without changes
|
motion_loaders/dataset_motion_loader.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data.t2m_dataset import Text2MotionDatasetEval, collate_fn # TODO
|
2 |
+
from utils.word_vectorizer import WordVectorizer
|
3 |
+
import numpy as np
|
4 |
+
from os.path import join as pjoin
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from utils.get_opt import get_opt
|
7 |
+
|
8 |
+
def get_dataset_motion_loader(opt_path, batch_size, fname, device):
|
9 |
+
opt = get_opt(opt_path, device)
|
10 |
+
|
11 |
+
# Configurations of T2M dataset and KIT dataset is almost the same
|
12 |
+
if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
|
13 |
+
print('Loading dataset %s ...' % opt.dataset_name)
|
14 |
+
|
15 |
+
mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
|
16 |
+
std = np.load(pjoin(opt.meta_dir, 'std.npy'))
|
17 |
+
|
18 |
+
w_vectorizer = WordVectorizer('./glove', 'our_vab')
|
19 |
+
split_file = pjoin(opt.data_root, '%s.txt'%fname)
|
20 |
+
dataset = Text2MotionDatasetEval(opt, mean, std, split_file, w_vectorizer)
|
21 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True,
|
22 |
+
collate_fn=collate_fn, shuffle=True)
|
23 |
+
else:
|
24 |
+
raise KeyError('Dataset not Recognized !!')
|
25 |
+
|
26 |
+
print('Ground Truth Dataset Loading Completed!!!')
|
27 |
+
return dataloader, dataset
|
options/__init__.py
ADDED
File without changes
|
options/base_option.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class BaseOptions():
|
6 |
+
def __init__(self):
|
7 |
+
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
8 |
+
self.initialized = False
|
9 |
+
|
10 |
+
def initialize(self):
|
11 |
+
self.parser.add_argument('--name', type=str, default="t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns", help='Name of this trial')
|
12 |
+
|
13 |
+
self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
|
14 |
+
|
15 |
+
self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
|
16 |
+
self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
|
17 |
+
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
|
18 |
+
|
19 |
+
self.parser.add_argument('--latent_dim', type=int, default=384, help='Dimension of transformer latent.')
|
20 |
+
self.parser.add_argument('--n_heads', type=int, default=6, help='Number of heads.')
|
21 |
+
self.parser.add_argument('--n_layers', type=int, default=8, help='Number of attention layers.')
|
22 |
+
self.parser.add_argument('--ff_size', type=int, default=1024, help='FF_Size')
|
23 |
+
self.parser.add_argument('--dropout', type=float, default=0.2, help='Dropout ratio in transformer')
|
24 |
+
|
25 |
+
self.parser.add_argument("--max_motion_length", type=int, default=196, help="Max length of motion")
|
26 |
+
self.parser.add_argument("--unit_length", type=int, default=4, help="Downscale ratio of VQ")
|
27 |
+
|
28 |
+
self.parser.add_argument('--force_mask', action="store_true", help='True: mask out conditions')
|
29 |
+
|
30 |
+
self.initialized = True
|
31 |
+
|
32 |
+
def parse(self):
|
33 |
+
if not self.initialized:
|
34 |
+
self.initialize()
|
35 |
+
|
36 |
+
self.opt = self.parser.parse_args()
|
37 |
+
|
38 |
+
self.opt.is_train = self.is_train
|
39 |
+
|
40 |
+
if self.opt.gpu_id != -1:
|
41 |
+
# self.opt.gpu_id = int(self.opt.gpu_id)
|
42 |
+
torch.cuda.set_device(self.opt.gpu_id)
|
43 |
+
|
44 |
+
args = vars(self.opt)
|
45 |
+
|
46 |
+
print('------------ Options -------------')
|
47 |
+
for k, v in sorted(args.items()):
|
48 |
+
print('%s: %s' % (str(k), str(v)))
|
49 |
+
print('-------------- End ----------------')
|
50 |
+
if self.is_train:
|
51 |
+
# save to the disk
|
52 |
+
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name)
|
53 |
+
if not os.path.exists(expr_dir):
|
54 |
+
os.makedirs(expr_dir)
|
55 |
+
file_name = os.path.join(expr_dir, 'opt.txt')
|
56 |
+
with open(file_name, 'wt') as opt_file:
|
57 |
+
opt_file.write('------------ Options -------------\n')
|
58 |
+
for k, v in sorted(args.items()):
|
59 |
+
opt_file.write('%s: %s\n' % (str(k), str(v)))
|
60 |
+
opt_file.write('-------------- End ----------------\n')
|
61 |
+
return self.opt
|
options/eval_option.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from options.base_option import BaseOptions
|
2 |
+
|
3 |
+
class EvalT2MOptions(BaseOptions):
|
4 |
+
def initialize(self):
|
5 |
+
BaseOptions.initialize(self)
|
6 |
+
self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint you want to use, {latest, net_best_fid, etc}')
|
7 |
+
self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
8 |
+
|
9 |
+
self.parser.add_argument('--ext', type=str, default='text2motion', help='Extension of the result file or folder')
|
10 |
+
self.parser.add_argument("--num_batch", default=2, type=int,
|
11 |
+
help="Number of batch for generation")
|
12 |
+
self.parser.add_argument("--repeat_times", default=1, type=int,
|
13 |
+
help="Number of repetitions, per sample text prompt")
|
14 |
+
self.parser.add_argument("--cond_scale", default=4, type=float,
|
15 |
+
help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
|
16 |
+
self.parser.add_argument("--temperature", default=1., type=float,
|
17 |
+
help="Sampling Temperature.")
|
18 |
+
self.parser.add_argument("--topkr", default=0.9, type=float,
|
19 |
+
help="Filter out percentil low prop entries.")
|
20 |
+
self.parser.add_argument("--time_steps", default=18, type=int,
|
21 |
+
help="Mask Generate steps.")
|
22 |
+
self.parser.add_argument("--seed", default=10107, type=int)
|
23 |
+
|
24 |
+
self.parser.add_argument('--gumbel_sample', action="store_true", help='True: gumbel sampling, False: categorical sampling.')
|
25 |
+
self.parser.add_argument('--use_res_model', action="store_true", help='Whether to use residual transformer.')
|
26 |
+
# self.parser.add_argument('--est_length', action="store_true", help='Training iterations')
|
27 |
+
|
28 |
+
self.parser.add_argument('--res_name', type=str, default='tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw', help='Model name of residual transformer')
|
29 |
+
self.parser.add_argument('--text_path', type=str, default="", help='Text prompt file')
|
30 |
+
|
31 |
+
|
32 |
+
self.parser.add_argument('-msec', '--mask_edit_section', nargs='*', type=str, help='Indicate sections for editing, use comma to separate the start and end of a section'
|
33 |
+
'type int will specify the token frame, type float will specify the ratio of seq_len')
|
34 |
+
self.parser.add_argument('--text_prompt', default='', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.")
|
35 |
+
self.parser.add_argument('--source_motion', default='example_data/000612.npy', type=str, help="Source motion path for editing. (new_joint_vecs format .npy file)")
|
36 |
+
self.parser.add_argument("--motion_length", default=0, type=int,
|
37 |
+
help="Motion length for generation, only applicable with single text prompt.")
|
38 |
+
self.is_train = False
|
options/train_option.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from options.base_option import BaseOptions
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
class TrainT2MOptions(BaseOptions):
|
5 |
+
def initialize(self):
|
6 |
+
BaseOptions.initialize(self)
|
7 |
+
self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
|
8 |
+
self.parser.add_argument('--max_epoch', type=int, default=500, help='Maximum number of epoch for training')
|
9 |
+
# self.parser.add_argument('--max_iters', type=int, default=150_000, help='Training iterations')
|
10 |
+
|
11 |
+
'''LR scheduler'''
|
12 |
+
self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
|
13 |
+
self.parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate schedule factor')
|
14 |
+
self.parser.add_argument('--milestones', default=[50_000], nargs="+", type=int,
|
15 |
+
help="learning rate schedule (iterations)")
|
16 |
+
self.parser.add_argument('--warm_up_iter', default=2000, type=int, help='number of total iterations for warmup')
|
17 |
+
|
18 |
+
'''Condition'''
|
19 |
+
self.parser.add_argument('--cond_drop_prob', type=float, default=0.1, help='Drop ratio of condition, for classifier-free guidance')
|
20 |
+
self.parser.add_argument("--seed", default=3407, type=int, help="Seed")
|
21 |
+
|
22 |
+
self.parser.add_argument('--is_continue', action="store_true", help='Is this trial continuing previous state?')
|
23 |
+
self.parser.add_argument('--gumbel_sample', action="store_true", help='Strategy for token sampling, True: Gumbel sampling, False: Categorical sampling')
|
24 |
+
self.parser.add_argument('--share_weight', action="store_true", help='Whether to share weight for projection/embedding, for residual transformer.')
|
25 |
+
|
26 |
+
self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress, (iteration)')
|
27 |
+
# self.parser.add_argument('--save_every_e', type=int, default=100, help='Frequency of printing training progress')
|
28 |
+
self.parser.add_argument('--eval_every_e', type=int, default=10, help='Frequency of animating eval results, (epoch)')
|
29 |
+
self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of saving checkpoint, (iteration)')
|
30 |
+
|
31 |
+
|
32 |
+
self.is_train = True
|
33 |
+
|
34 |
+
|
35 |
+
class TrainLenEstOptions():
|
36 |
+
def __init__(self):
|
37 |
+
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
38 |
+
self.parser.add_argument('--name', type=str, default="test", help='Name of this trial')
|
39 |
+
self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
|
40 |
+
|
41 |
+
self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name')
|
42 |
+
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
43 |
+
|
44 |
+
self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
|
45 |
+
|
46 |
+
self.parser.add_argument("--unit_length", type=int, default=4, help="Length of motion")
|
47 |
+
self.parser.add_argument("--max_text_len", type=int, default=20, help="Length of motion")
|
48 |
+
|
49 |
+
self.parser.add_argument('--max_epoch', type=int, default=300, help='Training iterations')
|
50 |
+
|
51 |
+
self.parser.add_argument('--lr', type=float, default=1e-4, help='Layers of GRU')
|
52 |
+
|
53 |
+
self.parser.add_argument('--is_continue', action="store_true", help='Training iterations')
|
54 |
+
|
55 |
+
self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress')
|
56 |
+
self.parser.add_argument('--save_every_e', type=int, default=5, help='Frequency of printing training progress')
|
57 |
+
self.parser.add_argument('--eval_every_e', type=int, default=3, help='Frequency of printing training progress')
|
58 |
+
self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of printing training progress')
|
59 |
+
|
60 |
+
def parse(self):
|
61 |
+
self.opt = self.parser.parse_args()
|
62 |
+
self.opt.is_train = True
|
63 |
+
# args = vars(self.opt)
|
64 |
+
return self.opt
|
options/vq_option.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
def arg_parse(is_train=False):
|
6 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
7 |
+
|
8 |
+
## dataloader
|
9 |
+
parser.add_argument('--dataset_name', type=str, default='humanml3d', help='dataset directory')
|
10 |
+
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
|
11 |
+
parser.add_argument('--window_size', type=int, default=64, help='training motion length')
|
12 |
+
parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
|
13 |
+
|
14 |
+
## optimization
|
15 |
+
parser.add_argument('--max_epoch', default=50, type=int, help='number of total epochs to run')
|
16 |
+
# parser.add_argument('--total_iter', default=None, type=int, help='number of total iterations to run')
|
17 |
+
parser.add_argument('--warm_up_iter', default=2000, type=int, help='number of total iterations for warmup')
|
18 |
+
parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
|
19 |
+
parser.add_argument('--milestones', default=[150000, 250000], nargs="+", type=int, help="learning rate schedule (iterations)")
|
20 |
+
parser.add_argument('--gamma', default=0.1, type=float, help="learning rate decay")
|
21 |
+
|
22 |
+
parser.add_argument('--weight_decay', default=0.0, type=float, help='weight decay')
|
23 |
+
parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss")
|
24 |
+
parser.add_argument('--loss_vel', type=float, default=0.5, help='hyper-parameter for the velocity loss')
|
25 |
+
parser.add_argument('--recons_loss', type=str, default='l1_smooth', help='reconstruction loss')
|
26 |
+
|
27 |
+
## vqvae arch
|
28 |
+
parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension")
|
29 |
+
parser.add_argument("--nb_code", type=int, default=512, help="nb of embedding")
|
30 |
+
parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
|
31 |
+
parser.add_argument("--down_t", type=int, default=2, help="downsampling rate")
|
32 |
+
parser.add_argument("--stride_t", type=int, default=2, help="stride size")
|
33 |
+
parser.add_argument("--width", type=int, default=512, help="width of the network")
|
34 |
+
parser.add_argument("--depth", type=int, default=3, help="num of resblocks for each res")
|
35 |
+
parser.add_argument("--dilation_growth_rate", type=int, default=3, help="dilation growth rate")
|
36 |
+
parser.add_argument("--output_emb_width", type=int, default=512, help="output embedding width")
|
37 |
+
parser.add_argument('--vq_act', type=str, default='relu', choices=['relu', 'silu', 'gelu'],
|
38 |
+
help='dataset directory')
|
39 |
+
parser.add_argument('--vq_norm', type=str, default=None, help='dataset directory')
|
40 |
+
|
41 |
+
parser.add_argument('--num_quantizers', type=int, default=3, help='num_quantizers')
|
42 |
+
parser.add_argument('--shared_codebook', action="store_true")
|
43 |
+
parser.add_argument('--quantize_dropout_prob', type=float, default=0.2, help='quantize_dropout_prob')
|
44 |
+
# parser.add_argument('--use_vq_prob', type=float, default=0.8, help='quantize_dropout_prob')
|
45 |
+
|
46 |
+
parser.add_argument('--ext', type=str, default='default', help='reconstruction loss')
|
47 |
+
|
48 |
+
|
49 |
+
## other
|
50 |
+
parser.add_argument('--name', type=str, default="test", help='Name of this trial')
|
51 |
+
parser.add_argument('--is_continue', action="store_true", help='Name of this trial')
|
52 |
+
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
53 |
+
parser.add_argument('--log_every', default=10, type=int, help='iter log frequency')
|
54 |
+
parser.add_argument('--save_latest', default=500, type=int, help='iter save latest model frequency')
|
55 |
+
parser.add_argument('--save_every_e', default=2, type=int, help='save model every n epoch')
|
56 |
+
parser.add_argument('--eval_every_e', default=1, type=int, help='save eval results every n epoch')
|
57 |
+
# parser.add_argument('--early_stop_e', default=5, type=int, help='early stopping epoch')
|
58 |
+
parser.add_argument('--feat_bias', type=float, default=5, help='Layers of GRU')
|
59 |
+
|
60 |
+
parser.add_argument('--which_epoch', type=str, default="all", help='Name of this trial')
|
61 |
+
|
62 |
+
## For Res Predictor only
|
63 |
+
parser.add_argument('--vq_name', type=str, default="rvq_nq6_dc512_nc512_noshare_qdp0.2", help='Name of this trial')
|
64 |
+
parser.add_argument('--n_res', type=int, default=2, help='Name of this trial')
|
65 |
+
parser.add_argument('--do_vq_res', action="store_true")
|
66 |
+
parser.add_argument("--seed", default=3407, type=int)
|
67 |
+
|
68 |
+
opt = parser.parse_args()
|
69 |
+
torch.cuda.set_device(opt.gpu_id)
|
70 |
+
|
71 |
+
args = vars(opt)
|
72 |
+
|
73 |
+
print('------------ Options -------------')
|
74 |
+
for k, v in sorted(args.items()):
|
75 |
+
print('%s: %s' % (str(k), str(v)))
|
76 |
+
print('-------------- End ----------------')
|
77 |
+
opt.is_train = is_train
|
78 |
+
if is_train:
|
79 |
+
# save to the disk
|
80 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
81 |
+
if not os.path.exists(expr_dir):
|
82 |
+
os.makedirs(expr_dir)
|
83 |
+
file_name = os.path.join(expr_dir, 'opt.txt')
|
84 |
+
with open(file_name, 'wt') as opt_file:
|
85 |
+
opt_file.write('------------ Options -------------\n')
|
86 |
+
for k, v in sorted(args.items()):
|
87 |
+
opt_file.write('%s: %s\n' % (str(k), str(v)))
|
88 |
+
opt_file.write('-------------- End ----------------\n')
|
89 |
+
return opt
|
prepare/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
prepare/download_evaluator.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cd checkpoints
|
2 |
+
|
3 |
+
cd t2m
|
4 |
+
echo -e "Downloading evaluation models for HumanML3D dataset"
|
5 |
+
gdown --fuzzy https://drive.google.com/file/d/1oLhSH7zTlYkQdUWPv3-v4opigB7pXkFk/view?usp=sharing
|
6 |
+
echo -e "Unzipping humanml3d_evaluator.zip"
|
7 |
+
unzip humanml3d_evaluator.zip
|
8 |
+
|
9 |
+
echo -e "Clearning humanml3d_evaluator.zip"
|
10 |
+
rm humanml3d_evaluator.zip
|
11 |
+
|
12 |
+
cd ../kit/
|
13 |
+
echo -e "Downloading pretrained models for KIT-ML dataset"
|
14 |
+
gdown --fuzzy https://drive.google.com/file/d/115n1ijntyKDDIZZEuA_aBgffyplNE5az/view?usp=sharing
|
15 |
+
|
16 |
+
echo -e "Unzipping kit_evaluator.zip"
|
17 |
+
unzip kit_evaluator.zip
|
18 |
+
|
19 |
+
echo -e "Clearning kit_evaluator.zip"
|
20 |
+
rm kit_evaluator.zip
|
21 |
+
|
22 |
+
cd ../../
|
23 |
+
|
24 |
+
echo -e "Downloading done!"
|
prepare/download_glove.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
echo -e "Downloading glove (in use by the evaluators, not by MoMask itself)"
|
2 |
+
gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing
|
3 |
+
rm -rf glove
|
4 |
+
|
5 |
+
unzip glove.zip
|
6 |
+
echo -e "Cleaning\n"
|
7 |
+
rm glove.zip
|
8 |
+
|
9 |
+
echo -e "Downloading done!"
|
prepare/download_models.sh
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rm -rf checkpoints
|
2 |
+
mkdir checkpoints
|
3 |
+
cd checkpoints
|
4 |
+
mkdir t2m
|
5 |
+
|
6 |
+
cd t2m
|
7 |
+
echo -e "Downloading pretrained models for HumanML3D dataset"
|
8 |
+
gdown --fuzzy https://drive.google.com/file/d/1dtKP2xBk-UjG9o16MVfBJDmGNSI56Dch/view?usp=sharing
|
9 |
+
|
10 |
+
echo -e "Unzipping humanml3d_models.zip"
|
11 |
+
unzip humanml3d_models.zip
|
12 |
+
|
13 |
+
echo -e "Cleaning humanml3d_models.zip"
|
14 |
+
rm humanml3d_models.zip
|
15 |
+
|
16 |
+
cd ../
|
17 |
+
mkdir kit
|
18 |
+
cd kit
|
19 |
+
|
20 |
+
echo -e "Downloading pretrained models for KIT-ML dataset"
|
21 |
+
gdown --fuzzy https://drive.google.com/file/d/1MNMdUdn5QoO8UW1iwTcZ0QNaLSH4A6G9/view?usp=sharing
|
22 |
+
|
23 |
+
echo -e "Unzipping kit_models.zip"
|
24 |
+
unzip kit_models.zip
|
25 |
+
|
26 |
+
echo -e "Cleaning kit_models.zip"
|
27 |
+
rm kit_models.zip
|
28 |
+
|
29 |
+
cd ../../
|
30 |
+
|
31 |
+
echo -e "Downloading done!"
|
prepare/download_models_demo.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rm -rf checkpoints
|
2 |
+
mkdir checkpoints
|
3 |
+
cd checkpoints
|
4 |
+
mkdir t2m
|
5 |
+
cd t2m
|
6 |
+
echo -e "Downloading pretrained models for HumanML3D dataset"
|
7 |
+
gdown --fuzzy https://drive.google.com/file/d/1dtKP2xBk-UjG9o16MVfBJDmGNSI56Dch/view?usp=sharing
|
8 |
+
unzip humanml3d_models.zip
|
9 |
+
rm humanml3d_models.zip
|
10 |
+
cd ../../
|
requirements.txt
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py @ file:///home/conda/feedstock_root/build_artifacts/absl-py_1673535674859/work
|
2 |
+
aiofiles==23.2.1
|
3 |
+
aiohttp @ file:///croot/aiohttp_1670009560265/work
|
4 |
+
aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
|
5 |
+
altair==5.0.1
|
6 |
+
anyio==3.7.1
|
7 |
+
async-timeout @ file:///opt/conda/conda-bld/async-timeout_1664876359750/work
|
8 |
+
asynctest==0.13.0
|
9 |
+
attrs @ file:///croot/attrs_1668696182826/work
|
10 |
+
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1649463573192/work
|
11 |
+
blinker==1.4
|
12 |
+
blis==0.7.8
|
13 |
+
blobfile==2.0.2
|
14 |
+
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854164153/work
|
15 |
+
cachetools==5.3.1
|
16 |
+
catalogue @ file:///home/conda/feedstock_root/build_artifacts/catalogue_1661366519934/work
|
17 |
+
certifi @ file:///croot/certifi_1671487769961/work/certifi
|
18 |
+
cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work
|
19 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
|
20 |
+
chumpy==0.70
|
21 |
+
click==8.1.3
|
22 |
+
clip @ git+https://github.com/openai/CLIP.git@a9b1bf5920416aaeaec965c25dd9e8f98c864f16
|
23 |
+
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1655412516417/work
|
24 |
+
confection==0.0.2
|
25 |
+
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1636040646098/work
|
26 |
+
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
|
27 |
+
cymem @ file:///home/conda/feedstock_root/build_artifacts/cymem_1649412169067/work
|
28 |
+
dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
|
29 |
+
einops==0.6.1
|
30 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl
|
31 |
+
exceptiongroup==1.2.0
|
32 |
+
fastapi==0.103.2
|
33 |
+
ffmpy==0.3.1
|
34 |
+
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1660129891014/work
|
35 |
+
frozenlist @ file:///croot/frozenlist_1670004507010/work
|
36 |
+
fsspec==2023.1.0
|
37 |
+
ftfy==6.1.1
|
38 |
+
gdown==4.7.1
|
39 |
+
google-auth==2.19.1
|
40 |
+
google-auth-oauthlib==0.4.6
|
41 |
+
gradio==3.34.0
|
42 |
+
gradio_client==0.2.6
|
43 |
+
grpcio==1.54.2
|
44 |
+
h11==0.14.0
|
45 |
+
h5py @ file:///tmp/abs_4aewd3wzey/croots/recipe/h5py_1659091371897/work
|
46 |
+
httpcore==0.17.3
|
47 |
+
httpx==0.24.1
|
48 |
+
huggingface-hub==0.16.4
|
49 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
|
50 |
+
importlib-metadata==5.0.0
|
51 |
+
importlib-resources==5.12.0
|
52 |
+
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
|
53 |
+
joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work
|
54 |
+
jsonschema==4.17.3
|
55 |
+
kiwisolver @ file:///opt/conda/conda-bld/kiwisolver_1653292039266/work
|
56 |
+
langcodes @ file:///home/conda/feedstock_root/build_artifacts/langcodes_1636741340529/work
|
57 |
+
linkify-it-py==2.0.2
|
58 |
+
loralib==0.1.1
|
59 |
+
lxml==4.9.1
|
60 |
+
Markdown @ file:///home/conda/feedstock_root/build_artifacts/markdown_1679584000376/work
|
61 |
+
markdown-it-py==2.2.0
|
62 |
+
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1648737551960/work
|
63 |
+
matplotlib==3.1.3
|
64 |
+
mdit-py-plugins==0.3.3
|
65 |
+
mdurl==0.1.2
|
66 |
+
mkl-fft==1.3.1
|
67 |
+
mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work
|
68 |
+
mkl-service==2.4.0
|
69 |
+
multidict @ file:///croot/multidict_1665674239670/work
|
70 |
+
murmurhash==1.0.8
|
71 |
+
numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1653915516269/work
|
72 |
+
oauthlib==3.2.2
|
73 |
+
orjson==3.9.7
|
74 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1637239678211/work
|
75 |
+
pandas==1.3.5
|
76 |
+
pathy @ file:///home/conda/feedstock_root/build_artifacts/pathy_1656568808184/work
|
77 |
+
Pillow==9.2.0
|
78 |
+
pkgutil_resolve_name==1.3.10
|
79 |
+
preshed==3.0.7
|
80 |
+
protobuf==3.20.3
|
81 |
+
pyasn1==0.5.0
|
82 |
+
pyasn1-modules==0.3.0
|
83 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
|
84 |
+
pycryptodomex==3.15.0
|
85 |
+
pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1636021129189/work
|
86 |
+
pydub==0.25.1
|
87 |
+
Pygments==2.17.2
|
88 |
+
PyJWT @ file:///opt/conda/conda-bld/pyjwt_1657544592787/work
|
89 |
+
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1663846997386/work
|
90 |
+
pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
|
91 |
+
pyrsistent==0.19.3
|
92 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1648857264451/work
|
93 |
+
python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
|
94 |
+
python-multipart==0.0.6
|
95 |
+
pytz==2023.3
|
96 |
+
PyYAML==6.0
|
97 |
+
regex==2022.9.13
|
98 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1661872987712/work
|
99 |
+
requests-oauthlib==1.3.1
|
100 |
+
rsa==4.9
|
101 |
+
scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1642601761909/work
|
102 |
+
scipy @ file:///opt/conda/conda-bld/scipy_1661390393401/work
|
103 |
+
semantic-version==2.10.0
|
104 |
+
shellingham @ file:///home/conda/feedstock_root/build_artifacts/shellingham_1659638615822/work
|
105 |
+
six @ file:///tmp/build/80754af9/six_1644875935023/work
|
106 |
+
smart-open @ file:///home/conda/feedstock_root/build_artifacts/smart_open_1630238320325/work
|
107 |
+
smplx==0.1.28
|
108 |
+
sniffio==1.3.0
|
109 |
+
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
|
110 |
+
spacy @ file:///opt/conda/conda-bld/spacy_1656601313568/work
|
111 |
+
spacy-legacy @ file:///home/conda/feedstock_root/build_artifacts/spacy-legacy_1660748275723/work
|
112 |
+
spacy-loggers @ file:///home/conda/feedstock_root/build_artifacts/spacy-loggers_1661365735520/work
|
113 |
+
srsly==2.4.4
|
114 |
+
starlette==0.27.0
|
115 |
+
tensorboard==2.11.2
|
116 |
+
tensorboard-data-server==0.6.1
|
117 |
+
tensorboard-plugin-wit @ file:///home/builder/tkoch/workspace/tensorflow/tensorboard-plugin-wit_1658918494740/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
|
118 |
+
tensorboardX==2.6
|
119 |
+
thinc==8.0.17
|
120 |
+
threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
|
121 |
+
toolz==0.12.0
|
122 |
+
torch==1.7.1
|
123 |
+
torch-tb-profiler==0.4.1
|
124 |
+
torchaudio==0.7.0a0+a853dff
|
125 |
+
torchvision==0.8.2
|
126 |
+
tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work
|
127 |
+
tqdm @ file:///opt/conda/conda-bld/tqdm_1664392687731/work
|
128 |
+
trimesh @ file:///home/conda/feedstock_root/build_artifacts/trimesh_1664841281434/work
|
129 |
+
typer @ file:///home/conda/feedstock_root/build_artifacts/typer_1657029164904/work
|
130 |
+
typing_extensions==4.7.1
|
131 |
+
uc-micro-py==1.0.2
|
132 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1678635778344/work
|
133 |
+
uvicorn==0.22.0
|
134 |
+
vector-quantize-pytorch==1.6.30
|
135 |
+
wasabi @ file:///home/conda/feedstock_root/build_artifacts/wasabi_1668249950899/work
|
136 |
+
wcwidth==0.2.5
|
137 |
+
websockets==11.0.3
|
138 |
+
Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1676411946679/work
|
139 |
+
yarl @ file:///opt/conda/conda-bld/yarl_1661437085904/work
|
140 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1659400682470/work
|
train_res_transformer.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from os.path import join as pjoin
|
7 |
+
|
8 |
+
from models.mask_transformer.transformer import ResidualTransformer
|
9 |
+
from models.mask_transformer.transformer_trainer import ResidualTransformerTrainer
|
10 |
+
from models.vq.model import RVQVAE
|
11 |
+
|
12 |
+
from options.train_option import TrainT2MOptions
|
13 |
+
|
14 |
+
from utils.plot_script import plot_3d_motion
|
15 |
+
from utils.motion_process import recover_from_ric
|
16 |
+
from utils.get_opt import get_opt
|
17 |
+
from utils.fixseed import fixseed
|
18 |
+
from utils.paramUtil import t2m_kinematic_chain, kit_kinematic_chain
|
19 |
+
|
20 |
+
from data.t2m_dataset import Text2MotionDataset
|
21 |
+
from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
|
22 |
+
from models.t2m_eval_wrapper import EvaluatorModelWrapper
|
23 |
+
|
24 |
+
|
25 |
+
def plot_t2m(data, save_dir, captions, m_lengths):
|
26 |
+
data = train_dataset.inv_transform(data)
|
27 |
+
|
28 |
+
# print(ep_curves.shape)
|
29 |
+
for i, (caption, joint_data) in enumerate(zip(captions, data)):
|
30 |
+
joint_data = joint_data[:m_lengths[i]]
|
31 |
+
joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy()
|
32 |
+
save_path = pjoin(save_dir, '%02d.mp4'%i)
|
33 |
+
# print(joint.shape)
|
34 |
+
plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)
|
35 |
+
|
36 |
+
def load_vq_model():
|
37 |
+
opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
|
38 |
+
vq_opt = get_opt(opt_path, opt.device)
|
39 |
+
vq_model = RVQVAE(vq_opt,
|
40 |
+
dim_pose,
|
41 |
+
vq_opt.nb_code,
|
42 |
+
vq_opt.code_dim,
|
43 |
+
vq_opt.output_emb_width,
|
44 |
+
vq_opt.down_t,
|
45 |
+
vq_opt.stride_t,
|
46 |
+
vq_opt.width,
|
47 |
+
vq_opt.depth,
|
48 |
+
vq_opt.dilation_growth_rate,
|
49 |
+
vq_opt.vq_act,
|
50 |
+
vq_opt.vq_norm)
|
51 |
+
ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
|
52 |
+
map_location=opt.device)
|
53 |
+
model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
|
54 |
+
vq_model.load_state_dict(ckpt[model_key])
|
55 |
+
print(f'Loading VQ Model {opt.vq_name}')
|
56 |
+
vq_model.to(opt.device)
|
57 |
+
return vq_model, vq_opt
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
parser = TrainT2MOptions()
|
61 |
+
opt = parser.parse()
|
62 |
+
fixseed(opt.seed)
|
63 |
+
|
64 |
+
opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
|
65 |
+
torch.autograd.set_detect_anomaly(True)
|
66 |
+
|
67 |
+
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
68 |
+
opt.model_dir = pjoin(opt.save_root, 'model')
|
69 |
+
# opt.meta_dir = pjoin(opt.save_root, 'meta')
|
70 |
+
opt.eval_dir = pjoin(opt.save_root, 'animation')
|
71 |
+
opt.log_dir = pjoin('./log/res/', opt.dataset_name, opt.name)
|
72 |
+
|
73 |
+
os.makedirs(opt.model_dir, exist_ok=True)
|
74 |
+
# os.makedirs(opt.meta_dir, exist_ok=True)
|
75 |
+
os.makedirs(opt.eval_dir, exist_ok=True)
|
76 |
+
os.makedirs(opt.log_dir, exist_ok=True)
|
77 |
+
|
78 |
+
if opt.dataset_name == 't2m':
|
79 |
+
opt.data_root = './dataset/HumanML3D'
|
80 |
+
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
|
81 |
+
opt.joints_num = 22
|
82 |
+
opt.max_motion_len = 55
|
83 |
+
dim_pose = 263
|
84 |
+
radius = 4
|
85 |
+
fps = 20
|
86 |
+
kinematic_chain = t2m_kinematic_chain
|
87 |
+
dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt'
|
88 |
+
|
89 |
+
elif opt.dataset_name == 'kit': #TODO
|
90 |
+
opt.data_root = './dataset/KIT-ML'
|
91 |
+
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
|
92 |
+
opt.joints_num = 21
|
93 |
+
radius = 240 * 8
|
94 |
+
fps = 12.5
|
95 |
+
dim_pose = 251
|
96 |
+
opt.max_motion_len = 55
|
97 |
+
kinematic_chain = kit_kinematic_chain
|
98 |
+
dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt'
|
99 |
+
|
100 |
+
else:
|
101 |
+
raise KeyError('Dataset Does Not Exist')
|
102 |
+
|
103 |
+
opt.text_dir = pjoin(opt.data_root, 'texts')
|
104 |
+
|
105 |
+
vq_model, vq_opt = load_vq_model()
|
106 |
+
|
107 |
+
clip_version = 'ViT-B/32'
|
108 |
+
|
109 |
+
opt.num_tokens = vq_opt.nb_code
|
110 |
+
opt.num_quantizers = vq_opt.num_quantizers
|
111 |
+
|
112 |
+
# if opt.is_v2:
|
113 |
+
res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
|
114 |
+
cond_mode='text',
|
115 |
+
latent_dim=opt.latent_dim,
|
116 |
+
ff_size=opt.ff_size,
|
117 |
+
num_layers=opt.n_layers,
|
118 |
+
num_heads=opt.n_heads,
|
119 |
+
dropout=opt.dropout,
|
120 |
+
clip_dim=512,
|
121 |
+
shared_codebook=vq_opt.shared_codebook,
|
122 |
+
cond_drop_prob=opt.cond_drop_prob,
|
123 |
+
# codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
|
124 |
+
share_weight=opt.share_weight,
|
125 |
+
clip_version=clip_version,
|
126 |
+
opt=opt)
|
127 |
+
# else:
|
128 |
+
# res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
|
129 |
+
# cond_mode='text',
|
130 |
+
# latent_dim=opt.latent_dim,
|
131 |
+
# ff_size=opt.ff_size,
|
132 |
+
# num_layers=opt.n_layers,
|
133 |
+
# num_heads=opt.n_heads,
|
134 |
+
# dropout=opt.dropout,
|
135 |
+
# clip_dim=512,
|
136 |
+
# shared_codebook=vq_opt.shared_codebook,
|
137 |
+
# cond_drop_prob=opt.cond_drop_prob,
|
138 |
+
# # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
|
139 |
+
# clip_version=clip_version,
|
140 |
+
# opt=opt)
|
141 |
+
|
142 |
+
|
143 |
+
all_params = 0
|
144 |
+
pc_transformer = sum(param.numel() for param in res_transformer.parameters_wo_clip())
|
145 |
+
|
146 |
+
print(res_transformer)
|
147 |
+
# print("Total parameters of t2m_transformer net: {:.2f}M".format(pc_transformer / 1000_000))
|
148 |
+
all_params += pc_transformer
|
149 |
+
|
150 |
+
print('Total parameters of all models: {:.2f}M'.format(all_params / 1000_000))
|
151 |
+
|
152 |
+
mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'mean.npy'))
|
153 |
+
std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'std.npy'))
|
154 |
+
|
155 |
+
train_split_file = pjoin(opt.data_root, 'train.txt')
|
156 |
+
val_split_file = pjoin(opt.data_root, 'val.txt')
|
157 |
+
|
158 |
+
train_dataset = Text2MotionDataset(opt, mean, std, train_split_file)
|
159 |
+
val_dataset = Text2MotionDataset(opt, mean, std, val_split_file)
|
160 |
+
|
161 |
+
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
|
162 |
+
val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
|
163 |
+
|
164 |
+
eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'val', device=opt.device)
|
165 |
+
|
166 |
+
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
|
167 |
+
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
|
168 |
+
|
169 |
+
trainer = ResidualTransformerTrainer(opt, res_transformer, vq_model)
|
170 |
+
|
171 |
+
trainer.train(train_loader, val_loader, eval_val_loader, eval_wrapper=eval_wrapper, plot_eval=plot_t2m)
|
train_t2m_transformer.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from os.path import join as pjoin
|
7 |
+
|
8 |
+
from models.mask_transformer.transformer import MaskTransformer
|
9 |
+
from models.mask_transformer.transformer_trainer import MaskTransformerTrainer
|
10 |
+
from models.vq.model import RVQVAE
|
11 |
+
|
12 |
+
from options.train_option import TrainT2MOptions
|
13 |
+
|
14 |
+
from utils.plot_script import plot_3d_motion
|
15 |
+
from utils.motion_process import recover_from_ric
|
16 |
+
from utils.get_opt import get_opt
|
17 |
+
from utils.fixseed import fixseed
|
18 |
+
from utils.paramUtil import t2m_kinematic_chain, kit_kinematic_chain
|
19 |
+
|
20 |
+
from data.t2m_dataset import Text2MotionDataset
|
21 |
+
from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
|
22 |
+
from models.t2m_eval_wrapper import EvaluatorModelWrapper
|
23 |
+
|
24 |
+
|
25 |
+
def plot_t2m(data, save_dir, captions, m_lengths):
|
26 |
+
data = train_dataset.inv_transform(data)
|
27 |
+
|
28 |
+
# print(ep_curves.shape)
|
29 |
+
for i, (caption, joint_data) in enumerate(zip(captions, data)):
|
30 |
+
joint_data = joint_data[:m_lengths[i]]
|
31 |
+
joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy()
|
32 |
+
save_path = pjoin(save_dir, '%02d.mp4'%i)
|
33 |
+
# print(joint.shape)
|
34 |
+
plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)
|
35 |
+
|
36 |
+
def load_vq_model():
|
37 |
+
opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
|
38 |
+
vq_opt = get_opt(opt_path, opt.device)
|
39 |
+
vq_model = RVQVAE(vq_opt,
|
40 |
+
dim_pose,
|
41 |
+
vq_opt.nb_code,
|
42 |
+
vq_opt.code_dim,
|
43 |
+
vq_opt.output_emb_width,
|
44 |
+
vq_opt.down_t,
|
45 |
+
vq_opt.stride_t,
|
46 |
+
vq_opt.width,
|
47 |
+
vq_opt.depth,
|
48 |
+
vq_opt.dilation_growth_rate,
|
49 |
+
vq_opt.vq_act,
|
50 |
+
vq_opt.vq_norm)
|
51 |
+
ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
|
52 |
+
map_location='cpu')
|
53 |
+
model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
|
54 |
+
vq_model.load_state_dict(ckpt[model_key])
|
55 |
+
print(f'Loading VQ Model {opt.vq_name}')
|
56 |
+
return vq_model, vq_opt
|
57 |
+
|
58 |
+
if __name__ == '__main__':
|
59 |
+
parser = TrainT2MOptions()
|
60 |
+
opt = parser.parse()
|
61 |
+
fixseed(opt.seed)
|
62 |
+
|
63 |
+
opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
|
64 |
+
torch.autograd.set_detect_anomaly(True)
|
65 |
+
|
66 |
+
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
67 |
+
opt.model_dir = pjoin(opt.save_root, 'model')
|
68 |
+
# opt.meta_dir = pjoin(opt.save_root, 'meta')
|
69 |
+
opt.eval_dir = pjoin(opt.save_root, 'animation')
|
70 |
+
opt.log_dir = pjoin('./log/t2m/', opt.dataset_name, opt.name)
|
71 |
+
|
72 |
+
os.makedirs(opt.model_dir, exist_ok=True)
|
73 |
+
# os.makedirs(opt.meta_dir, exist_ok=True)
|
74 |
+
os.makedirs(opt.eval_dir, exist_ok=True)
|
75 |
+
os.makedirs(opt.log_dir, exist_ok=True)
|
76 |
+
|
77 |
+
if opt.dataset_name == 't2m':
|
78 |
+
opt.data_root = './dataset/HumanML3D'
|
79 |
+
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
|
80 |
+
opt.joints_num = 22
|
81 |
+
opt.max_motion_len = 55
|
82 |
+
dim_pose = 263
|
83 |
+
radius = 4
|
84 |
+
fps = 20
|
85 |
+
kinematic_chain = t2m_kinematic_chain
|
86 |
+
dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt'
|
87 |
+
|
88 |
+
elif opt.dataset_name == 'kit': #TODO
|
89 |
+
opt.data_root = './dataset/KIT-ML'
|
90 |
+
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
|
91 |
+
opt.joints_num = 21
|
92 |
+
radius = 240 * 8
|
93 |
+
fps = 12.5
|
94 |
+
dim_pose = 251
|
95 |
+
opt.max_motion_len = 55
|
96 |
+
kinematic_chain = kit_kinematic_chain
|
97 |
+
dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt'
|
98 |
+
|
99 |
+
else:
|
100 |
+
raise KeyError('Dataset Does Not Exist')
|
101 |
+
|
102 |
+
opt.text_dir = pjoin(opt.data_root, 'texts')
|
103 |
+
|
104 |
+
vq_model, vq_opt = load_vq_model()
|
105 |
+
|
106 |
+
clip_version = 'ViT-B/32'
|
107 |
+
|
108 |
+
opt.num_tokens = vq_opt.nb_code
|
109 |
+
|
110 |
+
t2m_transformer = MaskTransformer(code_dim=vq_opt.code_dim,
|
111 |
+
cond_mode='text',
|
112 |
+
latent_dim=opt.latent_dim,
|
113 |
+
ff_size=opt.ff_size,
|
114 |
+
num_layers=opt.n_layers,
|
115 |
+
num_heads=opt.n_heads,
|
116 |
+
dropout=opt.dropout,
|
117 |
+
clip_dim=512,
|
118 |
+
cond_drop_prob=opt.cond_drop_prob,
|
119 |
+
clip_version=clip_version,
|
120 |
+
opt=opt)
|
121 |
+
|
122 |
+
# if opt.fix_token_emb:
|
123 |
+
# t2m_transformer.load_and_freeze_token_emb(vq_model.quantizer.codebooks[0])
|
124 |
+
|
125 |
+
all_params = 0
|
126 |
+
pc_transformer = sum(param.numel() for param in t2m_transformer.parameters_wo_clip())
|
127 |
+
|
128 |
+
# print(t2m_transformer)
|
129 |
+
# print("Total parameters of t2m_transformer net: {:.2f}M".format(pc_transformer / 1000_000))
|
130 |
+
all_params += pc_transformer
|
131 |
+
|
132 |
+
print('Total parameters of all models: {:.2f}M'.format(all_params / 1000_000))
|
133 |
+
|
134 |
+
mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'mean.npy'))
|
135 |
+
std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'std.npy'))
|
136 |
+
|
137 |
+
train_split_file = pjoin(opt.data_root, 'train.txt')
|
138 |
+
val_split_file = pjoin(opt.data_root, 'val.txt')
|
139 |
+
|
140 |
+
train_dataset = Text2MotionDataset(opt, mean, std, train_split_file)
|
141 |
+
val_dataset = Text2MotionDataset(opt, mean, std, val_split_file)
|
142 |
+
|
143 |
+
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
|
144 |
+
val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
|
145 |
+
|
146 |
+
eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'val', device=opt.device)
|
147 |
+
|
148 |
+
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
|
149 |
+
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
|
150 |
+
|
151 |
+
trainer = MaskTransformerTrainer(opt, t2m_transformer, vq_model)
|
152 |
+
|
153 |
+
trainer.train(train_loader, val_loader, eval_val_loader, eval_wrapper=eval_wrapper, plot_eval=plot_t2m)
|