mvreddy13
commited on
Commit
·
f0c7f08
1
Parent(s):
ed77e1a
Adding new Folders
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +159 -13
- __init__.py +0 -0
- config/LS3DCG.json +64 -0
- config/body_pixel.json +63 -0
- config/body_vq.json +62 -0
- config/face.json +59 -0
- data_utils/__init__.py +3 -0
- data_utils/__pycache__/__init__.cpython-37.pyc +0 -0
- data_utils/__pycache__/consts.cpython-37.pyc +0 -0
- data_utils/__pycache__/dataloader_torch.cpython-37.pyc +0 -0
- data_utils/__pycache__/lower_body.cpython-37.pyc +0 -0
- data_utils/__pycache__/mesh_dataset.cpython-37.pyc +0 -0
- data_utils/__pycache__/rotation_conversion.cpython-37.pyc +0 -0
- data_utils/__pycache__/utils.cpython-37.pyc +0 -0
- data_utils/apply_split.py +51 -0
- data_utils/axis2matrix.py +29 -0
- data_utils/consts.py +0 -0
- data_utils/dataloader_torch.py +279 -0
- data_utils/dataset_preprocess.py +170 -0
- data_utils/get_j.py +51 -0
- data_utils/hand_component.json +0 -0
- data_utils/lower_body.py +143 -0
- data_utils/mesh_dataset.py +348 -0
- data_utils/rotation_conversion.py +551 -0
- data_utils/split_more_than_2s.pkl +3 -0
- data_utils/split_train_val_test.py +27 -0
- data_utils/train_val_test.json +0 -0
- data_utils/utils.py +318 -0
- evaluation/FGD.py +199 -0
- evaluation/__init__.py +0 -0
- evaluation/__pycache__/__init__.cpython-37.pyc +0 -0
- evaluation/__pycache__/metrics.cpython-37.pyc +0 -0
- evaluation/diversity_LVD.py +64 -0
- evaluation/get_quality_samples.py +62 -0
- evaluation/metrics.py +109 -0
- evaluation/mode_transition.py +60 -0
- evaluation/peak_velocity.py +65 -0
- evaluation/util.py +148 -0
- losses/__init__.py +1 -0
- losses/__pycache__/__init__.cpython-37.pyc +0 -0
- losses/__pycache__/losses.cpython-37.pyc +0 -0
- losses/losses.py +91 -0
- nets/LS3DCG.py +414 -0
- nets/__init__.py +8 -0
- nets/__pycache__/__init__.cpython-37.pyc +0 -0
- nets/__pycache__/base.cpython-37.pyc +0 -0
- nets/__pycache__/init_model.cpython-37.pyc +0 -0
- nets/__pycache__/layers.cpython-37.pyc +0 -0
- nets/__pycache__/smplx_body_pixel.cpython-37.pyc +0 -0
- nets/__pycache__/smplx_body_vq.cpython-37.pyc +0 -0
README.md
CHANGED
@@ -1,13 +1,159 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# TalkSHOW: Generating Holistic 3D Human Motion from Speech [CVPR2023]
|
2 |
+
|
3 |
+
The official PyTorch implementation of the **CVPR2023** paper [**"Generating Holistic 3D Human Motion from Speech"**](https://arxiv.org/abs/2212.04420).
|
4 |
+
|
5 |
+
Please visit our [**webpage**](https://talkshow.is.tue.mpg.de/) for more details.
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
## HighLight
|
10 |
+
|
11 |
+
We directly provide the input and our output for the demo data, you can find them in `/demo/` and `/demo_audio/`. TalkSHOW can generalize well on English, French, Songs so far. Looking forward to more demos.
|
12 |
+
|
13 |
+
You can directly use the generated motion to animate your 3D character or your own digital avatar. We will provide more demos, please stay tuned. And we are quite looking forward to your pull request.
|
14 |
+
|
15 |
+
## Notes
|
16 |
+
|
17 |
+
We are using 100 dimension parameters for SMPL-X facial expression, if you need other dimensions parameters, you can use this code to convert.
|
18 |
+
|
19 |
+
```
|
20 |
+
https://github.com/yhw-yhw/SHOW/blob/main/cvt_exp_dim_tool.py
|
21 |
+
```
|
22 |
+
|
23 |
+
## TODO
|
24 |
+
|
25 |
+
- [x] [🤗Hugging Face Demo](https://huggingface.co/spaces/feifeifeiliu/TalkSHOW)
|
26 |
+
- [ ] Animated 2D videos by the generated motion from TalkSHOW.
|
27 |
+
|
28 |
+
|
29 |
+
## Getting started
|
30 |
+
|
31 |
+
The training code was tested on `Ubuntu 18.04.5 LTS` and the visualization code was test on `Windows 10`, and it requires:
|
32 |
+
|
33 |
+
* Python 3.7
|
34 |
+
* conda3 or miniconda3
|
35 |
+
* CUDA capable GPU (one is enough)
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
### 1. Setup environment
|
40 |
+
|
41 |
+
Clone the repo:
|
42 |
+
```bash
|
43 |
+
git clone https://github.com/yhw-yhw/TalkSHOW
|
44 |
+
cd TalkSHOW
|
45 |
+
```
|
46 |
+
Create conda environment:
|
47 |
+
```bash
|
48 |
+
conda create --name talkshow python=3.7
|
49 |
+
conda activate talkshow
|
50 |
+
```
|
51 |
+
Please install pytorch (v1.10.1).
|
52 |
+
|
53 |
+
pip install -r requirements.txt
|
54 |
+
|
55 |
+
Please install [**MPI-Mesh**](https://github.com/MPI-IS/mesh).
|
56 |
+
|
57 |
+
### 2. Get data
|
58 |
+
|
59 |
+
Please note that if you only want to generate demo videos, you can skip this step and directly download the pretrained models.
|
60 |
+
|
61 |
+
Download [**SHOW_dataset_v1.0.zip**](https://download.is.tue.mpg.de/download.php?domain=talkshow&resume=1&sfile=SHOW_dataset_v1.0.zip) from [**TalkSHOW download webpage**](https://talkshow.is.tue.mpg.de/download.php),
|
62 |
+
unzip using ``for i in $(ls *.tar.gz);do tar xvf $i;done``.
|
63 |
+
|
64 |
+
~~Run ``python data_utils/dataset_preprocess.py`` to check and split dataset.
|
65 |
+
Modify ``data_root`` in ``config/*.json`` to the dataset-path.~~
|
66 |
+
|
67 |
+
Modify ``data_root`` in ``data_utils/apply_split.py`` to the dataset path and run it to apply ``data_utils/split_more_than_2s.pkl`` to the dataset.
|
68 |
+
|
69 |
+
We will update the benchmark soon.
|
70 |
+
|
71 |
+
### 3. Download the pretrained models (Optional)
|
72 |
+
|
73 |
+
Download [**pretrained models**](https://drive.google.com/file/d/1bC0ZTza8HOhLB46WOJ05sBywFvcotDZG/view?usp=sharing),
|
74 |
+
unzip and place it in the TalkSHOW folder, i.e. ``path-to-TalkSHOW/experiments``.
|
75 |
+
|
76 |
+
### 4. Training
|
77 |
+
Please note that the process of loading data for the first time can be quite slow. If you have already completed the loading process, setting ``dataset_load_mode`` to ``pickle`` in ``config/[config_name].json`` will make the loading process much faster.
|
78 |
+
|
79 |
+
# 1. Train VQ-VAEs.
|
80 |
+
bash train_body_vq.sh
|
81 |
+
# 2. Train PixelCNN. Please modify "Model:vq_path" in config/body_pixel.json to the path of VQ-VAEs.
|
82 |
+
bash train_body_pixel.sh
|
83 |
+
# 3. Train face generator.
|
84 |
+
bash train_face.sh
|
85 |
+
|
86 |
+
### 5. Testing
|
87 |
+
|
88 |
+
Modify the arguments in ``test_face.sh`` and ``test_body.sh``. Then
|
89 |
+
|
90 |
+
bash test_face.sh
|
91 |
+
bash test_body.sh
|
92 |
+
|
93 |
+
### 5. Visualization
|
94 |
+
|
95 |
+
If you ssh into the linux machine, NotImplementedError might occur. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error.
|
96 |
+
|
97 |
+
Download [**smplx model**](https://drive.google.com/file/d/1Ly_hQNLQcZ89KG0Nj4jYZwccQiimSUVn/view?usp=share_link) (Please register in the official [**SMPLX webpage**](https://smpl-x.is.tue.mpg.de) before you use it.)
|
98 |
+
and place it in ``path-to-TalkSHOW/visualise/smplx_model``.
|
99 |
+
To visualise the test set and generated result (in each video, left: generated result | right: ground truth).
|
100 |
+
The videos and generated motion data are saved in ``./visualise/video/body-pixel``:
|
101 |
+
|
102 |
+
bash visualise.sh
|
103 |
+
|
104 |
+
If you ssh into the linux machine, there might be an error about OffscreenRenderer. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error.
|
105 |
+
|
106 |
+
To reproduce the demo videos, run
|
107 |
+
```bash
|
108 |
+
# the whole body demo
|
109 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/1st-page.wav --id 0 --whole_body
|
110 |
+
# the face demo
|
111 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --only_face
|
112 |
+
# the identity-specific demo
|
113 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0
|
114 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 1
|
115 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 2
|
116 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 3 --stand
|
117 |
+
# the diversity demo
|
118 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --num_samples 12
|
119 |
+
# the french demo
|
120 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/french.wav --id 0
|
121 |
+
# the synthetic speech demo
|
122 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/rich.wav --id 0
|
123 |
+
# the song demo
|
124 |
+
python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/song.wav --id 0
|
125 |
+
````
|
126 |
+
### 6. Baseline
|
127 |
+
|
128 |
+
For training the reproducted "Learning Speech-driven 3D Conversational Gestures from Video" (Habibie et al.), you could run
|
129 |
+
```bash
|
130 |
+
python -W ignore scripts/train.py --speakers oliver seth conan chemistry --config_file ./config/LS3DCG.json
|
131 |
+
```
|
132 |
+
|
133 |
+
For visualization with the pretrained model, download the above [pretrained models](#3-download-the-pretrained-models--optional-) and run
|
134 |
+
```bash
|
135 |
+
python scripts/demo.py --config_file ./config/LS3DCG.json --infer --audio_file ./demo_audio/style.wav --body_model_name s2g_LS3DCG --body_model_path experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth --id 0
|
136 |
+
```
|
137 |
+
|
138 |
+
## Citation
|
139 |
+
If you find our work useful to your research, please consider citing:
|
140 |
+
```
|
141 |
+
@inproceedings{yi2022generating,
|
142 |
+
title={Generating Holistic 3D Human Motion from Speech},
|
143 |
+
author={Yi, Hongwei and Liang, Hualin and Liu, Yifei and Cao, Qiong and Wen, Yandong and Bolkart, Timo and Tao, Dacheng and Black, Michael J},
|
144 |
+
booktitle={CVPR},
|
145 |
+
year={2023}
|
146 |
+
}
|
147 |
+
```
|
148 |
+
|
149 |
+
## Acknowledgements
|
150 |
+
For functions or scripts that are based on external sources, we acknowledge the origin individually in each file.
|
151 |
+
Here are some great resources we benefit:
|
152 |
+
- [Freeform](https://github.com/TheTempAccount/Co-Speech-Motion-Generation) for training pipeline
|
153 |
+
- [MPI-Mesh](https://github.com/MPI-IS/mesh), [Pyrender](https://github.com/mmatl/pyrender), [Smplx](https://github.com/vchoutas/smplx), [VOCA](https://github.com/TimoBolkart/voca) for rendering
|
154 |
+
- [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) and [Faceformer](https://github.com/EvelynFan/FaceFormer) for audio encoder
|
155 |
+
|
156 |
+
## Contact
|
157 |
+
For questions, please contact [email protected] or [email protected] or [email protected] or [email protected]
|
158 |
+
|
159 |
+
For commercial licensing, please contact [email protected]
|
__init__.py
ADDED
File without changes
|
config/LS3DCG.json
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
3 |
+
"dataset_load_mode": "pickle",
|
4 |
+
"store_file_path": "store.pkl",
|
5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
8 |
+
"param": {
|
9 |
+
"w_j": 1,
|
10 |
+
"w_b": 1,
|
11 |
+
"w_h": 1
|
12 |
+
},
|
13 |
+
"Data": {
|
14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
15 |
+
"pklname": "_3d_mfcc.pkl",
|
16 |
+
"whole_video": false,
|
17 |
+
"pose": {
|
18 |
+
"normalization": false,
|
19 |
+
"convert_to_6d": false,
|
20 |
+
"norm_method": "all",
|
21 |
+
"augmentation": false,
|
22 |
+
"generate_length": 88,
|
23 |
+
"pre_pose_length": 0,
|
24 |
+
"pose_dim": 99,
|
25 |
+
"expression": true
|
26 |
+
},
|
27 |
+
"aud": {
|
28 |
+
"feat_method": "mfcc",
|
29 |
+
"aud_feat_dim": 64,
|
30 |
+
"aud_feat_win_size": null,
|
31 |
+
"context_info": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"Model": {
|
35 |
+
"model_type": "body",
|
36 |
+
"model_name": "s2g_LS3DCG",
|
37 |
+
"code_num": 2048,
|
38 |
+
"AudioOpt": "Adam",
|
39 |
+
"encoder_choice": "mfcc",
|
40 |
+
"gan": false
|
41 |
+
},
|
42 |
+
"DataLoader": {
|
43 |
+
"batch_size": 128,
|
44 |
+
"num_workers": 0
|
45 |
+
},
|
46 |
+
"Train": {
|
47 |
+
"epochs": 100,
|
48 |
+
"max_gradient_norm": 5,
|
49 |
+
"learning_rate": {
|
50 |
+
"generator_learning_rate": 1e-4,
|
51 |
+
"discriminator_learning_rate": 1e-4
|
52 |
+
},
|
53 |
+
"weights": {
|
54 |
+
"keypoint_loss_weight": 1.0,
|
55 |
+
"gan_loss_weight": 1.0
|
56 |
+
}
|
57 |
+
},
|
58 |
+
"Log": {
|
59 |
+
"save_every": 50,
|
60 |
+
"print_every": 200,
|
61 |
+
"name": "LS3DCG"
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
config/body_pixel.json
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
3 |
+
"dataset_load_mode": "json",
|
4 |
+
"store_file_path": "store.pkl",
|
5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
8 |
+
"param": {
|
9 |
+
"w_j": 1,
|
10 |
+
"w_b": 1,
|
11 |
+
"w_h": 1
|
12 |
+
},
|
13 |
+
"Data": {
|
14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
15 |
+
"pklname": "_3d_mfcc.pkl",
|
16 |
+
"whole_video": false,
|
17 |
+
"pose": {
|
18 |
+
"normalization": false,
|
19 |
+
"convert_to_6d": false,
|
20 |
+
"norm_method": "all",
|
21 |
+
"augmentation": false,
|
22 |
+
"generate_length": 88,
|
23 |
+
"pre_pose_length": 0,
|
24 |
+
"pose_dim": 99,
|
25 |
+
"expression": true
|
26 |
+
},
|
27 |
+
"aud": {
|
28 |
+
"feat_method": "mfcc",
|
29 |
+
"aud_feat_dim": 64,
|
30 |
+
"aud_feat_win_size": null,
|
31 |
+
"context_info": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"Model": {
|
35 |
+
"model_type": "body",
|
36 |
+
"model_name": "s2g_body_pixel",
|
37 |
+
"composition": true,
|
38 |
+
"code_num": 2048,
|
39 |
+
"bh_model": true,
|
40 |
+
"AudioOpt": "Adam",
|
41 |
+
"encoder_choice": "mfcc",
|
42 |
+
"gan": false,
|
43 |
+
"vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth"
|
44 |
+
},
|
45 |
+
"DataLoader": {
|
46 |
+
"batch_size": 128,
|
47 |
+
"num_workers": 0
|
48 |
+
},
|
49 |
+
"Train": {
|
50 |
+
"epochs": 100,
|
51 |
+
"max_gradient_norm": 5,
|
52 |
+
"learning_rate": {
|
53 |
+
"generator_learning_rate": 1e-4,
|
54 |
+
"discriminator_learning_rate": 1e-4
|
55 |
+
}
|
56 |
+
},
|
57 |
+
"Log": {
|
58 |
+
"save_every": 50,
|
59 |
+
"print_every": 200,
|
60 |
+
"name": "body-pixel2"
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
config/body_vq.json
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
3 |
+
"dataset_load_mode": "json",
|
4 |
+
"store_file_path": "store.pkl",
|
5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
8 |
+
"param": {
|
9 |
+
"w_j": 1,
|
10 |
+
"w_b": 1,
|
11 |
+
"w_h": 1
|
12 |
+
},
|
13 |
+
"Data": {
|
14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
15 |
+
"pklname": "_3d_mfcc.pkl",
|
16 |
+
"whole_video": false,
|
17 |
+
"pose": {
|
18 |
+
"normalization": false,
|
19 |
+
"convert_to_6d": false,
|
20 |
+
"norm_method": "all",
|
21 |
+
"augmentation": false,
|
22 |
+
"generate_length": 88,
|
23 |
+
"pre_pose_length": 0,
|
24 |
+
"pose_dim": 99,
|
25 |
+
"expression": true
|
26 |
+
},
|
27 |
+
"aud": {
|
28 |
+
"feat_method": "mfcc",
|
29 |
+
"aud_feat_dim": 64,
|
30 |
+
"aud_feat_win_size": null,
|
31 |
+
"context_info": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"Model": {
|
35 |
+
"model_type": "body",
|
36 |
+
"model_name": "s2g_body_vq",
|
37 |
+
"composition": true,
|
38 |
+
"code_num": 2048,
|
39 |
+
"bh_model": true,
|
40 |
+
"AudioOpt": "Adam",
|
41 |
+
"encoder_choice": "mfcc",
|
42 |
+
"gan": false
|
43 |
+
},
|
44 |
+
"DataLoader": {
|
45 |
+
"batch_size": 128,
|
46 |
+
"num_workers": 0
|
47 |
+
},
|
48 |
+
"Train": {
|
49 |
+
"epochs": 100,
|
50 |
+
"max_gradient_norm": 5,
|
51 |
+
"learning_rate": {
|
52 |
+
"generator_learning_rate": 1e-4,
|
53 |
+
"discriminator_learning_rate": 1e-4
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"Log": {
|
57 |
+
"save_every": 50,
|
58 |
+
"print_every": 200,
|
59 |
+
"name": "body-vq"
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
config/face.json
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
|
3 |
+
"dataset_load_mode": "json",
|
4 |
+
"store_file_path": "store.pkl",
|
5 |
+
"smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
|
6 |
+
"extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
|
7 |
+
"j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
|
8 |
+
"param": {
|
9 |
+
"w_j": 1,
|
10 |
+
"w_b": 1,
|
11 |
+
"w_h": 1
|
12 |
+
},
|
13 |
+
"Data": {
|
14 |
+
"data_root": "../ExpressiveWholeBodyDatasetv1.0/",
|
15 |
+
"pklname": "_3d_wv2.pkl",
|
16 |
+
"whole_video": true,
|
17 |
+
"pose": {
|
18 |
+
"normalization": false,
|
19 |
+
"convert_to_6d": false,
|
20 |
+
"norm_method": "all",
|
21 |
+
"augmentation": false,
|
22 |
+
"generate_length": 88,
|
23 |
+
"pre_pose_length": 0,
|
24 |
+
"pose_dim": 99,
|
25 |
+
"expression": true
|
26 |
+
},
|
27 |
+
"aud": {
|
28 |
+
"feat_method": "mfcc",
|
29 |
+
"aud_feat_dim": 64,
|
30 |
+
"aud_feat_win_size": null,
|
31 |
+
"context_info": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"Model": {
|
35 |
+
"model_type": "face",
|
36 |
+
"model_name": "s2g_face",
|
37 |
+
"AudioOpt": "SGD",
|
38 |
+
"encoder_choice": "faceformer",
|
39 |
+
"gan": false
|
40 |
+
},
|
41 |
+
"DataLoader": {
|
42 |
+
"batch_size": 1,
|
43 |
+
"num_workers": 0
|
44 |
+
},
|
45 |
+
"Train": {
|
46 |
+
"epochs": 100,
|
47 |
+
"max_gradient_norm": 5,
|
48 |
+
"learning_rate": {
|
49 |
+
"generator_learning_rate": 1e-4,
|
50 |
+
"discriminator_learning_rate": 1e-4
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"Log": {
|
54 |
+
"save_every": 50,
|
55 |
+
"print_every": 1000,
|
56 |
+
"name": "face"
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
data_utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# from .dataloader_csv import MultiVidData as csv_data
|
2 |
+
from .dataloader_torch import MultiVidData as torch_data
|
3 |
+
from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
|
data_utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (375 Bytes). View file
|
|
data_utils/__pycache__/consts.cpython-37.pyc
ADDED
Binary file (92.7 kB). View file
|
|
data_utils/__pycache__/dataloader_torch.cpython-37.pyc
ADDED
Binary file (5.31 kB). View file
|
|
data_utils/__pycache__/lower_body.cpython-37.pyc
ADDED
Binary file (3.91 kB). View file
|
|
data_utils/__pycache__/mesh_dataset.cpython-37.pyc
ADDED
Binary file (7.9 kB). View file
|
|
data_utils/__pycache__/rotation_conversion.cpython-37.pyc
ADDED
Binary file (16.4 kB). View file
|
|
data_utils/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (7.42 kB). View file
|
|
data_utils/apply_split.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import pickle
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
speakers = ['seth', 'oliver', 'conan', 'chemistry']
|
7 |
+
source_data_root = "../expressive_body-V0.7"
|
8 |
+
data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0"
|
9 |
+
|
10 |
+
f_read = open('split_more_than_2s.pkl', 'rb')
|
11 |
+
f_save = open('none.pkl', 'wb')
|
12 |
+
data_split = pickle.load(f_read)
|
13 |
+
none_split = []
|
14 |
+
|
15 |
+
train = val = test = 0
|
16 |
+
|
17 |
+
for speaker_name in speakers:
|
18 |
+
speaker_root = os.path.join(data_root, speaker_name)
|
19 |
+
|
20 |
+
videos = [v for v in data_split[speaker_name]]
|
21 |
+
|
22 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
23 |
+
for split in data_split[speaker_name][vid]:
|
24 |
+
for seq in data_split[speaker_name][vid][split]:
|
25 |
+
|
26 |
+
seq = seq.replace('\\', '/')
|
27 |
+
old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1])
|
28 |
+
old_file_path = old_file_path.replace('\\', '/')
|
29 |
+
new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1])
|
30 |
+
try:
|
31 |
+
shutil.move(old_file_path, new_file_path)
|
32 |
+
if split == 'train':
|
33 |
+
train = train + 1
|
34 |
+
elif split == 'test':
|
35 |
+
test = test + 1
|
36 |
+
elif split == 'val':
|
37 |
+
val = val + 1
|
38 |
+
except FileNotFoundError:
|
39 |
+
none_split.append(old_file_path)
|
40 |
+
print(f"The file {old_file_path} does not exists.")
|
41 |
+
except shutil.Error:
|
42 |
+
none_split.append(old_file_path)
|
43 |
+
print(f"The file {old_file_path} does not exists.")
|
44 |
+
|
45 |
+
print(none_split.__len__())
|
46 |
+
pickle.dump(none_split, f_save)
|
47 |
+
f_save.close()
|
48 |
+
|
49 |
+
print(train, val, test)
|
50 |
+
|
51 |
+
|
data_utils/axis2matrix.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import scipy.linalg as linalg
|
4 |
+
|
5 |
+
|
6 |
+
def rotate_mat(axis, radian):
|
7 |
+
|
8 |
+
a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian)
|
9 |
+
|
10 |
+
rot_matrix = linalg.expm(a)
|
11 |
+
|
12 |
+
return rot_matrix
|
13 |
+
|
14 |
+
def aaa2mat(axis, sin, cos):
|
15 |
+
i = np.eye(3)
|
16 |
+
nnt = np.dot(axis.T, axis)
|
17 |
+
s = np.asarray([[0, -axis[0,2], axis[0,1]],
|
18 |
+
[axis[0,2], 0, -axis[0,0]],
|
19 |
+
[-axis[0,1], axis[0,0], 0]])
|
20 |
+
r = cos * i + (1-cos)*nnt +sin * s
|
21 |
+
return r
|
22 |
+
|
23 |
+
rand_axis = np.asarray([[1,0,0]])
|
24 |
+
#旋转角度
|
25 |
+
r = math.pi/2
|
26 |
+
#返回旋转矩阵
|
27 |
+
rot_matrix = rotate_mat(rand_axis, r)
|
28 |
+
r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r))
|
29 |
+
print(rot_matrix)
|
data_utils/consts.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_utils/dataloader_torch.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(os.getcwd())
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
from data_utils.utils import *
|
7 |
+
import torch.utils.data as data
|
8 |
+
from data_utils.mesh_dataset import SmplxDataset
|
9 |
+
from transformers import Wav2Vec2Processor
|
10 |
+
|
11 |
+
|
12 |
+
class MultiVidData():
|
13 |
+
def __init__(self,
|
14 |
+
data_root,
|
15 |
+
speakers,
|
16 |
+
split='train',
|
17 |
+
limbscaling=False,
|
18 |
+
normalization=False,
|
19 |
+
norm_method='new',
|
20 |
+
split_trans_zero=False,
|
21 |
+
num_frames=25,
|
22 |
+
num_pre_frames=25,
|
23 |
+
num_generate_length=None,
|
24 |
+
aud_feat_win_size=None,
|
25 |
+
aud_feat_dim=64,
|
26 |
+
feat_method='mel_spec',
|
27 |
+
context_info=False,
|
28 |
+
smplx=False,
|
29 |
+
audio_sr=16000,
|
30 |
+
convert_to_6d=False,
|
31 |
+
expression=False,
|
32 |
+
config=None
|
33 |
+
):
|
34 |
+
self.data_root = data_root
|
35 |
+
self.speakers = speakers
|
36 |
+
self.split = split
|
37 |
+
if split == 'pre':
|
38 |
+
self.split = 'train'
|
39 |
+
self.norm_method=norm_method
|
40 |
+
self.normalization = normalization
|
41 |
+
self.limbscaling = limbscaling
|
42 |
+
self.convert_to_6d = convert_to_6d
|
43 |
+
self.num_frames=num_frames
|
44 |
+
self.num_pre_frames=num_pre_frames
|
45 |
+
if num_generate_length is None:
|
46 |
+
self.num_generate_length = num_frames
|
47 |
+
else:
|
48 |
+
self.num_generate_length = num_generate_length
|
49 |
+
self.split_trans_zero=split_trans_zero
|
50 |
+
|
51 |
+
dataset = SmplxDataset
|
52 |
+
|
53 |
+
if self.split_trans_zero:
|
54 |
+
self.trans_dataset_list = []
|
55 |
+
self.zero_dataset_list = []
|
56 |
+
else:
|
57 |
+
self.all_dataset_list = []
|
58 |
+
self.dataset={}
|
59 |
+
self.complete_data=[]
|
60 |
+
self.config=config
|
61 |
+
load_mode=self.config.dataset_load_mode
|
62 |
+
|
63 |
+
######################load with pickle file
|
64 |
+
if load_mode=='pickle':
|
65 |
+
import pickle
|
66 |
+
import subprocess
|
67 |
+
|
68 |
+
# store_file_path='/tmp/store.pkl'
|
69 |
+
# cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl
|
70 |
+
# subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True)
|
71 |
+
|
72 |
+
# f = open(self.config.store_file_path, 'rb+')
|
73 |
+
f = open(self.split+config.Data.pklname, 'rb+')
|
74 |
+
self.dataset=pickle.load(f)
|
75 |
+
f.close()
|
76 |
+
for key in self.dataset:
|
77 |
+
self.complete_data.append(self.dataset[key].complete_data)
|
78 |
+
######################load with pickle file
|
79 |
+
|
80 |
+
######################load with a csv file
|
81 |
+
elif load_mode=='csv':
|
82 |
+
|
83 |
+
# 这里从我的一个code文件夹导入的,后续再完善进来
|
84 |
+
try:
|
85 |
+
sys.path.append(self.config.config_root_path)
|
86 |
+
from config import config_path
|
87 |
+
from csv_parser import csv_parse
|
88 |
+
|
89 |
+
except ImportError as e:
|
90 |
+
print(f'err: {e}')
|
91 |
+
raise ImportError('config root path error...')
|
92 |
+
|
93 |
+
|
94 |
+
for speaker_name in self.speakers:
|
95 |
+
# df_intervals=pd.read_csv(self.config.voca_csv_file_path)
|
96 |
+
df_intervals=None
|
97 |
+
df_intervals=df_intervals[df_intervals['speaker']==speaker_name]
|
98 |
+
df_intervals = df_intervals[df_intervals['dataset'] == self.split]
|
99 |
+
|
100 |
+
print(f'speaker {speaker_name} train interval length: {len(df_intervals)}')
|
101 |
+
for iter_index, (_, interval) in tqdm(
|
102 |
+
(enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}'
|
103 |
+
):
|
104 |
+
|
105 |
+
(
|
106 |
+
interval_index,
|
107 |
+
interval_speaker,
|
108 |
+
interval_video_fn,
|
109 |
+
interval_id,
|
110 |
+
|
111 |
+
start_time,
|
112 |
+
end_time,
|
113 |
+
duration_time,
|
114 |
+
start_time_10,
|
115 |
+
over_flow_flag,
|
116 |
+
short_dur_flag,
|
117 |
+
|
118 |
+
big_video_dir,
|
119 |
+
small_video_dir_name,
|
120 |
+
speaker_video_path,
|
121 |
+
|
122 |
+
voca_basename,
|
123 |
+
json_basename,
|
124 |
+
wav_basename,
|
125 |
+
voca_top_clip_path,
|
126 |
+
voca_json_clip_path,
|
127 |
+
voca_wav_clip_path,
|
128 |
+
|
129 |
+
audio_output_fn,
|
130 |
+
image_output_path,
|
131 |
+
pifpaf_output_path,
|
132 |
+
mp_output_path,
|
133 |
+
op_output_path,
|
134 |
+
deca_output_path,
|
135 |
+
pixie_output_path,
|
136 |
+
cam_output_path,
|
137 |
+
ours_output_path,
|
138 |
+
merge_output_path,
|
139 |
+
multi_output_path,
|
140 |
+
gt_output_path,
|
141 |
+
ours_images_path,
|
142 |
+
pkl_fil_path,
|
143 |
+
)=csv_parse(interval)
|
144 |
+
|
145 |
+
if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn):
|
146 |
+
continue
|
147 |
+
|
148 |
+
key=f'{interval_video_fn}/{small_video_dir_name}'
|
149 |
+
self.dataset[key] = dataset(
|
150 |
+
data_root=pkl_fil_path,
|
151 |
+
speaker=speaker_name,
|
152 |
+
audio_fn=audio_output_fn,
|
153 |
+
audio_sr=audio_sr,
|
154 |
+
fps=num_frames,
|
155 |
+
feat_method=feat_method,
|
156 |
+
audio_feat_dim=aud_feat_dim,
|
157 |
+
train=(self.split == 'train'),
|
158 |
+
load_all=True,
|
159 |
+
split_trans_zero=self.split_trans_zero,
|
160 |
+
limbscaling=self.limbscaling,
|
161 |
+
num_frames=self.num_frames,
|
162 |
+
num_pre_frames=self.num_pre_frames,
|
163 |
+
num_generate_length=self.num_generate_length,
|
164 |
+
audio_feat_win_size=aud_feat_win_size,
|
165 |
+
context_info=context_info,
|
166 |
+
convert_to_6d=convert_to_6d,
|
167 |
+
expression=expression,
|
168 |
+
config=self.config
|
169 |
+
)
|
170 |
+
self.complete_data.append(self.dataset[key].complete_data)
|
171 |
+
######################load with a csv file
|
172 |
+
|
173 |
+
######################origin load method
|
174 |
+
elif load_mode=='json':
|
175 |
+
|
176 |
+
# if self.split == 'train':
|
177 |
+
# import pickle
|
178 |
+
# f = open('store.pkl', 'rb+')
|
179 |
+
# self.dataset=pickle.load(f)
|
180 |
+
# f.close()
|
181 |
+
# for key in self.dataset:
|
182 |
+
# self.complete_data.append(self.dataset[key].complete_data)
|
183 |
+
# else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav
|
184 |
+
# if config.Model.model_type == 'face':
|
185 |
+
am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
|
186 |
+
am_sr = 16000
|
187 |
+
# else:
|
188 |
+
# am, am_sr = None, None
|
189 |
+
for speaker_name in self.speakers:
|
190 |
+
speaker_root = os.path.join(self.data_root, speaker_name)
|
191 |
+
|
192 |
+
videos=[v for v in os.listdir(speaker_root) ]
|
193 |
+
print(videos)
|
194 |
+
|
195 |
+
haode = huaide = 0
|
196 |
+
|
197 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
198 |
+
source_vid=vid
|
199 |
+
# vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split)
|
200 |
+
vid_pth = os.path.join(speaker_root, source_vid, self.split)
|
201 |
+
if smplx == 'pose':
|
202 |
+
seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))]
|
203 |
+
else:
|
204 |
+
try:
|
205 |
+
seqs = [s for s in os.listdir(vid_pth)]
|
206 |
+
except:
|
207 |
+
continue
|
208 |
+
|
209 |
+
for s in seqs:
|
210 |
+
seq_root=os.path.join(vid_pth, s)
|
211 |
+
key = seq_root # correspond to clip******
|
212 |
+
audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s))
|
213 |
+
motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s))
|
214 |
+
if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname):
|
215 |
+
huaide = huaide + 1
|
216 |
+
continue
|
217 |
+
|
218 |
+
self.dataset[key]=dataset(
|
219 |
+
data_root=seq_root,
|
220 |
+
speaker=speaker_name,
|
221 |
+
motion_fn=motion_fname,
|
222 |
+
audio_fn=audio_fname,
|
223 |
+
audio_sr=audio_sr,
|
224 |
+
fps=num_frames,
|
225 |
+
feat_method=feat_method,
|
226 |
+
audio_feat_dim=aud_feat_dim,
|
227 |
+
train=(self.split=='train'),
|
228 |
+
load_all=True,
|
229 |
+
split_trans_zero=self.split_trans_zero,
|
230 |
+
limbscaling=self.limbscaling,
|
231 |
+
num_frames=self.num_frames,
|
232 |
+
num_pre_frames=self.num_pre_frames,
|
233 |
+
num_generate_length=self.num_generate_length,
|
234 |
+
audio_feat_win_size=aud_feat_win_size,
|
235 |
+
context_info=context_info,
|
236 |
+
convert_to_6d=convert_to_6d,
|
237 |
+
expression=expression,
|
238 |
+
config=self.config,
|
239 |
+
am=am,
|
240 |
+
am_sr=am_sr,
|
241 |
+
whole_video=config.Data.whole_video
|
242 |
+
)
|
243 |
+
self.complete_data.append(self.dataset[key].complete_data)
|
244 |
+
haode = haode + 1
|
245 |
+
print("huaide:{}, haode:{}".format(huaide, haode))
|
246 |
+
import pickle
|
247 |
+
|
248 |
+
f = open(self.split+config.Data.pklname, 'wb')
|
249 |
+
pickle.dump(self.dataset, f)
|
250 |
+
f.close()
|
251 |
+
######################origin load method
|
252 |
+
|
253 |
+
self.complete_data=np.concatenate(self.complete_data, axis=0)
|
254 |
+
|
255 |
+
# assert self.complete_data.shape[-1] == (12+21+21)*2
|
256 |
+
self.normalize_stats = {}
|
257 |
+
|
258 |
+
self.data_mean = None
|
259 |
+
self.data_std = None
|
260 |
+
|
261 |
+
def get_dataset(self):
|
262 |
+
self.normalize_stats['mean'] = self.data_mean
|
263 |
+
self.normalize_stats['std'] = self.data_std
|
264 |
+
|
265 |
+
for key in list(self.dataset.keys()):
|
266 |
+
if self.dataset[key].complete_data.shape[0] < self.num_generate_length:
|
267 |
+
continue
|
268 |
+
self.dataset[key].num_generate_length = self.num_generate_length
|
269 |
+
self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split)
|
270 |
+
self.all_dataset_list.append(self.dataset[key].all_dataset)
|
271 |
+
|
272 |
+
if self.split_trans_zero:
|
273 |
+
self.trans_dataset = data.ConcatDataset(self.trans_dataset_list)
|
274 |
+
self.zero_dataset = data.ConcatDataset(self.zero_dataset_list)
|
275 |
+
else:
|
276 |
+
self.all_dataset = data.ConcatDataset(self.all_dataset_list)
|
277 |
+
|
278 |
+
|
279 |
+
|
data_utils/dataset_preprocess.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
from tqdm import tqdm
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import librosa
|
8 |
+
import random
|
9 |
+
|
10 |
+
speakers = ['seth', 'conan', 'oliver', 'chemistry']
|
11 |
+
data_root = "../ExpressiveWholeBodyDatasetv1.0/"
|
12 |
+
split = 'train'
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def split_list(full_list,shuffle=False,ratio=0.2):
|
17 |
+
n_total = len(full_list)
|
18 |
+
offset_0 = int(n_total * ratio)
|
19 |
+
offset_1 = int(n_total * ratio * 2)
|
20 |
+
if n_total==0 or offset_1<1:
|
21 |
+
return [],full_list
|
22 |
+
if shuffle:
|
23 |
+
random.shuffle(full_list)
|
24 |
+
sublist_0 = full_list[:offset_0]
|
25 |
+
sublist_1 = full_list[offset_0:offset_1]
|
26 |
+
sublist_2 = full_list[offset_1:]
|
27 |
+
return sublist_0, sublist_1, sublist_2
|
28 |
+
|
29 |
+
|
30 |
+
def moveto(list, file):
|
31 |
+
for f in list:
|
32 |
+
before, after = '/'.join(f.split('/')[:-1]), f.split('/')[-1]
|
33 |
+
new_path = os.path.join(before, file)
|
34 |
+
new_path = os.path.join(new_path, after)
|
35 |
+
# os.makedirs(new_path)
|
36 |
+
# os.path.isdir(new_path)
|
37 |
+
# shutil.move(f, new_path)
|
38 |
+
|
39 |
+
#转移到新目录
|
40 |
+
shutil.copytree(f, new_path)
|
41 |
+
#删除原train里的文件
|
42 |
+
shutil.rmtree(f)
|
43 |
+
return None
|
44 |
+
|
45 |
+
|
46 |
+
def read_pkl(data):
|
47 |
+
betas = np.array(data['betas'])
|
48 |
+
|
49 |
+
jaw_pose = np.array(data['jaw_pose'])
|
50 |
+
leye_pose = np.array(data['leye_pose'])
|
51 |
+
reye_pose = np.array(data['reye_pose'])
|
52 |
+
global_orient = np.array(data['global_orient']).squeeze()
|
53 |
+
body_pose = np.array(data['body_pose_axis'])
|
54 |
+
left_hand_pose = np.array(data['left_hand_pose'])
|
55 |
+
right_hand_pose = np.array(data['right_hand_pose'])
|
56 |
+
|
57 |
+
full_body = np.concatenate(
|
58 |
+
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
|
59 |
+
|
60 |
+
expression = np.array(data['expression'])
|
61 |
+
full_body = np.concatenate((full_body, expression), axis=1)
|
62 |
+
|
63 |
+
if (full_body.shape[0] < 90) or (torch.isnan(torch.from_numpy(full_body)).sum() > 0):
|
64 |
+
return 1
|
65 |
+
else:
|
66 |
+
return 0
|
67 |
+
|
68 |
+
|
69 |
+
for speaker_name in speakers:
|
70 |
+
speaker_root = os.path.join(data_root, speaker_name)
|
71 |
+
|
72 |
+
videos = [v for v in os.listdir(speaker_root)]
|
73 |
+
print(videos)
|
74 |
+
|
75 |
+
haode = huaide = 0
|
76 |
+
total_seqs = []
|
77 |
+
|
78 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
79 |
+
# for vid in videos:
|
80 |
+
source_vid = vid
|
81 |
+
vid_pth = os.path.join(speaker_root, source_vid)
|
82 |
+
# vid_pth = os.path.join(speaker_root, source_vid, 'images/half', split)
|
83 |
+
t = os.path.join(speaker_root, source_vid, 'test')
|
84 |
+
v = os.path.join(speaker_root, source_vid, 'val')
|
85 |
+
|
86 |
+
# if os.path.exists(t):
|
87 |
+
# shutil.rmtree(t)
|
88 |
+
# if os.path.exists(v):
|
89 |
+
# shutil.rmtree(v)
|
90 |
+
try:
|
91 |
+
seqs = [s for s in os.listdir(vid_pth)]
|
92 |
+
except:
|
93 |
+
continue
|
94 |
+
# if len(seqs) == 0:
|
95 |
+
# shutil.rmtree(os.path.join(speaker_root, source_vid))
|
96 |
+
# None
|
97 |
+
for s in seqs:
|
98 |
+
quality = 0
|
99 |
+
total_seqs.append(os.path.join(vid_pth,s))
|
100 |
+
seq_root = os.path.join(vid_pth, s)
|
101 |
+
key = seq_root # correspond to clip******
|
102 |
+
audio_fname = os.path.join(speaker_root, source_vid, s, '%s.wav' % (s))
|
103 |
+
|
104 |
+
# delete the data without audio or the audio file could not be read
|
105 |
+
if os.path.isfile(audio_fname):
|
106 |
+
try:
|
107 |
+
audio = librosa.load(audio_fname)
|
108 |
+
except:
|
109 |
+
# print(key)
|
110 |
+
shutil.rmtree(key)
|
111 |
+
huaide = huaide + 1
|
112 |
+
continue
|
113 |
+
else:
|
114 |
+
huaide = huaide + 1
|
115 |
+
# print(key)
|
116 |
+
shutil.rmtree(key)
|
117 |
+
continue
|
118 |
+
|
119 |
+
# check motion file
|
120 |
+
motion_fname = os.path.join(speaker_root, source_vid, s, '%s.pkl' % (s))
|
121 |
+
try:
|
122 |
+
f = open(motion_fname, 'rb+')
|
123 |
+
except:
|
124 |
+
shutil.rmtree(key)
|
125 |
+
huaide = huaide + 1
|
126 |
+
continue
|
127 |
+
|
128 |
+
data = pickle.load(f)
|
129 |
+
w = read_pkl(data)
|
130 |
+
f.close()
|
131 |
+
quality = quality + w
|
132 |
+
|
133 |
+
if w == 1:
|
134 |
+
shutil.rmtree(key)
|
135 |
+
# print(key)
|
136 |
+
huaide = huaide + 1
|
137 |
+
continue
|
138 |
+
|
139 |
+
haode = haode + 1
|
140 |
+
|
141 |
+
print("huaide:{}, haode:{}, total_seqs:{}".format(huaide, haode, total_seqs.__len__()))
|
142 |
+
|
143 |
+
for speaker_name in speakers:
|
144 |
+
speaker_root = os.path.join(data_root, speaker_name)
|
145 |
+
|
146 |
+
videos = [v for v in os.listdir(speaker_root)]
|
147 |
+
print(videos)
|
148 |
+
|
149 |
+
haode = huaide = 0
|
150 |
+
total_seqs = []
|
151 |
+
|
152 |
+
for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
|
153 |
+
# for vid in videos:
|
154 |
+
source_vid = vid
|
155 |
+
vid_pth = os.path.join(speaker_root, source_vid)
|
156 |
+
try:
|
157 |
+
seqs = [s for s in os.listdir(vid_pth)]
|
158 |
+
except:
|
159 |
+
continue
|
160 |
+
for s in seqs:
|
161 |
+
quality = 0
|
162 |
+
total_seqs.append(os.path.join(vid_pth, s))
|
163 |
+
print("total_seqs:{}".format(total_seqs.__len__()))
|
164 |
+
# split the dataset
|
165 |
+
test_list, val_list, train_list = split_list(total_seqs, True, 0.1)
|
166 |
+
print(len(test_list), len(val_list), len(train_list))
|
167 |
+
moveto(train_list, 'train')
|
168 |
+
moveto(test_list, 'test')
|
169 |
+
moveto(val_list, 'val')
|
170 |
+
|
data_utils/get_j.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def to3d(poses, config):
|
5 |
+
if config.Data.pose.convert_to_6d:
|
6 |
+
if config.Data.pose.expression:
|
7 |
+
poses_exp = poses[:, -100:]
|
8 |
+
poses = poses[:, :-100]
|
9 |
+
|
10 |
+
poses = poses.reshape(poses.shape[0], -1, 5)
|
11 |
+
sin, cos = poses[:, :, 3], poses[:, :, 4]
|
12 |
+
pose_angle = torch.atan2(sin, cos)
|
13 |
+
poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1)
|
14 |
+
|
15 |
+
if config.Data.pose.expression:
|
16 |
+
poses = torch.cat([poses, poses_exp], dim=-1)
|
17 |
+
return poses
|
18 |
+
|
19 |
+
|
20 |
+
def get_joint(smplx_model, betas, pred):
|
21 |
+
joint = smplx_model(betas=betas.repeat(pred.shape[0], 1),
|
22 |
+
expression=pred[:, 165:265],
|
23 |
+
jaw_pose=pred[:, 0:3],
|
24 |
+
leye_pose=pred[:, 3:6],
|
25 |
+
reye_pose=pred[:, 6:9],
|
26 |
+
global_orient=pred[:, 9:12],
|
27 |
+
body_pose=pred[:, 12:75],
|
28 |
+
left_hand_pose=pred[:, 75:120],
|
29 |
+
right_hand_pose=pred[:, 120:165],
|
30 |
+
return_verts=True)['joints']
|
31 |
+
return joint
|
32 |
+
|
33 |
+
|
34 |
+
def get_joints(smplx_model, betas, pred):
|
35 |
+
if len(pred.shape) == 3:
|
36 |
+
B = pred.shape[0]
|
37 |
+
x = 4 if B>= 4 else B
|
38 |
+
T = pred.shape[1]
|
39 |
+
pred = pred.reshape(-1, 265)
|
40 |
+
smplx_model.batch_size = L = T * x
|
41 |
+
|
42 |
+
times = pred.shape[0] // smplx_model.batch_size
|
43 |
+
joints = []
|
44 |
+
for i in range(times):
|
45 |
+
joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L]))
|
46 |
+
joints = torch.cat(joints, dim=0)
|
47 |
+
joints = joints.reshape(B, T, -1, 3)
|
48 |
+
else:
|
49 |
+
smplx_model.batch_size = pred.shape[0]
|
50 |
+
joints = get_joint(smplx_model, betas, pred)
|
51 |
+
return joints
|
data_utils/hand_component.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_utils/lower_body.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
lower_pose = torch.tensor(
|
5 |
+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
|
6 |
+
0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
|
7 |
+
-0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
|
8 |
+
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
9 |
+
lower_pose_stand = torch.tensor([
|
10 |
+
8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
|
11 |
+
3.0747, -0.0158, -0.0152,
|
12 |
+
-3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
|
13 |
+
-3.9716e-01, -4.0229e-02, -1.2637e-01,
|
14 |
+
7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
|
15 |
+
7.8632e-01, -4.3810e-02, 1.4375e-02,
|
16 |
+
-1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])
|
17 |
+
# lower_pose_stand = torch.tensor(
|
18 |
+
# [6.4919e-02, 3.3018e-02, 1.7485e-02, 8.9759e-04, 7.1074e-04, -5.9163e-06,
|
19 |
+
# 3.0747, -0.0158, -0.0152,
|
20 |
+
# -3.3633e+00, -9.3915e-02, 3.0996e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
|
21 |
+
# 1.1654603481292725, 0.0, 0.0,
|
22 |
+
# 4.4167e-01, 6.7183e-03, -3.6379e-03, 7.9163e-01, 6.8519e-02, -1.5091e-01,
|
23 |
+
# 0.0, 0.0, 0.0,
|
24 |
+
# 2.2910e-02, -2.4797e-02, -5.5657e-03, -1.0675e-01, 1.2635e-01, 1.6711e-02,])
|
25 |
+
lower_body = [0, 1, 3, 4, 6, 7, 9, 10]
|
26 |
+
count_part = [6, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
27 |
+
29, 30, 31, 32, 33, 34, 35, 36, 37,
|
28 |
+
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54]
|
29 |
+
fix_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
30 |
+
29,
|
31 |
+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
|
32 |
+
50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
|
33 |
+
65, 66, 67, 68, 69, 70, 71, 72, 73, 74]
|
34 |
+
all_index = np.ones(275)
|
35 |
+
all_index[fix_index] = 0
|
36 |
+
c_index = []
|
37 |
+
i = 0
|
38 |
+
for num in all_index:
|
39 |
+
if num == 1:
|
40 |
+
c_index.append(i)
|
41 |
+
i = i + 1
|
42 |
+
c_index = np.asarray(c_index)
|
43 |
+
|
44 |
+
fix_index_3d = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
|
45 |
+
21, 22, 23, 24, 25, 26,
|
46 |
+
30, 31, 32, 33, 34, 35,
|
47 |
+
45, 46, 47, 48, 49, 50]
|
48 |
+
all_index_3d = np.ones(165)
|
49 |
+
all_index_3d[fix_index_3d] = 0
|
50 |
+
c_index_3d = []
|
51 |
+
i = 0
|
52 |
+
for num in all_index_3d:
|
53 |
+
if num == 1:
|
54 |
+
c_index_3d.append(i)
|
55 |
+
i = i + 1
|
56 |
+
c_index_3d = np.asarray(c_index_3d)
|
57 |
+
|
58 |
+
c_index_6d = []
|
59 |
+
i = 0
|
60 |
+
for num in all_index_3d:
|
61 |
+
if num == 1:
|
62 |
+
c_index_6d.append(2*i)
|
63 |
+
c_index_6d.append(2 * i + 1)
|
64 |
+
i = i + 1
|
65 |
+
c_index_6d = np.asarray(c_index_6d)
|
66 |
+
|
67 |
+
|
68 |
+
def part2full(input, stand=False):
|
69 |
+
if stand:
|
70 |
+
# lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
71 |
+
lp = torch.zeros_like(lower_pose)
|
72 |
+
lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
|
73 |
+
lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
74 |
+
else:
|
75 |
+
lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
76 |
+
|
77 |
+
input = torch.cat([input[:, :3],
|
78 |
+
lp[:, :15],
|
79 |
+
input[:, 3:6],
|
80 |
+
lp[:, 15:21],
|
81 |
+
input[:, 6:9],
|
82 |
+
lp[:, 21:27],
|
83 |
+
input[:, 9:12],
|
84 |
+
lp[:, 27:],
|
85 |
+
input[:, 12:]]
|
86 |
+
, dim=1)
|
87 |
+
return input
|
88 |
+
|
89 |
+
|
90 |
+
def pred2poses(input, gt):
|
91 |
+
input = torch.cat([input[:, :3],
|
92 |
+
gt[0:1, 3:18].repeat(input.shape[0], 1),
|
93 |
+
input[:, 3:6],
|
94 |
+
gt[0:1, 21:27].repeat(input.shape[0], 1),
|
95 |
+
input[:, 6:9],
|
96 |
+
gt[0:1, 30:36].repeat(input.shape[0], 1),
|
97 |
+
input[:, 9:12],
|
98 |
+
gt[0:1, 39:45].repeat(input.shape[0], 1),
|
99 |
+
input[:, 12:]]
|
100 |
+
, dim=1)
|
101 |
+
return input
|
102 |
+
|
103 |
+
|
104 |
+
def poses2poses(input, gt):
|
105 |
+
input = torch.cat([input[:, :3],
|
106 |
+
gt[0:1, 3:18].repeat(input.shape[0], 1),
|
107 |
+
input[:, 18:21],
|
108 |
+
gt[0:1, 21:27].repeat(input.shape[0], 1),
|
109 |
+
input[:, 27:30],
|
110 |
+
gt[0:1, 30:36].repeat(input.shape[0], 1),
|
111 |
+
input[:, 36:39],
|
112 |
+
gt[0:1, 39:45].repeat(input.shape[0], 1),
|
113 |
+
input[:, 45:]]
|
114 |
+
, dim=1)
|
115 |
+
return input
|
116 |
+
|
117 |
+
def poses2pred(input, stand=False):
|
118 |
+
if stand:
|
119 |
+
lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
120 |
+
# lp = torch.zeros_like(lower_pose).unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
121 |
+
else:
|
122 |
+
lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
|
123 |
+
input = torch.cat([input[:, :3],
|
124 |
+
lp[:, :15],
|
125 |
+
input[:, 18:21],
|
126 |
+
lp[:, 15:21],
|
127 |
+
input[:, 27:30],
|
128 |
+
lp[:, 21:27],
|
129 |
+
input[:, 36:39],
|
130 |
+
lp[:, 27:],
|
131 |
+
input[:, 45:]]
|
132 |
+
, dim=1)
|
133 |
+
return input
|
134 |
+
|
135 |
+
|
136 |
+
rearrange = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]\
|
137 |
+
# ,22, 23, 24, 25, 40, 26, 41,
|
138 |
+
# 27, 42, 28, 43, 29, 44, 30, 45, 31, 46, 32, 47, 33, 48, 34, 49, 35, 50, 36, 51, 37, 52, 38, 53, 39, 54, 55,
|
139 |
+
# 57, 56, 59, 58, 60, 63, 61, 64, 62, 65, 66, 71, 67, 72, 68, 73, 69, 74, 70, 75]
|
140 |
+
|
141 |
+
symmetry = [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1]#, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
142 |
+
# 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
143 |
+
# 1, 1, 1, 1, 1, 1]
|
data_utils/mesh_dataset.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
+
import json
|
8 |
+
from glob import glob
|
9 |
+
from data_utils.utils import *
|
10 |
+
import torch.utils.data as data
|
11 |
+
from data_utils.consts import speaker_id
|
12 |
+
from data_utils.lower_body import count_part
|
13 |
+
import random
|
14 |
+
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
|
15 |
+
|
16 |
+
with open('data_utils/hand_component.json') as file_obj:
|
17 |
+
comp = json.load(file_obj)
|
18 |
+
left_hand_c = np.asarray(comp['left'])
|
19 |
+
right_hand_c = np.asarray(comp['right'])
|
20 |
+
|
21 |
+
|
22 |
+
def to3d(data):
|
23 |
+
left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :])
|
24 |
+
right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :])
|
25 |
+
data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1)
|
26 |
+
return data
|
27 |
+
|
28 |
+
|
29 |
+
class SmplxDataset():
|
30 |
+
'''
|
31 |
+
creat a dataset for every segment and concat.
|
32 |
+
'''
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
data_root,
|
36 |
+
speaker,
|
37 |
+
motion_fn,
|
38 |
+
audio_fn,
|
39 |
+
audio_sr,
|
40 |
+
fps,
|
41 |
+
feat_method='mel_spec',
|
42 |
+
audio_feat_dim=64,
|
43 |
+
audio_feat_win_size=None,
|
44 |
+
|
45 |
+
train=True,
|
46 |
+
load_all=False,
|
47 |
+
split_trans_zero=False,
|
48 |
+
limbscaling=False,
|
49 |
+
num_frames=25,
|
50 |
+
num_pre_frames=25,
|
51 |
+
num_generate_length=25,
|
52 |
+
context_info=False,
|
53 |
+
convert_to_6d=False,
|
54 |
+
expression=False,
|
55 |
+
config=None,
|
56 |
+
am=None,
|
57 |
+
am_sr=None,
|
58 |
+
whole_video=False
|
59 |
+
):
|
60 |
+
|
61 |
+
self.data_root = data_root
|
62 |
+
self.speaker = speaker
|
63 |
+
|
64 |
+
self.feat_method = feat_method
|
65 |
+
self.audio_fn = audio_fn
|
66 |
+
self.audio_sr = audio_sr
|
67 |
+
self.fps = fps
|
68 |
+
self.audio_feat_dim = audio_feat_dim
|
69 |
+
self.audio_feat_win_size = audio_feat_win_size
|
70 |
+
self.context_info = context_info # for aud feat
|
71 |
+
self.convert_to_6d = convert_to_6d
|
72 |
+
self.expression = expression
|
73 |
+
|
74 |
+
self.train = train
|
75 |
+
self.load_all = load_all
|
76 |
+
self.split_trans_zero = split_trans_zero
|
77 |
+
self.limbscaling = limbscaling
|
78 |
+
self.num_frames = num_frames
|
79 |
+
self.num_pre_frames = num_pre_frames
|
80 |
+
self.num_generate_length = num_generate_length
|
81 |
+
# print('num_generate_length ', self.num_generate_length)
|
82 |
+
|
83 |
+
self.config = config
|
84 |
+
self.am_sr = am_sr
|
85 |
+
self.whole_video = whole_video
|
86 |
+
load_mode = self.config.dataset_load_mode
|
87 |
+
|
88 |
+
if load_mode == 'pickle':
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
elif load_mode == 'csv':
|
92 |
+
import pickle
|
93 |
+
with open(data_root, 'rb') as f:
|
94 |
+
u = pickle._Unpickler(f)
|
95 |
+
data = u.load()
|
96 |
+
self.data = data[0]
|
97 |
+
if self.load_all:
|
98 |
+
self._load_npz_all()
|
99 |
+
|
100 |
+
elif load_mode == 'json':
|
101 |
+
self.annotations = glob(data_root + '/*pkl')
|
102 |
+
if len(self.annotations) == 0:
|
103 |
+
raise FileNotFoundError(data_root + ' are empty')
|
104 |
+
self.annotations = sorted(self.annotations)
|
105 |
+
self.img_name_list = self.annotations
|
106 |
+
|
107 |
+
if self.load_all:
|
108 |
+
self._load_them_all(am, am_sr, motion_fn)
|
109 |
+
|
110 |
+
def _load_npz_all(self):
|
111 |
+
self.loaded_data = {}
|
112 |
+
self.complete_data = []
|
113 |
+
data = self.data
|
114 |
+
shape = data['body_pose_axis'].shape[0]
|
115 |
+
self.betas = data['betas']
|
116 |
+
self.img_name_list = []
|
117 |
+
for index in range(shape):
|
118 |
+
img_name = f'{index:6d}'
|
119 |
+
self.img_name_list.append(img_name)
|
120 |
+
|
121 |
+
jaw_pose = data['jaw_pose'][index]
|
122 |
+
leye_pose = data['leye_pose'][index]
|
123 |
+
reye_pose = data['reye_pose'][index]
|
124 |
+
global_orient = data['global_orient'][index]
|
125 |
+
body_pose = data['body_pose_axis'][index]
|
126 |
+
left_hand_pose = data['left_hand_pose'][index]
|
127 |
+
right_hand_pose = data['right_hand_pose'][index]
|
128 |
+
|
129 |
+
full_body = np.concatenate(
|
130 |
+
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose))
|
131 |
+
assert full_body.shape[0] == 99
|
132 |
+
if self.convert_to_6d:
|
133 |
+
full_body = to3d(full_body)
|
134 |
+
full_body = torch.from_numpy(full_body)
|
135 |
+
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body))
|
136 |
+
full_body = np.asarray(full_body)
|
137 |
+
if self.expression:
|
138 |
+
expression = data['expression'][index]
|
139 |
+
full_body = np.concatenate((full_body, expression))
|
140 |
+
# full_body = np.concatenate((full_body, non_zero))
|
141 |
+
else:
|
142 |
+
full_body = to3d(full_body)
|
143 |
+
if self.expression:
|
144 |
+
expression = data['expression'][index]
|
145 |
+
full_body = np.concatenate((full_body, expression))
|
146 |
+
|
147 |
+
self.loaded_data[img_name] = full_body.reshape(-1)
|
148 |
+
self.complete_data.append(full_body.reshape(-1))
|
149 |
+
|
150 |
+
self.complete_data = np.array(self.complete_data)
|
151 |
+
|
152 |
+
if self.audio_feat_win_size is not None:
|
153 |
+
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
|
154 |
+
# print(self.audio_feat.shape)
|
155 |
+
else:
|
156 |
+
if self.feat_method == 'mel_spec':
|
157 |
+
self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
|
158 |
+
elif self.feat_method == 'mfcc':
|
159 |
+
self.audio_feat = get_mfcc(self.audio_fn,
|
160 |
+
smlpx=True,
|
161 |
+
sr=self.audio_sr,
|
162 |
+
n_mfcc=self.audio_feat_dim,
|
163 |
+
win_size=self.audio_feat_win_size
|
164 |
+
)
|
165 |
+
|
166 |
+
def _load_them_all(self, am, am_sr, motion_fn):
|
167 |
+
self.loaded_data = {}
|
168 |
+
self.complete_data = []
|
169 |
+
f = open(motion_fn, 'rb+')
|
170 |
+
data = pickle.load(f)
|
171 |
+
|
172 |
+
self.betas = np.array(data['betas'])
|
173 |
+
|
174 |
+
jaw_pose = np.array(data['jaw_pose'])
|
175 |
+
leye_pose = np.array(data['leye_pose'])
|
176 |
+
reye_pose = np.array(data['reye_pose'])
|
177 |
+
global_orient = np.array(data['global_orient']).squeeze()
|
178 |
+
body_pose = np.array(data['body_pose_axis'])
|
179 |
+
left_hand_pose = np.array(data['left_hand_pose'])
|
180 |
+
right_hand_pose = np.array(data['right_hand_pose'])
|
181 |
+
|
182 |
+
full_body = np.concatenate(
|
183 |
+
(jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
|
184 |
+
assert full_body.shape[1] == 99
|
185 |
+
|
186 |
+
|
187 |
+
if self.convert_to_6d:
|
188 |
+
full_body = to3d(full_body)
|
189 |
+
full_body = torch.from_numpy(full_body)
|
190 |
+
full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330)
|
191 |
+
full_body = np.asarray(full_body)
|
192 |
+
if self.expression:
|
193 |
+
expression = np.array(data['expression'])
|
194 |
+
full_body = np.concatenate((full_body, expression), axis=1)
|
195 |
+
|
196 |
+
else:
|
197 |
+
full_body = to3d(full_body)
|
198 |
+
expression = np.array(data['expression'])
|
199 |
+
full_body = np.concatenate((full_body, expression), axis=1)
|
200 |
+
|
201 |
+
self.complete_data = full_body
|
202 |
+
self.complete_data = np.array(self.complete_data)
|
203 |
+
|
204 |
+
if self.audio_feat_win_size is not None:
|
205 |
+
self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
|
206 |
+
else:
|
207 |
+
# if self.feat_method == 'mel_spec':
|
208 |
+
# self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
|
209 |
+
# elif self.feat_method == 'mfcc':
|
210 |
+
self.audio_feat = get_mfcc_ta(self.audio_fn,
|
211 |
+
smlpx=True,
|
212 |
+
fps=30,
|
213 |
+
sr=self.audio_sr,
|
214 |
+
n_mfcc=self.audio_feat_dim,
|
215 |
+
win_size=self.audio_feat_win_size,
|
216 |
+
type=self.feat_method,
|
217 |
+
am=am,
|
218 |
+
am_sr=am_sr,
|
219 |
+
encoder_choice=self.config.Model.encoder_choice,
|
220 |
+
)
|
221 |
+
# with open(audio_file, 'w', encoding='utf-8') as file:
|
222 |
+
# file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False))
|
223 |
+
|
224 |
+
def get_dataset(self, normalization=False, normalize_stats=None, split='train'):
|
225 |
+
|
226 |
+
class __Worker__(data.Dataset):
|
227 |
+
def __init__(child, index_list, normalization, normalize_stats, split='train') -> None:
|
228 |
+
super().__init__()
|
229 |
+
child.index_list = index_list
|
230 |
+
child.normalization = normalization
|
231 |
+
child.normalize_stats = normalize_stats
|
232 |
+
child.split = split
|
233 |
+
|
234 |
+
def __getitem__(child, index):
|
235 |
+
num_generate_length = self.num_generate_length
|
236 |
+
num_pre_frames = self.num_pre_frames
|
237 |
+
seq_len = num_generate_length + num_pre_frames
|
238 |
+
# print(num_generate_length)
|
239 |
+
|
240 |
+
index = child.index_list[index]
|
241 |
+
index_new = index + random.randrange(0, 5, 3)
|
242 |
+
if index_new + seq_len > self.complete_data.shape[0]:
|
243 |
+
index_new = index
|
244 |
+
index = index_new
|
245 |
+
|
246 |
+
if child.split in ['val', 'pre', 'test'] or self.whole_video:
|
247 |
+
index = 0
|
248 |
+
seq_len = self.complete_data.shape[0]
|
249 |
+
seq_data = []
|
250 |
+
assert index + seq_len <= self.complete_data.shape[0]
|
251 |
+
# print(seq_len)
|
252 |
+
seq_data = self.complete_data[index:(index + seq_len), :]
|
253 |
+
seq_data = np.array(seq_data)
|
254 |
+
|
255 |
+
'''
|
256 |
+
audio feature,
|
257 |
+
'''
|
258 |
+
if not self.context_info:
|
259 |
+
if not self.whole_video:
|
260 |
+
audio_feat = self.audio_feat[index:index + seq_len, ...]
|
261 |
+
if audio_feat.shape[0] < seq_len:
|
262 |
+
audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]],
|
263 |
+
mode='reflect')
|
264 |
+
|
265 |
+
assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim
|
266 |
+
else:
|
267 |
+
audio_feat = self.audio_feat
|
268 |
+
|
269 |
+
else: # including feature and history
|
270 |
+
if self.audio_feat_win_size is None:
|
271 |
+
audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...]
|
272 |
+
if audio_feat.shape[0] < seq_len + num_pre_frames:
|
273 |
+
audio_feat = np.pad(audio_feat,
|
274 |
+
[[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]],
|
275 |
+
mode='constant')
|
276 |
+
|
277 |
+
assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[
|
278 |
+
1] == self.audio_feat_dim
|
279 |
+
|
280 |
+
if child.normalization:
|
281 |
+
data_mean = child.normalize_stats['mean'].reshape(1, -1)
|
282 |
+
data_std = child.normalize_stats['std'].reshape(1, -1)
|
283 |
+
seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std
|
284 |
+
if child.split in['train', 'test']:
|
285 |
+
if self.convert_to_6d:
|
286 |
+
if self.expression:
|
287 |
+
data_sample = {
|
288 |
+
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
|
289 |
+
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
|
290 |
+
# 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0),
|
291 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
292 |
+
'speaker': speaker_id[self.speaker],
|
293 |
+
'betas': self.betas,
|
294 |
+
'aud_file': self.audio_fn,
|
295 |
+
}
|
296 |
+
else:
|
297 |
+
data_sample = {
|
298 |
+
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
|
299 |
+
'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0),
|
300 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
301 |
+
'speaker': speaker_id[self.speaker],
|
302 |
+
'betas': self.betas
|
303 |
+
}
|
304 |
+
else:
|
305 |
+
if self.expression:
|
306 |
+
data_sample = {
|
307 |
+
'poses': seq_data[:, :165].astype(np.float).transpose(1, 0),
|
308 |
+
'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0),
|
309 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
310 |
+
# 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0),
|
311 |
+
'speaker': speaker_id[self.speaker],
|
312 |
+
'aud_file': self.audio_fn,
|
313 |
+
'betas': self.betas
|
314 |
+
}
|
315 |
+
else:
|
316 |
+
data_sample = {
|
317 |
+
'poses': seq_data.astype(np.float).transpose(1, 0),
|
318 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
319 |
+
'speaker': speaker_id[self.speaker],
|
320 |
+
'betas': self.betas
|
321 |
+
}
|
322 |
+
return data_sample
|
323 |
+
else:
|
324 |
+
data_sample = {
|
325 |
+
'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
|
326 |
+
'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
|
327 |
+
# 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0),
|
328 |
+
'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
|
329 |
+
'aud_file': self.audio_fn,
|
330 |
+
'speaker': speaker_id[self.speaker],
|
331 |
+
'betas': self.betas
|
332 |
+
}
|
333 |
+
return data_sample
|
334 |
+
def __len__(child):
|
335 |
+
return len(child.index_list)
|
336 |
+
|
337 |
+
if split == 'train':
|
338 |
+
index_list = list(
|
339 |
+
range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames,
|
340 |
+
6))
|
341 |
+
elif split in ['val', 'test']:
|
342 |
+
index_list = list([0])
|
343 |
+
if self.whole_video:
|
344 |
+
index_list = list([0])
|
345 |
+
self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split)
|
346 |
+
|
347 |
+
def __len__(self):
|
348 |
+
return len(self.img_name_list)
|
data_utils/rotation_conversion.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
2 |
+
# Check PYTORCH3D_LICENCE before use
|
3 |
+
|
4 |
+
import functools
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
The transformation matrices returned from the functions in this file assume
|
13 |
+
the points on which the transformation will be applied are column vectors.
|
14 |
+
i.e. the R matrix is structured as
|
15 |
+
|
16 |
+
R = [
|
17 |
+
[Rxx, Rxy, Rxz],
|
18 |
+
[Ryx, Ryy, Ryz],
|
19 |
+
[Rzx, Rzy, Rzz],
|
20 |
+
] # (3, 3)
|
21 |
+
|
22 |
+
This matrix can be applied to column vectors by post multiplication
|
23 |
+
by the points e.g.
|
24 |
+
|
25 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
26 |
+
transformed_points = R * points
|
27 |
+
|
28 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
29 |
+
can be transposed and pre multiplied by the points:
|
30 |
+
|
31 |
+
e.g.
|
32 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
33 |
+
transformed_points = points * R.transpose(1, 0)
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
def quaternion_to_matrix(quaternions):
|
38 |
+
"""
|
39 |
+
Convert rotations given as quaternions to rotation matrices.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
quaternions: quaternions with real part first,
|
43 |
+
as tensor of shape (..., 4).
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
47 |
+
"""
|
48 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
49 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
50 |
+
|
51 |
+
o = torch.stack(
|
52 |
+
(
|
53 |
+
1 - two_s * (j * j + k * k),
|
54 |
+
two_s * (i * j - k * r),
|
55 |
+
two_s * (i * k + j * r),
|
56 |
+
two_s * (i * j + k * r),
|
57 |
+
1 - two_s * (i * i + k * k),
|
58 |
+
two_s * (j * k - i * r),
|
59 |
+
two_s * (i * k - j * r),
|
60 |
+
two_s * (j * k + i * r),
|
61 |
+
1 - two_s * (i * i + j * j),
|
62 |
+
),
|
63 |
+
-1,
|
64 |
+
)
|
65 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
66 |
+
|
67 |
+
|
68 |
+
def _copysign(a, b):
|
69 |
+
"""
|
70 |
+
Return a tensor where each element has the absolute value taken from the,
|
71 |
+
corresponding element of a, with sign taken from the corresponding
|
72 |
+
element of b. This is like the standard copysign floating-point operation,
|
73 |
+
but is not careful about negative 0 and NaN.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
a: source tensor.
|
77 |
+
b: tensor whose signs will be used, of the same shape as a.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Tensor of the same shape as a with the signs of b.
|
81 |
+
"""
|
82 |
+
signs_differ = (a < 0) != (b < 0)
|
83 |
+
return torch.where(signs_differ, -a, a)
|
84 |
+
|
85 |
+
|
86 |
+
def _sqrt_positive_part(x):
|
87 |
+
"""
|
88 |
+
Returns torch.sqrt(torch.max(0, x))
|
89 |
+
but with a zero subgradient where x is 0.
|
90 |
+
"""
|
91 |
+
ret = torch.zeros_like(x)
|
92 |
+
positive_mask = x > 0
|
93 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
94 |
+
return ret
|
95 |
+
|
96 |
+
|
97 |
+
def matrix_to_quaternion(matrix):
|
98 |
+
"""
|
99 |
+
Convert rotations given as rotation matrices to quaternions.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
106 |
+
"""
|
107 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
108 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
109 |
+
m00 = matrix[..., 0, 0]
|
110 |
+
m11 = matrix[..., 1, 1]
|
111 |
+
m22 = matrix[..., 2, 2]
|
112 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
113 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
114 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
115 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
116 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
117 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
118 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
119 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
120 |
+
|
121 |
+
|
122 |
+
def _axis_angle_rotation(axis: str, angle):
|
123 |
+
"""
|
124 |
+
Return the rotation matrices for one of the rotations about an axis
|
125 |
+
of which Euler angles describe, for each value of the angle given.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
axis: Axis label "X" or "Y or "Z".
|
129 |
+
angle: any shape tensor of Euler angles in radians
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
133 |
+
"""
|
134 |
+
|
135 |
+
cos = torch.cos(angle)
|
136 |
+
sin = torch.sin(angle)
|
137 |
+
one = torch.ones_like(angle)
|
138 |
+
zero = torch.zeros_like(angle)
|
139 |
+
|
140 |
+
if axis == "X":
|
141 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
142 |
+
if axis == "Y":
|
143 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
144 |
+
if axis == "Z":
|
145 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
146 |
+
|
147 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
148 |
+
|
149 |
+
|
150 |
+
def euler_angles_to_matrix(euler_angles, convention: str):
|
151 |
+
"""
|
152 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
156 |
+
convention: Convention string of three uppercase letters from
|
157 |
+
{"X", "Y", and "Z"}.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
161 |
+
"""
|
162 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
163 |
+
raise ValueError("Invalid input euler angles.")
|
164 |
+
if len(convention) != 3:
|
165 |
+
raise ValueError("Convention must have 3 letters.")
|
166 |
+
if convention[1] in (convention[0], convention[2]):
|
167 |
+
raise ValueError(f"Invalid convention {convention}.")
|
168 |
+
for letter in convention:
|
169 |
+
if letter not in ("X", "Y", "Z"):
|
170 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
171 |
+
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
172 |
+
return functools.reduce(torch.matmul, matrices)
|
173 |
+
|
174 |
+
|
175 |
+
def _angle_from_tan(
|
176 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
177 |
+
):
|
178 |
+
"""
|
179 |
+
Extract the first or third Euler angle from the two members of
|
180 |
+
the matrix which are positive constant times its sine and cosine.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
184 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
185 |
+
convention.
|
186 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
187 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
188 |
+
which means the relevant entries are in the same row of the
|
189 |
+
rotation matrix. If not, they are in the same column.
|
190 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
Euler Angles in radians for each matrix in data as a tensor
|
194 |
+
of shape (...).
|
195 |
+
"""
|
196 |
+
|
197 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
198 |
+
if horizontal:
|
199 |
+
i2, i1 = i1, i2
|
200 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
201 |
+
if horizontal == even:
|
202 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
203 |
+
if tait_bryan:
|
204 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
205 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
206 |
+
|
207 |
+
|
208 |
+
def _index_from_letter(letter: str):
|
209 |
+
if letter == "X":
|
210 |
+
return 0
|
211 |
+
if letter == "Y":
|
212 |
+
return 1
|
213 |
+
if letter == "Z":
|
214 |
+
return 2
|
215 |
+
|
216 |
+
|
217 |
+
def matrix_to_euler_angles(matrix, convention: str):
|
218 |
+
"""
|
219 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
223 |
+
convention: Convention string of three uppercase letters.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
Euler angles in radians as tensor of shape (..., 3).
|
227 |
+
"""
|
228 |
+
if len(convention) != 3:
|
229 |
+
raise ValueError("Convention must have 3 letters.")
|
230 |
+
if convention[1] in (convention[0], convention[2]):
|
231 |
+
raise ValueError(f"Invalid convention {convention}.")
|
232 |
+
for letter in convention:
|
233 |
+
if letter not in ("X", "Y", "Z"):
|
234 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
235 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
236 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
237 |
+
i0 = _index_from_letter(convention[0])
|
238 |
+
i2 = _index_from_letter(convention[2])
|
239 |
+
tait_bryan = i0 != i2
|
240 |
+
if tait_bryan:
|
241 |
+
central_angle = torch.asin(
|
242 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
243 |
+
)
|
244 |
+
else:
|
245 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
246 |
+
|
247 |
+
o = (
|
248 |
+
_angle_from_tan(
|
249 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
250 |
+
),
|
251 |
+
central_angle,
|
252 |
+
_angle_from_tan(
|
253 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
254 |
+
),
|
255 |
+
)
|
256 |
+
return torch.stack(o, -1)
|
257 |
+
|
258 |
+
|
259 |
+
def random_quaternions(
|
260 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
261 |
+
):
|
262 |
+
"""
|
263 |
+
Generate random quaternions representing rotations,
|
264 |
+
i.e. versors with nonnegative real part.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
n: Number of quaternions in a batch to return.
|
268 |
+
dtype: Type to return.
|
269 |
+
device: Desired device of returned tensor. Default:
|
270 |
+
uses the current device for the default tensor type.
|
271 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
272 |
+
flag set.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Quaternions as tensor of shape (N, 4).
|
276 |
+
"""
|
277 |
+
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
278 |
+
s = (o * o).sum(1)
|
279 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
280 |
+
return o
|
281 |
+
|
282 |
+
|
283 |
+
def random_rotations(
|
284 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
285 |
+
):
|
286 |
+
"""
|
287 |
+
Generate random rotations as 3x3 rotation matrices.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
n: Number of rotation matrices in a batch to return.
|
291 |
+
dtype: Type to return.
|
292 |
+
device: Device of returned tensor. Default: if None,
|
293 |
+
uses the current device for the default tensor type.
|
294 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
295 |
+
flag set.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
299 |
+
"""
|
300 |
+
quaternions = random_quaternions(
|
301 |
+
n, dtype=dtype, device=device, requires_grad=requires_grad
|
302 |
+
)
|
303 |
+
return quaternion_to_matrix(quaternions)
|
304 |
+
|
305 |
+
|
306 |
+
def random_rotation(
|
307 |
+
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
308 |
+
):
|
309 |
+
"""
|
310 |
+
Generate a single random 3x3 rotation matrix.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
dtype: Type to return
|
314 |
+
device: Device of returned tensor. Default: if None,
|
315 |
+
uses the current device for the default tensor type
|
316 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
317 |
+
flag set
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
Rotation matrix as tensor of shape (3, 3).
|
321 |
+
"""
|
322 |
+
return random_rotations(1, dtype, device, requires_grad)[0]
|
323 |
+
|
324 |
+
|
325 |
+
def standardize_quaternion(quaternions):
|
326 |
+
"""
|
327 |
+
Convert a unit quaternion to a standard form: one in which the real
|
328 |
+
part is non negative.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
quaternions: Quaternions with real part first,
|
332 |
+
as tensor of shape (..., 4).
|
333 |
+
|
334 |
+
Returns:
|
335 |
+
Standardized quaternions as tensor of shape (..., 4).
|
336 |
+
"""
|
337 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
338 |
+
|
339 |
+
|
340 |
+
def quaternion_raw_multiply(a, b):
|
341 |
+
"""
|
342 |
+
Multiply two quaternions.
|
343 |
+
Usual torch rules for broadcasting apply.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
347 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
351 |
+
"""
|
352 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
353 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
354 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
355 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
356 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
357 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
358 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
359 |
+
|
360 |
+
|
361 |
+
def quaternion_multiply(a, b):
|
362 |
+
"""
|
363 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
364 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
365 |
+
Usual torch rules for broadcasting apply.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
369 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
373 |
+
"""
|
374 |
+
ab = quaternion_raw_multiply(a, b)
|
375 |
+
return standardize_quaternion(ab)
|
376 |
+
|
377 |
+
|
378 |
+
def quaternion_invert(quaternion):
|
379 |
+
"""
|
380 |
+
Given a quaternion representing rotation, get the quaternion representing
|
381 |
+
its inverse.
|
382 |
+
|
383 |
+
Args:
|
384 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
385 |
+
first, which must be versors (unit quaternions).
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
389 |
+
"""
|
390 |
+
|
391 |
+
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
392 |
+
|
393 |
+
|
394 |
+
def quaternion_apply(quaternion, point):
|
395 |
+
"""
|
396 |
+
Apply the rotation given by a quaternion to a 3D point.
|
397 |
+
Usual torch rules for broadcasting apply.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
401 |
+
point: Tensor of 3D points of shape (..., 3).
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
Tensor of rotated points of shape (..., 3).
|
405 |
+
"""
|
406 |
+
if point.size(-1) != 3:
|
407 |
+
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
408 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
409 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
410 |
+
out = quaternion_raw_multiply(
|
411 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
412 |
+
quaternion_invert(quaternion),
|
413 |
+
)
|
414 |
+
return out[..., 1:]
|
415 |
+
|
416 |
+
|
417 |
+
def axis_angle_to_matrix(axis_angle):
|
418 |
+
"""
|
419 |
+
Convert rotations given as axis/angle to rotation matrices.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
423 |
+
as a tensor of shape (..., 3), where the magnitude is
|
424 |
+
the angle turned anticlockwise in radians around the
|
425 |
+
vector's direction.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
429 |
+
"""
|
430 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
431 |
+
|
432 |
+
|
433 |
+
def matrix_to_axis_angle(matrix):
|
434 |
+
"""
|
435 |
+
Convert rotations given as rotation matrices to axis/angle.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
Rotations given as a vector in axis angle form, as a tensor
|
442 |
+
of shape (..., 3), where the magnitude is the angle
|
443 |
+
turned anticlockwise in radians around the vector's
|
444 |
+
direction.
|
445 |
+
"""
|
446 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
447 |
+
|
448 |
+
|
449 |
+
def axis_angle_to_quaternion(axis_angle):
|
450 |
+
"""
|
451 |
+
Convert rotations given as axis/angle to quaternions.
|
452 |
+
|
453 |
+
Args:
|
454 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
455 |
+
as a tensor of shape (..., 3), where the magnitude is
|
456 |
+
the angle turned anticlockwise in radians around the
|
457 |
+
vector's direction.
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
461 |
+
"""
|
462 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
463 |
+
half_angles = 0.5 * angles
|
464 |
+
eps = 1e-6
|
465 |
+
small_angles = angles.abs() < eps
|
466 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
467 |
+
sin_half_angles_over_angles[~small_angles] = (
|
468 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
469 |
+
)
|
470 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
471 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
472 |
+
sin_half_angles_over_angles[small_angles] = (
|
473 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
474 |
+
)
|
475 |
+
quaternions = torch.cat(
|
476 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
477 |
+
)
|
478 |
+
return quaternions
|
479 |
+
|
480 |
+
|
481 |
+
def quaternion_to_axis_angle(quaternions):
|
482 |
+
"""
|
483 |
+
Convert rotations given as quaternions to axis/angle.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
quaternions: quaternions with real part first,
|
487 |
+
as tensor of shape (..., 4).
|
488 |
+
|
489 |
+
Returns:
|
490 |
+
Rotations given as a vector in axis angle form, as a tensor
|
491 |
+
of shape (..., 3), where the magnitude is the angle
|
492 |
+
turned anticlockwise in radians around the vector's
|
493 |
+
direction.
|
494 |
+
"""
|
495 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
496 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
497 |
+
angles = 2 * half_angles
|
498 |
+
eps = 1e-6
|
499 |
+
small_angles = angles.abs() < eps
|
500 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
501 |
+
sin_half_angles_over_angles[~small_angles] = (
|
502 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
503 |
+
)
|
504 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
505 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
506 |
+
sin_half_angles_over_angles[small_angles] = (
|
507 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
508 |
+
)
|
509 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
510 |
+
|
511 |
+
|
512 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
513 |
+
"""
|
514 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
515 |
+
using Gram--Schmidt orthogonalisation per Section B of [1].
|
516 |
+
Args:
|
517 |
+
d6: 6D rotation representation, of size (*, 6)
|
518 |
+
|
519 |
+
Returns:
|
520 |
+
batch of rotation matrices of size (*, 3, 3)
|
521 |
+
|
522 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
523 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
524 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
525 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
526 |
+
"""
|
527 |
+
|
528 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
529 |
+
b1 = F.normalize(a1, dim=-1)
|
530 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
531 |
+
b2 = F.normalize(b2, dim=-1)
|
532 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
533 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
534 |
+
|
535 |
+
|
536 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
537 |
+
"""
|
538 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
539 |
+
by dropping the last row. Note that 6D representation is not unique.
|
540 |
+
Args:
|
541 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
6D rotation representation, of size (*, 6)
|
545 |
+
|
546 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
547 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
548 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
549 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
550 |
+
"""
|
551 |
+
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
data_utils/split_more_than_2s.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2df6e745cdf7473f13ce3ae2ed759c3cceb60c9197e7f3fd65110e7bc20b6f2d
|
3 |
+
size 2398875
|
data_utils/split_train_val_test.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
if __name__ =='__main__':
|
6 |
+
id_list = "chemistry conan oliver seth"
|
7 |
+
id_list = id_list.split(' ')
|
8 |
+
|
9 |
+
old_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0'
|
10 |
+
new_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0/talkshow_data_splited'
|
11 |
+
|
12 |
+
with open('train_val_test.json') as f:
|
13 |
+
split_info = json.load(f)
|
14 |
+
phase_list = ['train', 'val', 'test']
|
15 |
+
for phase in phase_list:
|
16 |
+
phase_path_list = split_info[phase]
|
17 |
+
for p in phase_path_list:
|
18 |
+
old_path = os.path.join(old_root, p)
|
19 |
+
if not os.path.exists(old_path):
|
20 |
+
print(f'{old_path} not found, continue' )
|
21 |
+
continue
|
22 |
+
new_path = os.path.join(new_root, phase, p)
|
23 |
+
dir_name = os.path.dirname(new_path)
|
24 |
+
if not os.path.isdir(dir_name):
|
25 |
+
os.makedirs(dir_name, exist_ok=True)
|
26 |
+
shutil.move(old_path, new_path)
|
27 |
+
|
data_utils/train_val_test.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_utils/utils.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
# import librosa #has to do this cause librosa is not supported on my server
|
3 |
+
import python_speech_features
|
4 |
+
from scipy.io import wavfile
|
5 |
+
from scipy import signal
|
6 |
+
import librosa
|
7 |
+
import torch
|
8 |
+
import torchaudio as ta
|
9 |
+
import torchaudio.functional as ta_F
|
10 |
+
import torchaudio.transforms as ta_T
|
11 |
+
# import pyloudnorm as pyln
|
12 |
+
|
13 |
+
|
14 |
+
def load_wav_old(audio_fn, sr = 16000):
|
15 |
+
sample_rate, sig = wavfile.read(audio_fn)
|
16 |
+
if sample_rate != sr:
|
17 |
+
result = int((sig.shape[0]) / sample_rate * sr)
|
18 |
+
x_resampled = signal.resample(sig, result)
|
19 |
+
x_resampled = x_resampled.astype(np.float64)
|
20 |
+
return x_resampled, sr
|
21 |
+
|
22 |
+
sig = sig / (2**15)
|
23 |
+
return sig, sample_rate
|
24 |
+
|
25 |
+
|
26 |
+
def get_mfcc(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
|
27 |
+
|
28 |
+
y, sr = librosa.load(audio_fn, sr=sr, mono=True)
|
29 |
+
|
30 |
+
if win_size is None:
|
31 |
+
hop_len=int(sr / fps)
|
32 |
+
else:
|
33 |
+
hop_len=int(sr / win_size)
|
34 |
+
|
35 |
+
n_fft=2048
|
36 |
+
|
37 |
+
C = librosa.feature.mfcc(
|
38 |
+
y = y,
|
39 |
+
sr = sr,
|
40 |
+
n_mfcc = n_mfcc,
|
41 |
+
hop_length = hop_len,
|
42 |
+
n_fft = n_fft
|
43 |
+
)
|
44 |
+
|
45 |
+
if C.shape[0] == n_mfcc:
|
46 |
+
C = C.transpose(1, 0)
|
47 |
+
|
48 |
+
return C
|
49 |
+
|
50 |
+
|
51 |
+
def get_melspec(audio_fn, eps=1e-6, fps = 25, sr=16000, n_mels=64):
|
52 |
+
raise NotImplementedError
|
53 |
+
'''
|
54 |
+
# y, sr = load_wav(audio_fn=audio_fn, sr=sr)
|
55 |
+
|
56 |
+
# hop_len = int(sr / fps)
|
57 |
+
# n_fft = 2048
|
58 |
+
|
59 |
+
# C = librosa.feature.melspectrogram(
|
60 |
+
# y = y,
|
61 |
+
# sr = sr,
|
62 |
+
# n_fft=n_fft,
|
63 |
+
# hop_length=hop_len,
|
64 |
+
# n_mels = n_mels,
|
65 |
+
# fmin=0,
|
66 |
+
# fmax=8000)
|
67 |
+
|
68 |
+
|
69 |
+
# mask = (C == 0).astype(np.float)
|
70 |
+
# C = mask * eps + (1-mask) * C
|
71 |
+
|
72 |
+
# C = np.log(C)
|
73 |
+
# #wierd error may occur here
|
74 |
+
# assert not (np.isnan(C).any()), audio_fn
|
75 |
+
# if C.shape[0] == n_mels:
|
76 |
+
# C = C.transpose(1, 0)
|
77 |
+
|
78 |
+
# return C
|
79 |
+
'''
|
80 |
+
|
81 |
+
def extract_mfcc(audio,sample_rate=16000):
|
82 |
+
mfcc = zip(*python_speech_features.mfcc(audio,sample_rate, numcep=64, nfilt=64, nfft=2048, winstep=0.04))
|
83 |
+
mfcc = np.stack([np.array(i) for i in mfcc])
|
84 |
+
return mfcc
|
85 |
+
|
86 |
+
def get_mfcc_psf(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
|
87 |
+
y, sr = load_wav_old(audio_fn, sr=sr)
|
88 |
+
|
89 |
+
if y.shape.__len__() > 1:
|
90 |
+
y = (y[:,0]+y[:,1])/2
|
91 |
+
|
92 |
+
if win_size is None:
|
93 |
+
hop_len=int(sr / fps)
|
94 |
+
else:
|
95 |
+
hop_len=int(sr/ win_size)
|
96 |
+
|
97 |
+
n_fft=2048
|
98 |
+
|
99 |
+
#hard coded for 25 fps
|
100 |
+
if not smlpx:
|
101 |
+
C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=0.04)
|
102 |
+
else:
|
103 |
+
C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01/15)
|
104 |
+
# if C.shape[0] == n_mfcc:
|
105 |
+
# C = C.transpose(1, 0)
|
106 |
+
|
107 |
+
return C
|
108 |
+
|
109 |
+
|
110 |
+
def get_mfcc_psf_min(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
|
111 |
+
y, sr = load_wav_old(audio_fn, sr=sr)
|
112 |
+
|
113 |
+
if y.shape.__len__() > 1:
|
114 |
+
y = (y[:, 0] + y[:, 1]) / 2
|
115 |
+
n_fft = 2048
|
116 |
+
|
117 |
+
slice_len = 22000 * 5
|
118 |
+
slice = y.size // slice_len
|
119 |
+
|
120 |
+
C = []
|
121 |
+
|
122 |
+
for i in range(slice):
|
123 |
+
if i != (slice - 1):
|
124 |
+
feat = python_speech_features.mfcc(y[i*slice_len:(i+1)*slice_len], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
|
125 |
+
else:
|
126 |
+
feat = python_speech_features.mfcc(y[i * slice_len:], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
|
127 |
+
|
128 |
+
C.append(feat)
|
129 |
+
|
130 |
+
return C
|
131 |
+
|
132 |
+
|
133 |
+
def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
|
134 |
+
"""
|
135 |
+
:param audio: 1 x T tensor containing a 16kHz audio signal
|
136 |
+
:param frame_rate: frame rate for video (we need one audio chunk per video frame)
|
137 |
+
:param chunk_size: number of audio samples per chunk
|
138 |
+
:return: num_chunks x chunk_size tensor containing sliced audio
|
139 |
+
"""
|
140 |
+
samples_per_frame = chunk_size // frame_rate
|
141 |
+
padding = (chunk_size - samples_per_frame) // 2
|
142 |
+
audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
|
143 |
+
anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
|
144 |
+
audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
|
145 |
+
return audio
|
146 |
+
|
147 |
+
|
148 |
+
def get_mfcc_ta(audio_fn, eps=1e-6, fps=15, smlpx=False, sr=16000, n_mfcc=64, win_size=None, type='mfcc', am=None, am_sr=None, encoder_choice='mfcc'):
|
149 |
+
if am is None:
|
150 |
+
audio, sr_0 = ta.load(audio_fn)
|
151 |
+
if sr != sr_0:
|
152 |
+
audio = ta.transforms.Resample(sr_0, sr)(audio)
|
153 |
+
if audio.shape[0] > 1:
|
154 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
155 |
+
|
156 |
+
n_fft = 2048
|
157 |
+
if fps == 15:
|
158 |
+
hop_length = 1467
|
159 |
+
elif fps == 30:
|
160 |
+
hop_length = 734
|
161 |
+
win_length = hop_length * 2
|
162 |
+
n_mels = 256
|
163 |
+
n_mfcc = 64
|
164 |
+
|
165 |
+
if type == 'mfcc':
|
166 |
+
mfcc_transform = ta_T.MFCC(
|
167 |
+
sample_rate=sr,
|
168 |
+
n_mfcc=n_mfcc,
|
169 |
+
melkwargs={
|
170 |
+
"n_fft": n_fft,
|
171 |
+
"n_mels": n_mels,
|
172 |
+
# "win_length": win_length,
|
173 |
+
"hop_length": hop_length,
|
174 |
+
"mel_scale": "htk",
|
175 |
+
},
|
176 |
+
)
|
177 |
+
audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0,1).numpy()
|
178 |
+
elif type == 'mel':
|
179 |
+
# audio = 0.01 * audio / torch.mean(torch.abs(audio))
|
180 |
+
mel_transform = ta_T.MelSpectrogram(
|
181 |
+
sample_rate=sr, n_fft=n_fft, win_length=None, hop_length=hop_length, n_mels=n_mels
|
182 |
+
)
|
183 |
+
audio_ft = mel_transform(audio).squeeze(0).transpose(0,1).numpy()
|
184 |
+
# audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).transpose(0,1).numpy()
|
185 |
+
elif type == 'mel_mul':
|
186 |
+
audio = 0.01 * audio / torch.mean(torch.abs(audio))
|
187 |
+
audio = audio_chunking(audio, frame_rate=fps, chunk_size=sr)
|
188 |
+
mel_transform = ta_T.MelSpectrogram(
|
189 |
+
sample_rate=sr, n_fft=n_fft, win_length=int(sr/20), hop_length=int(sr/100), n_mels=n_mels
|
190 |
+
)
|
191 |
+
audio_ft = mel_transform(audio).squeeze(1)
|
192 |
+
audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).numpy()
|
193 |
+
else:
|
194 |
+
speech_array, sampling_rate = librosa.load(audio_fn, sr=16000)
|
195 |
+
|
196 |
+
if encoder_choice == 'faceformer':
|
197 |
+
# audio_ft = np.squeeze(am(speech_array, sampling_rate=16000).input_values).reshape(-1, 1)
|
198 |
+
audio_ft = speech_array.reshape(-1, 1)
|
199 |
+
elif encoder_choice == 'meshtalk':
|
200 |
+
audio_ft = 0.01 * speech_array / np.mean(np.abs(speech_array))
|
201 |
+
elif encoder_choice == 'onset':
|
202 |
+
audio_ft = librosa.onset.onset_detect(y=speech_array, sr=16000, units='time').reshape(-1, 1)
|
203 |
+
else:
|
204 |
+
audio, sr_0 = ta.load(audio_fn)
|
205 |
+
if sr != sr_0:
|
206 |
+
audio = ta.transforms.Resample(sr_0, sr)(audio)
|
207 |
+
if audio.shape[0] > 1:
|
208 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
209 |
+
|
210 |
+
n_fft = 2048
|
211 |
+
if fps == 15:
|
212 |
+
hop_length = 1467
|
213 |
+
elif fps == 30:
|
214 |
+
hop_length = 734
|
215 |
+
win_length = hop_length * 2
|
216 |
+
n_mels = 256
|
217 |
+
n_mfcc = 64
|
218 |
+
|
219 |
+
mfcc_transform = ta_T.MFCC(
|
220 |
+
sample_rate=sr,
|
221 |
+
n_mfcc=n_mfcc,
|
222 |
+
melkwargs={
|
223 |
+
"n_fft": n_fft,
|
224 |
+
"n_mels": n_mels,
|
225 |
+
# "win_length": win_length,
|
226 |
+
"hop_length": hop_length,
|
227 |
+
"mel_scale": "htk",
|
228 |
+
},
|
229 |
+
)
|
230 |
+
audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0, 1).numpy()
|
231 |
+
return audio_ft
|
232 |
+
|
233 |
+
|
234 |
+
def get_mfcc_sepa(audio_fn, fps=15, sr=16000):
|
235 |
+
audio, sr_0 = ta.load(audio_fn)
|
236 |
+
if sr != sr_0:
|
237 |
+
audio = ta.transforms.Resample(sr_0, sr)(audio)
|
238 |
+
if audio.shape[0] > 1:
|
239 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
240 |
+
|
241 |
+
n_fft = 2048
|
242 |
+
if fps == 15:
|
243 |
+
hop_length = 1467
|
244 |
+
elif fps == 30:
|
245 |
+
hop_length = 734
|
246 |
+
n_mels = 256
|
247 |
+
n_mfcc = 64
|
248 |
+
|
249 |
+
mfcc_transform = ta_T.MFCC(
|
250 |
+
sample_rate=sr,
|
251 |
+
n_mfcc=n_mfcc,
|
252 |
+
melkwargs={
|
253 |
+
"n_fft": n_fft,
|
254 |
+
"n_mels": n_mels,
|
255 |
+
# "win_length": win_length,
|
256 |
+
"hop_length": hop_length,
|
257 |
+
"mel_scale": "htk",
|
258 |
+
},
|
259 |
+
)
|
260 |
+
audio_ft_0 = mfcc_transform(audio[0, :sr*2]).squeeze(dim=0).transpose(0,1).numpy()
|
261 |
+
audio_ft_1 = mfcc_transform(audio[0, sr*2:]).squeeze(dim=0).transpose(0,1).numpy()
|
262 |
+
audio_ft = np.concatenate((audio_ft_0, audio_ft_1), axis=0)
|
263 |
+
return audio_ft, audio_ft_0.shape[0]
|
264 |
+
|
265 |
+
|
266 |
+
def get_mfcc_old(wav_file):
|
267 |
+
sig, sample_rate = load_wav_old(wav_file)
|
268 |
+
mfcc = extract_mfcc(sig)
|
269 |
+
return mfcc
|
270 |
+
|
271 |
+
|
272 |
+
def smooth_geom(geom, mask: torch.Tensor = None, filter_size: int = 9, sigma: float = 2.0):
|
273 |
+
"""
|
274 |
+
:param geom: T x V x 3 tensor containing a temporal sequence of length T with V vertices in each frame
|
275 |
+
:param mask: V-dimensional Tensor containing a mask with vertices to be smoothed
|
276 |
+
:param filter_size: size of the Gaussian filter
|
277 |
+
:param sigma: standard deviation of the Gaussian filter
|
278 |
+
:return: T x V x 3 tensor containing smoothed geometry (i.e., smoothed in the area indicated by the mask)
|
279 |
+
"""
|
280 |
+
assert filter_size % 2 == 1, f"filter size must be odd but is {filter_size}"
|
281 |
+
# Gaussian smoothing (low-pass filtering)
|
282 |
+
fltr = np.arange(-(filter_size // 2), filter_size // 2 + 1)
|
283 |
+
fltr = np.exp(-0.5 * fltr ** 2 / sigma ** 2)
|
284 |
+
fltr = torch.Tensor(fltr) / np.sum(fltr)
|
285 |
+
# apply fltr
|
286 |
+
fltr = fltr.view(1, 1, -1).to(device=geom.device)
|
287 |
+
T, V = geom.shape[1], geom.shape[2]
|
288 |
+
g = torch.nn.functional.pad(
|
289 |
+
geom.permute(2, 0, 1).view(V, 1, T),
|
290 |
+
pad=[filter_size // 2, filter_size // 2], mode='replicate'
|
291 |
+
)
|
292 |
+
g = torch.nn.functional.conv1d(g, fltr).view(V, 1, T)
|
293 |
+
smoothed = g.permute(1, 2, 0).contiguous()
|
294 |
+
# blend smoothed signal with original signal
|
295 |
+
if mask is None:
|
296 |
+
return smoothed
|
297 |
+
else:
|
298 |
+
return smoothed * mask[None, :, None] + geom * (-mask[None, :, None] + 1)
|
299 |
+
|
300 |
+
if __name__ == '__main__':
|
301 |
+
audio_fn = '../sample_audio/clip000028_tCAkv4ggPgI.wav'
|
302 |
+
|
303 |
+
C = get_mfcc_psf(audio_fn)
|
304 |
+
print(C.shape)
|
305 |
+
|
306 |
+
C_2 = get_mfcc_librosa(audio_fn)
|
307 |
+
print(C.shape)
|
308 |
+
|
309 |
+
print(C)
|
310 |
+
print(C_2)
|
311 |
+
print((C == C_2).all())
|
312 |
+
# print(y.shape, sr)
|
313 |
+
# mel_spec = get_melspec(audio_fn)
|
314 |
+
# print(mel_spec.shape)
|
315 |
+
# mfcc = get_mfcc(audio_fn, sr = 16000)
|
316 |
+
# print(mfcc.shape)
|
317 |
+
# print(mel_spec.max(), mel_spec.min())
|
318 |
+
# print(mfcc.max(), mfcc.min())
|
evaluation/FGD.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from scipy import linalg
|
7 |
+
import math
|
8 |
+
from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
|
12 |
+
|
13 |
+
|
14 |
+
change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
|
15 |
+
class EmbeddingSpaceEvaluator:
|
16 |
+
def __init__(self, ae, vae, device):
|
17 |
+
|
18 |
+
# init embed net
|
19 |
+
self.ae = ae
|
20 |
+
# self.vae = vae
|
21 |
+
|
22 |
+
# storage
|
23 |
+
self.real_feat_list = []
|
24 |
+
self.generated_feat_list = []
|
25 |
+
self.real_joints_list = []
|
26 |
+
self.generated_joints_list = []
|
27 |
+
self.real_6d_list = []
|
28 |
+
self.generated_6d_list = []
|
29 |
+
self.audio_beat_list = []
|
30 |
+
|
31 |
+
def reset(self):
|
32 |
+
self.real_feat_list = []
|
33 |
+
self.generated_feat_list = []
|
34 |
+
|
35 |
+
def get_no_of_samples(self):
|
36 |
+
return len(self.real_feat_list)
|
37 |
+
|
38 |
+
def push_samples(self, generated_poses, real_poses):
|
39 |
+
# self.net.eval()
|
40 |
+
# convert poses to latent features
|
41 |
+
real_feat, real_poses = self.ae.extract(real_poses)
|
42 |
+
generated_feat, generated_poses = self.ae.extract(generated_poses)
|
43 |
+
|
44 |
+
num_joints = real_poses.shape[2] // 3
|
45 |
+
|
46 |
+
real_feat = real_feat.squeeze()
|
47 |
+
generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
|
48 |
+
|
49 |
+
self.real_feat_list.append(real_feat.data.cpu().numpy())
|
50 |
+
self.generated_feat_list.append(generated_feat.data.cpu().numpy())
|
51 |
+
|
52 |
+
# real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
|
53 |
+
# generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
|
54 |
+
#
|
55 |
+
# self.real_feat_list.append(real_poses.data.cpu().numpy())
|
56 |
+
# self.generated_feat_list.append(generated_poses.data.cpu().numpy())
|
57 |
+
|
58 |
+
def push_joints(self, generated_poses, real_poses):
|
59 |
+
self.real_joints_list.append(real_poses.data.cpu())
|
60 |
+
self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
|
61 |
+
|
62 |
+
def push_aud(self, aud):
|
63 |
+
self.audio_beat_list.append(aud.squeeze().data.cpu())
|
64 |
+
|
65 |
+
def get_MAAC(self):
|
66 |
+
ang_vel_list = []
|
67 |
+
for real_joints in self.real_joints_list:
|
68 |
+
real_joints[:, 15:21] = real_joints[:, 16:22]
|
69 |
+
vec = real_joints[:, 15:21] - real_joints[:, 13:19]
|
70 |
+
inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
|
71 |
+
inner_product = torch.clamp(inner_product, -1, 1, out=None)
|
72 |
+
angle = torch.acos(inner_product) / math.pi
|
73 |
+
ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
|
74 |
+
ang_vel_list.append(ang_vel.unsqueeze(dim=0))
|
75 |
+
all_vel = torch.cat(ang_vel_list, dim=0)
|
76 |
+
MAAC = all_vel.mean(dim=0)
|
77 |
+
return MAAC
|
78 |
+
|
79 |
+
def get_BCscore(self):
|
80 |
+
thres = 0.01
|
81 |
+
sigma = 0.1
|
82 |
+
sum_1 = 0
|
83 |
+
total_beat = 0
|
84 |
+
for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
|
85 |
+
motion_beat_time = []
|
86 |
+
if joints.dim() == 4:
|
87 |
+
joints = joints[0]
|
88 |
+
joints[:, 15:21] = joints[:, 16:22]
|
89 |
+
vec = joints[:, 15:21] - joints[:, 13:19]
|
90 |
+
inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
|
91 |
+
inner_product = torch.clamp(inner_product, -1, 1, out=None)
|
92 |
+
angle = torch.acos(inner_product) / math.pi
|
93 |
+
ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
|
94 |
+
|
95 |
+
angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
|
96 |
+
|
97 |
+
sum_2 = 0
|
98 |
+
for i in range(angle_diff.shape[1]):
|
99 |
+
motion_beat_time = []
|
100 |
+
for t in range(1, joints.shape[0]-1):
|
101 |
+
if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
|
102 |
+
if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
|
103 |
+
t][i] >= thres):
|
104 |
+
motion_beat_time.append(float(t) / 30.0)
|
105 |
+
if (len(motion_beat_time) == 0):
|
106 |
+
continue
|
107 |
+
motion_beat_time = torch.tensor(motion_beat_time)
|
108 |
+
sum = 0
|
109 |
+
for audio in audio_beat_time:
|
110 |
+
sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
|
111 |
+
sum_2 = sum_2 + sum
|
112 |
+
total_beat = total_beat + len(audio_beat_time)
|
113 |
+
sum_1 = sum_1 + sum_2
|
114 |
+
return sum_1/total_beat
|
115 |
+
|
116 |
+
|
117 |
+
def get_scores(self):
|
118 |
+
generated_feats = np.vstack(self.generated_feat_list)
|
119 |
+
real_feats = np.vstack(self.real_feat_list)
|
120 |
+
|
121 |
+
def frechet_distance(samples_A, samples_B):
|
122 |
+
A_mu = np.mean(samples_A, axis=0)
|
123 |
+
A_sigma = np.cov(samples_A, rowvar=False)
|
124 |
+
B_mu = np.mean(samples_B, axis=0)
|
125 |
+
B_sigma = np.cov(samples_B, rowvar=False)
|
126 |
+
try:
|
127 |
+
frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
|
128 |
+
except ValueError:
|
129 |
+
frechet_dist = 1e+10
|
130 |
+
return frechet_dist
|
131 |
+
|
132 |
+
####################################################################
|
133 |
+
# frechet distance
|
134 |
+
frechet_dist = frechet_distance(generated_feats, real_feats)
|
135 |
+
|
136 |
+
####################################################################
|
137 |
+
# distance between real and generated samples on the latent feature space
|
138 |
+
dists = []
|
139 |
+
for i in range(real_feats.shape[0]):
|
140 |
+
d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
|
141 |
+
dists.append(d)
|
142 |
+
feat_dist = np.mean(dists)
|
143 |
+
|
144 |
+
return frechet_dist, feat_dist
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
148 |
+
""" from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
|
149 |
+
"""Numpy implementation of the Frechet Distance.
|
150 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
151 |
+
and X_2 ~ N(mu_2, C_2) is
|
152 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
153 |
+
Stable version by Dougal J. Sutherland.
|
154 |
+
Params:
|
155 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
156 |
+
inception net (like returned by the function 'get_predictions')
|
157 |
+
for generated samples.
|
158 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
159 |
+
representative data set.
|
160 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
161 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
162 |
+
representative data set.
|
163 |
+
Returns:
|
164 |
+
-- : The Frechet Distance.
|
165 |
+
"""
|
166 |
+
|
167 |
+
mu1 = np.atleast_1d(mu1)
|
168 |
+
mu2 = np.atleast_1d(mu2)
|
169 |
+
|
170 |
+
sigma1 = np.atleast_2d(sigma1)
|
171 |
+
sigma2 = np.atleast_2d(sigma2)
|
172 |
+
|
173 |
+
assert mu1.shape == mu2.shape, \
|
174 |
+
'Training and test mean vectors have different lengths'
|
175 |
+
assert sigma1.shape == sigma2.shape, \
|
176 |
+
'Training and test covariances have different dimensions'
|
177 |
+
|
178 |
+
diff = mu1 - mu2
|
179 |
+
|
180 |
+
# Product might be almost singular
|
181 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
182 |
+
if not np.isfinite(covmean).all():
|
183 |
+
msg = ('fid calculation produces singular product; '
|
184 |
+
'adding %s to diagonal of cov estimates') % eps
|
185 |
+
print(msg)
|
186 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
187 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
188 |
+
|
189 |
+
# Numerical error might give slight imaginary component
|
190 |
+
if np.iscomplexobj(covmean):
|
191 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
192 |
+
m = np.max(np.abs(covmean.imag))
|
193 |
+
raise ValueError('Imaginary component {}'.format(m))
|
194 |
+
covmean = covmean.real
|
195 |
+
|
196 |
+
tr_covmean = np.trace(covmean)
|
197 |
+
|
198 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
199 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
evaluation/__init__.py
ADDED
File without changes
|
evaluation/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (181 Bytes). View file
|
|
evaluation/__pycache__/metrics.cpython-37.pyc
ADDED
Binary file (3.81 kB). View file
|
|
evaluation/diversity_LVD.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LVD: different initial pose
|
3 |
+
diversity: same initial pose
|
4 |
+
'''
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
sys.path.append(os.getcwd())
|
8 |
+
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
import json
|
13 |
+
|
14 |
+
from evaluation.util import *
|
15 |
+
from evaluation.metrics import *
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
parser = ArgumentParser()
|
19 |
+
parser.add_argument('--speaker', required=True, type=str)
|
20 |
+
parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
speaker = args.speaker
|
24 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
25 |
+
|
26 |
+
LVD_list = []
|
27 |
+
diversity_list = []
|
28 |
+
|
29 |
+
for aud in tqdm(test_audios):
|
30 |
+
base_name = os.path.splitext(aud)[0]
|
31 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
32 |
+
_, gt_poses, _ = get_gts(gt_path)
|
33 |
+
gt_poses = gt_poses[np.newaxis,...]
|
34 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
35 |
+
for post_fix in args.post_fix:
|
36 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
37 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
38 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
39 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
40 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
41 |
+
|
42 |
+
gt_valid_points = hand_points(gt_poses)
|
43 |
+
pred_valid_points = hand_points(pred_poses)
|
44 |
+
|
45 |
+
lvd = LVD(gt_valid_points, pred_valid_points)
|
46 |
+
# div = diversity(pred_valid_points)
|
47 |
+
|
48 |
+
LVD_list.append(lvd)
|
49 |
+
# diversity_list.append(div)
|
50 |
+
|
51 |
+
# gt_velocity = peak_velocity(gt_valid_points, order=2)
|
52 |
+
# pred_velocity = peak_velocity(pred_valid_points, order=2)
|
53 |
+
|
54 |
+
# gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
|
55 |
+
# pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
|
56 |
+
|
57 |
+
# gt_consistency_list.append(gt_consistency)
|
58 |
+
# pred_consistency_list.append(pred_consistency)
|
59 |
+
|
60 |
+
lvd = np.mean(LVD_list)
|
61 |
+
# diversity_list = np.mean(diversity_list)
|
62 |
+
|
63 |
+
print('LVD:', lvd)
|
64 |
+
# print("diversity:", diversity_list)
|
evaluation/get_quality_samples.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
'''
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
sys.path.append(os.getcwd())
|
6 |
+
|
7 |
+
from glob import glob
|
8 |
+
|
9 |
+
from argparse import ArgumentParser
|
10 |
+
import json
|
11 |
+
|
12 |
+
from evaluation.util import *
|
13 |
+
from evaluation.metrics import *
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
parser = ArgumentParser()
|
17 |
+
parser.add_argument('--speaker', required=True, type=str)
|
18 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
speaker = args.speaker
|
22 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
23 |
+
|
24 |
+
quality_samples={'gt':[]}
|
25 |
+
for post_fix in args.post_fix:
|
26 |
+
quality_samples[post_fix] = []
|
27 |
+
|
28 |
+
for aud in tqdm(test_audios):
|
29 |
+
base_name = os.path.splitext(aud)[0]
|
30 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
31 |
+
_, gt_poses, _ = get_gts(gt_path)
|
32 |
+
gt_poses = gt_poses[np.newaxis,...]
|
33 |
+
gt_valid_points = valid_points(gt_poses)
|
34 |
+
# print(gt_valid_points.shape)
|
35 |
+
quality_samples['gt'].append(gt_valid_points)
|
36 |
+
|
37 |
+
for post_fix in args.post_fix:
|
38 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
39 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
40 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
41 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
42 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
43 |
+
|
44 |
+
pred_valid_points = valid_points(pred_poses)[0:1]
|
45 |
+
quality_samples[post_fix].append(pred_valid_points)
|
46 |
+
|
47 |
+
quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
|
48 |
+
for post_fix in args.post_fix:
|
49 |
+
quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
|
50 |
+
|
51 |
+
print('gt:', quality_samples['gt'].shape)
|
52 |
+
quality_samples['gt'] = quality_samples['gt'].tolist()
|
53 |
+
for post_fix in args.post_fix:
|
54 |
+
print(post_fix, ':', quality_samples[post_fix].shape)
|
55 |
+
quality_samples[post_fix] = quality_samples[post_fix].tolist()
|
56 |
+
|
57 |
+
save_dir = '../../experiments/'
|
58 |
+
os.makedirs(save_dir, exist_ok=True)
|
59 |
+
save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
|
60 |
+
with open(save_name, 'w') as f:
|
61 |
+
json.dump(quality_samples, f)
|
62 |
+
|
evaluation/metrics.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Warning: metrics are for reference only, may have limited significance
|
3 |
+
'''
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
sys.path.append(os.getcwd())
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from data_utils.lower_body import rearrange, symmetry
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
def data_driven_baselines(gt_kps):
|
14 |
+
'''
|
15 |
+
gt_kps: T, D
|
16 |
+
'''
|
17 |
+
gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
|
18 |
+
|
19 |
+
mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
|
20 |
+
mean = np.mean(np.abs(gt_velocity-mean))
|
21 |
+
last_step = gt_kps[1] - gt_kps[0]
|
22 |
+
last_step = last_step[np.newaxis] #(1, D)
|
23 |
+
last_step = np.mean(np.abs(gt_velocity-last_step))
|
24 |
+
return last_step, mean
|
25 |
+
|
26 |
+
def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
|
27 |
+
if gt_kps.shape[0] > pr_kps.shape[1]:
|
28 |
+
length = pr_kps.shape[1]
|
29 |
+
else:
|
30 |
+
length = gt_kps.shape[0]
|
31 |
+
gt_kps = gt_kps[:length]
|
32 |
+
pr_kps = pr_kps[:, :length]
|
33 |
+
global symmetry
|
34 |
+
symmetry = torch.tensor(symmetry).bool()
|
35 |
+
|
36 |
+
if symmetrical:
|
37 |
+
# rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
|
38 |
+
gt_kps = gt_kps[:, rearrange]
|
39 |
+
ns_gt_kps = gt_kps[:, ~symmetry]
|
40 |
+
ys_gt_kps = gt_kps[:, symmetry]
|
41 |
+
ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
|
42 |
+
ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
|
43 |
+
ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
|
44 |
+
left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
|
45 |
+
right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
|
46 |
+
move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
|
47 |
+
ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
|
48 |
+
ys_gt_velocity = ys_gt_velocity.transpose(0,1)
|
49 |
+
gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
|
50 |
+
|
51 |
+
pr_kps = pr_kps[:, :, rearrange]
|
52 |
+
ns_pr_kps = pr_kps[:, :, ~symmetry]
|
53 |
+
ys_pr_kps = pr_kps[:, :, symmetry]
|
54 |
+
ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
|
55 |
+
ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
|
56 |
+
ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
|
57 |
+
left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
|
58 |
+
right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
|
59 |
+
move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
|
60 |
+
torch.zeros(left_pr_vel.shape).cuda())
|
61 |
+
ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
|
62 |
+
ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
|
63 |
+
ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
|
64 |
+
pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
|
65 |
+
else:
|
66 |
+
gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
|
67 |
+
pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
|
68 |
+
|
69 |
+
if weight:
|
70 |
+
w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
|
71 |
+
else:
|
72 |
+
w = 1 / gt_velocity.shape[0]
|
73 |
+
|
74 |
+
v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
|
75 |
+
|
76 |
+
return v_diff
|
77 |
+
|
78 |
+
|
79 |
+
def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
|
80 |
+
gt_kps = gt_kps.squeeze()
|
81 |
+
pr_kps = pr_kps.squeeze()
|
82 |
+
if len(pr_kps.shape) == 4:
|
83 |
+
return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
|
84 |
+
# length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
|
85 |
+
length = gt_kps.shape[0]-10
|
86 |
+
# gt_kps = gt_kps[25:length]
|
87 |
+
# pr_kps = pr_kps[25:length] #(T, D)
|
88 |
+
# if pr_kps.shape[0] < gt_kps.shape[0]:
|
89 |
+
# pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
|
90 |
+
|
91 |
+
gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
|
92 |
+
pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
|
93 |
+
|
94 |
+
return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
|
95 |
+
|
96 |
+
def diversity(kps):
|
97 |
+
'''
|
98 |
+
kps: bs, seq, dim
|
99 |
+
'''
|
100 |
+
dis_list = []
|
101 |
+
#the distance between each pair
|
102 |
+
for i in range(kps.shape[0]):
|
103 |
+
for j in range(i+1, kps.shape[0]):
|
104 |
+
seq_i = kps[i]
|
105 |
+
seq_j = kps[j]
|
106 |
+
|
107 |
+
dis = np.mean(np.abs(seq_i - seq_j))
|
108 |
+
dis_list.append(dis)
|
109 |
+
return np.mean(dis_list)
|
evaluation/mode_transition.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.getcwd())
|
4 |
+
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
import json
|
9 |
+
|
10 |
+
from evaluation.util import *
|
11 |
+
from evaluation.metrics import *
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument('--speaker', required=True, type=str)
|
16 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
speaker = args.speaker
|
20 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
21 |
+
|
22 |
+
precision_list=[]
|
23 |
+
recall_list=[]
|
24 |
+
accuracy_list=[]
|
25 |
+
|
26 |
+
for aud in tqdm(test_audios):
|
27 |
+
base_name = os.path.splitext(aud)[0]
|
28 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
29 |
+
_, gt_poses, _ = get_gts(gt_path)
|
30 |
+
if gt_poses.shape[0] < 50:
|
31 |
+
continue
|
32 |
+
gt_poses = gt_poses[np.newaxis,...]
|
33 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
34 |
+
for post_fix in args.post_fix:
|
35 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
36 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
37 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
38 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
39 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
40 |
+
|
41 |
+
gt_valid_points = valid_points(gt_poses)
|
42 |
+
pred_valid_points = valid_points(pred_poses)
|
43 |
+
|
44 |
+
# print(gt_valid_points.shape, pred_valid_points.shape)
|
45 |
+
|
46 |
+
gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
|
47 |
+
pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
|
48 |
+
|
49 |
+
# baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
|
50 |
+
# pred_mode_transition_seq = baseline
|
51 |
+
precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
|
52 |
+
precision_list.append(precision)
|
53 |
+
recall_list.append(recall)
|
54 |
+
accuracy_list.append(accuracy)
|
55 |
+
print(len(precision_list), len(recall_list), len(accuracy_list))
|
56 |
+
precision_list = np.mean(precision_list)
|
57 |
+
recall_list = np.mean(recall_list)
|
58 |
+
accuracy_list = np.mean(accuracy_list)
|
59 |
+
|
60 |
+
print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
|
evaluation/peak_velocity.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.getcwd())
|
4 |
+
|
5 |
+
from glob import glob
|
6 |
+
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
import json
|
9 |
+
|
10 |
+
from evaluation.util import *
|
11 |
+
from evaluation.metrics import *
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument('--speaker', required=True, type=str)
|
16 |
+
parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
speaker = args.speaker
|
20 |
+
test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
|
21 |
+
|
22 |
+
gt_consistency_list=[]
|
23 |
+
pred_consistency_list=[]
|
24 |
+
|
25 |
+
for aud in tqdm(test_audios):
|
26 |
+
base_name = os.path.splitext(aud)[0]
|
27 |
+
gt_path = get_full_path(aud, speaker, 'val')
|
28 |
+
_, gt_poses, _ = get_gts(gt_path)
|
29 |
+
gt_poses = gt_poses[np.newaxis,...]
|
30 |
+
# print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
|
31 |
+
for post_fix in args.post_fix:
|
32 |
+
pred_path = base_name + '_'+post_fix+'.json'
|
33 |
+
pred_poses = np.array(json.load(open(pred_path)))
|
34 |
+
# print(pred_poses.shape)#(B, seq_len, 108)
|
35 |
+
pred_poses = cvt25(pred_poses, gt_poses)
|
36 |
+
# print(pred_poses.shape)#(B, seq, pose_dim)
|
37 |
+
|
38 |
+
gt_valid_points = hand_points(gt_poses)
|
39 |
+
pred_valid_points = hand_points(pred_poses)
|
40 |
+
|
41 |
+
gt_velocity = peak_velocity(gt_valid_points, order=2)
|
42 |
+
pred_velocity = peak_velocity(pred_valid_points, order=2)
|
43 |
+
|
44 |
+
gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
|
45 |
+
pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
|
46 |
+
|
47 |
+
gt_consistency_list.append(gt_consistency)
|
48 |
+
pred_consistency_list.append(pred_consistency)
|
49 |
+
|
50 |
+
gt_consistency_list = np.concatenate(gt_consistency_list)
|
51 |
+
pred_consistency_list = np.concatenate(pred_consistency_list)
|
52 |
+
|
53 |
+
print(gt_consistency_list.max(), gt_consistency_list.min())
|
54 |
+
print(pred_consistency_list.max(), pred_consistency_list.min())
|
55 |
+
print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
|
56 |
+
print(np.std(gt_consistency_list), np.std(pred_consistency_list))
|
57 |
+
|
58 |
+
draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
|
59 |
+
draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
|
60 |
+
|
61 |
+
to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
|
62 |
+
to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
|
63 |
+
|
64 |
+
np.save('%s_gt.npy'%(speaker), gt_consistency_list)
|
65 |
+
np.save('%s_pred.npy'%(speaker), pred_consistency_list)
|
evaluation/util.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
import pandas as pd
|
7 |
+
def get_gts(clip):
|
8 |
+
'''
|
9 |
+
clip: abs path to the clip dir
|
10 |
+
'''
|
11 |
+
keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
|
12 |
+
|
13 |
+
upper_body_points = list(np.arange(0, 25))
|
14 |
+
poses = []
|
15 |
+
confs = []
|
16 |
+
neck_to_nose_len = []
|
17 |
+
mean_position = []
|
18 |
+
for kp_file in keypoints_files:
|
19 |
+
kp_load = json.load(open(kp_file, 'r'))['people'][0]
|
20 |
+
posepts = kp_load['pose_keypoints_2d']
|
21 |
+
lhandpts = kp_load['hand_left_keypoints_2d']
|
22 |
+
rhandpts = kp_load['hand_right_keypoints_2d']
|
23 |
+
facepts = kp_load['face_keypoints_2d']
|
24 |
+
|
25 |
+
neck = np.array(posepts).reshape(-1,3)[1]
|
26 |
+
nose = np.array(posepts).reshape(-1,3)[0]
|
27 |
+
x_offset = abs(neck[0]-nose[0])
|
28 |
+
y_offset = abs(neck[1]-nose[1])
|
29 |
+
neck_to_nose_len.append(y_offset)
|
30 |
+
mean_position.append([neck[0],neck[1]])
|
31 |
+
|
32 |
+
keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
|
33 |
+
|
34 |
+
upper_body = keypoints[upper_body_points, :]
|
35 |
+
hand_points = keypoints[25:, :]
|
36 |
+
keypoints = np.vstack([upper_body, hand_points])
|
37 |
+
|
38 |
+
poses.append(keypoints)
|
39 |
+
|
40 |
+
if len(neck_to_nose_len) > 0:
|
41 |
+
scale_factor = np.mean(neck_to_nose_len)
|
42 |
+
else:
|
43 |
+
raise ValueError(clip)
|
44 |
+
mean_position = np.mean(np.array(mean_position), axis=0)
|
45 |
+
|
46 |
+
unlocalized_poses = np.array(poses).copy()
|
47 |
+
localized_poses = []
|
48 |
+
for i in range(len(poses)):
|
49 |
+
keypoints = poses[i]
|
50 |
+
neck = keypoints[1].copy()
|
51 |
+
|
52 |
+
keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
|
53 |
+
keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
|
54 |
+
localized_poses.append(keypoints.reshape(-1))
|
55 |
+
|
56 |
+
localized_poses=np.array(localized_poses)
|
57 |
+
return unlocalized_poses, localized_poses, (scale_factor, mean_position)
|
58 |
+
|
59 |
+
def get_full_path(wav_name, speaker, split):
|
60 |
+
'''
|
61 |
+
get clip path from aud file
|
62 |
+
'''
|
63 |
+
wav_name = os.path.basename(wav_name)
|
64 |
+
wav_name = os.path.splitext(wav_name)[0]
|
65 |
+
clip_name, vid_name = wav_name[:10], wav_name[11:]
|
66 |
+
|
67 |
+
full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
|
68 |
+
|
69 |
+
assert os.path.isdir(full_path), full_path
|
70 |
+
|
71 |
+
return full_path
|
72 |
+
|
73 |
+
def smooth(res):
|
74 |
+
'''
|
75 |
+
res: (B, seq_len, pose_dim)
|
76 |
+
'''
|
77 |
+
window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
|
78 |
+
w_size=7
|
79 |
+
for i in range(10, res.shape[1]-3):
|
80 |
+
window.append(res[:, i+3, :])
|
81 |
+
if len(window) > w_size:
|
82 |
+
window = window[1:]
|
83 |
+
|
84 |
+
if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
|
85 |
+
res[:, i, :] = np.mean(window, axis=1)
|
86 |
+
|
87 |
+
return res
|
88 |
+
|
89 |
+
def cvt25(pred_poses, gt_poses=None):
|
90 |
+
'''
|
91 |
+
gt_poses: (1, seq_len, 270), 135 *2
|
92 |
+
pred_poses: (B, seq_len, 108), 54 * 2
|
93 |
+
'''
|
94 |
+
if gt_poses is None:
|
95 |
+
gt_poses = np.zeros_like(pred_poses)
|
96 |
+
else:
|
97 |
+
gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
|
98 |
+
|
99 |
+
length = min(pred_poses.shape[1], gt_poses.shape[1])
|
100 |
+
pred_poses = pred_poses[:, :length, :]
|
101 |
+
gt_poses = gt_poses[:, :length, :]
|
102 |
+
gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
|
103 |
+
pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
|
104 |
+
|
105 |
+
gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
|
106 |
+
gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
|
107 |
+
|
108 |
+
return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
|
109 |
+
|
110 |
+
def hand_points(seq):
|
111 |
+
'''
|
112 |
+
seq: (B, seq_len, 135*2)
|
113 |
+
hands only
|
114 |
+
'''
|
115 |
+
hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
|
116 |
+
seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
|
117 |
+
return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
|
118 |
+
|
119 |
+
def valid_points(seq):
|
120 |
+
'''
|
121 |
+
hands with some head points
|
122 |
+
'''
|
123 |
+
valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
|
124 |
+
seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
|
125 |
+
|
126 |
+
seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
|
127 |
+
assert seq.shape[-1] == 108, seq.shape
|
128 |
+
return seq
|
129 |
+
|
130 |
+
def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
|
131 |
+
plt.figure()
|
132 |
+
plt.hist(seq, bins=100, range=(0, 100), color=color)
|
133 |
+
plt.savefig(save_name)
|
134 |
+
|
135 |
+
def to_excel(seq, save_name='res.xlsx'):
|
136 |
+
'''
|
137 |
+
seq: (T)
|
138 |
+
'''
|
139 |
+
df = pd.DataFrame(seq)
|
140 |
+
writer = pd.ExcelWriter(save_name)
|
141 |
+
df.to_excel(writer, 'sheet1')
|
142 |
+
writer.save()
|
143 |
+
writer.close()
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
random_data = np.random.randint(0, 10, 100)
|
148 |
+
draw_cdf(random_data)
|
losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .losses import *
|
losses/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (174 Bytes). View file
|
|
losses/__pycache__/losses.cpython-37.pyc
ADDED
Binary file (3.53 kB). View file
|
|
losses/losses.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.getcwd())
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
class KeypointLoss(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super(KeypointLoss, self).__init__()
|
14 |
+
|
15 |
+
def forward(self, pred_seq, gt_seq, gt_conf=None):
|
16 |
+
#pred_seq: (B, C, T)
|
17 |
+
if gt_conf is not None:
|
18 |
+
gt_conf = gt_conf >= 0.01
|
19 |
+
return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
|
20 |
+
else:
|
21 |
+
return F.mse_loss(pred_seq, gt_seq)
|
22 |
+
|
23 |
+
|
24 |
+
class KLLoss(nn.Module):
|
25 |
+
def __init__(self, kl_tolerance):
|
26 |
+
super(KLLoss, self).__init__()
|
27 |
+
self.kl_tolerance = kl_tolerance
|
28 |
+
|
29 |
+
def forward(self, mu, var, mul=1):
|
30 |
+
kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
|
31 |
+
kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
|
32 |
+
# kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
|
33 |
+
if self.kl_tolerance is not None:
|
34 |
+
# above_line = kld_loss[kld_loss > self.kl_tolerance]
|
35 |
+
# if len(above_line) > 0:
|
36 |
+
# kld_loss = torch.mean(kld_loss)
|
37 |
+
# else:
|
38 |
+
# kld_loss = 0
|
39 |
+
kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
|
40 |
+
# else:
|
41 |
+
kld_loss = torch.mean(kld_loss)
|
42 |
+
return kld_loss
|
43 |
+
|
44 |
+
|
45 |
+
class L2KLLoss(nn.Module):
|
46 |
+
def __init__(self, kl_tolerance):
|
47 |
+
super(L2KLLoss, self).__init__()
|
48 |
+
self.kl_tolerance = kl_tolerance
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
# TODO: check
|
52 |
+
kld_loss = torch.sum(x ** 2, dim=1)
|
53 |
+
if self.kl_tolerance is not None:
|
54 |
+
above_line = kld_loss[kld_loss > self.kl_tolerance]
|
55 |
+
if len(above_line) > 0:
|
56 |
+
kld_loss = torch.mean(kld_loss)
|
57 |
+
else:
|
58 |
+
kld_loss = 0
|
59 |
+
else:
|
60 |
+
kld_loss = torch.mean(kld_loss)
|
61 |
+
return kld_loss
|
62 |
+
|
63 |
+
class L2RegLoss(nn.Module):
|
64 |
+
def __init__(self):
|
65 |
+
super(L2RegLoss, self).__init__()
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
#TODO: check
|
69 |
+
return torch.sum(x**2)
|
70 |
+
|
71 |
+
|
72 |
+
class L2Loss(nn.Module):
|
73 |
+
def __init__(self):
|
74 |
+
super(L2Loss, self).__init__()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
# TODO: check
|
78 |
+
return torch.sum(x ** 2)
|
79 |
+
|
80 |
+
|
81 |
+
class AudioLoss(nn.Module):
|
82 |
+
def __init__(self):
|
83 |
+
super(AudioLoss, self).__init__()
|
84 |
+
|
85 |
+
def forward(self, dynamics, gt_poses):
|
86 |
+
#pay attention, normalized
|
87 |
+
mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
|
88 |
+
gt = gt_poses - mean
|
89 |
+
return F.mse_loss(dynamics, gt)
|
90 |
+
|
91 |
+
L1Loss = nn.L1Loss
|
nets/LS3DCG.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
not exactly the same as the official repo but the results are good
|
3 |
+
'''
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
|
7 |
+
from data_utils.lower_body import c_index_3d, c_index_6d
|
8 |
+
|
9 |
+
sys.path.append(os.getcwd())
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.optim as optim
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import math
|
17 |
+
|
18 |
+
from nets.base import TrainWrapperBaseClass
|
19 |
+
from nets.layers import SeqEncoder1D
|
20 |
+
from losses import KeypointLoss, L1Loss, KLLoss
|
21 |
+
from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
|
22 |
+
from nets.utils import denormalize
|
23 |
+
|
24 |
+
class Conv1d_tf(nn.Conv1d):
|
25 |
+
"""
|
26 |
+
Conv1d with the padding behavior from TF
|
27 |
+
modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, *args, **kwargs):
|
31 |
+
super(Conv1d_tf, self).__init__(*args, **kwargs)
|
32 |
+
self.padding = kwargs.get("padding", "same")
|
33 |
+
|
34 |
+
def _compute_padding(self, input, dim):
|
35 |
+
input_size = input.size(dim + 2)
|
36 |
+
filter_size = self.weight.size(dim + 2)
|
37 |
+
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
|
38 |
+
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
|
39 |
+
total_padding = max(
|
40 |
+
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
|
41 |
+
)
|
42 |
+
additional_padding = int(total_padding % 2 != 0)
|
43 |
+
|
44 |
+
return additional_padding, total_padding
|
45 |
+
|
46 |
+
def forward(self, input):
|
47 |
+
if self.padding == "VALID":
|
48 |
+
return F.conv1d(
|
49 |
+
input,
|
50 |
+
self.weight,
|
51 |
+
self.bias,
|
52 |
+
self.stride,
|
53 |
+
padding=0,
|
54 |
+
dilation=self.dilation,
|
55 |
+
groups=self.groups,
|
56 |
+
)
|
57 |
+
rows_odd, padding_rows = self._compute_padding(input, dim=0)
|
58 |
+
if rows_odd:
|
59 |
+
input = F.pad(input, [0, rows_odd])
|
60 |
+
|
61 |
+
return F.conv1d(
|
62 |
+
input,
|
63 |
+
self.weight,
|
64 |
+
self.bias,
|
65 |
+
self.stride,
|
66 |
+
padding=(padding_rows // 2),
|
67 |
+
dilation=self.dilation,
|
68 |
+
groups=self.groups,
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
|
73 |
+
if k is None and s is None:
|
74 |
+
if not downsample:
|
75 |
+
k = 3
|
76 |
+
s = 1
|
77 |
+
else:
|
78 |
+
k = 4
|
79 |
+
s = 2
|
80 |
+
|
81 |
+
if type == '1d':
|
82 |
+
conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
|
83 |
+
if norm == 'bn':
|
84 |
+
norm_block = nn.BatchNorm1d(out_channels)
|
85 |
+
elif norm == 'ln':
|
86 |
+
norm_block = nn.LayerNorm(out_channels)
|
87 |
+
elif type == '2d':
|
88 |
+
conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
|
89 |
+
norm_block = nn.BatchNorm2d(out_channels)
|
90 |
+
else:
|
91 |
+
assert False
|
92 |
+
|
93 |
+
return nn.Sequential(
|
94 |
+
conv_block,
|
95 |
+
norm_block,
|
96 |
+
nn.LeakyReLU(0.2, True)
|
97 |
+
)
|
98 |
+
|
99 |
+
class Decoder(nn.Module):
|
100 |
+
def __init__(self, in_ch, out_ch):
|
101 |
+
super(Decoder, self).__init__()
|
102 |
+
self.up1 = nn.Sequential(
|
103 |
+
ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
|
104 |
+
ConvNormRelu(in_ch // 2, in_ch // 2),
|
105 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
106 |
+
)
|
107 |
+
self.up2 = nn.Sequential(
|
108 |
+
ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
|
109 |
+
ConvNormRelu(in_ch // 4, in_ch // 4),
|
110 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
111 |
+
)
|
112 |
+
self.up3 = nn.Sequential(
|
113 |
+
ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
|
114 |
+
ConvNormRelu(in_ch // 8, in_ch // 8),
|
115 |
+
nn.Conv1d(in_ch // 8, out_ch, 1, 1)
|
116 |
+
)
|
117 |
+
|
118 |
+
def forward(self, x, x1, x2, x3):
|
119 |
+
x = F.interpolate(x, x3.shape[2])
|
120 |
+
x = torch.cat([x, x3], dim=1)
|
121 |
+
x = self.up1(x)
|
122 |
+
x = F.interpolate(x, x2.shape[2])
|
123 |
+
x = torch.cat([x, x2], dim=1)
|
124 |
+
x = self.up2(x)
|
125 |
+
x = F.interpolate(x, x1.shape[2])
|
126 |
+
x = torch.cat([x, x1], dim=1)
|
127 |
+
x = self.up3(x)
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class EncoderDecoder(nn.Module):
|
132 |
+
def __init__(self, n_frames, each_dim):
|
133 |
+
super().__init__()
|
134 |
+
self.n_frames = n_frames
|
135 |
+
|
136 |
+
self.down1 = nn.Sequential(
|
137 |
+
ConvNormRelu(64, 64, '1d', False),
|
138 |
+
ConvNormRelu(64, 128, '1d', False),
|
139 |
+
)
|
140 |
+
self.down2 = nn.Sequential(
|
141 |
+
ConvNormRelu(128, 128, '1d', False),
|
142 |
+
ConvNormRelu(128, 256, '1d', False),
|
143 |
+
)
|
144 |
+
self.down3 = nn.Sequential(
|
145 |
+
ConvNormRelu(256, 256, '1d', False),
|
146 |
+
ConvNormRelu(256, 512, '1d', False),
|
147 |
+
)
|
148 |
+
self.down4 = nn.Sequential(
|
149 |
+
ConvNormRelu(512, 512, '1d', False),
|
150 |
+
ConvNormRelu(512, 1024, '1d', False),
|
151 |
+
)
|
152 |
+
|
153 |
+
self.down = nn.MaxPool1d(kernel_size=2)
|
154 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
155 |
+
|
156 |
+
self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
|
157 |
+
self.body_decoder = Decoder(1024, each_dim[1])
|
158 |
+
self.hand_decoder = Decoder(1024, each_dim[2])
|
159 |
+
|
160 |
+
def forward(self, spectrogram, time_steps=None):
|
161 |
+
if time_steps is None:
|
162 |
+
time_steps = self.n_frames
|
163 |
+
|
164 |
+
x1 = self.down1(spectrogram)
|
165 |
+
x = self.down(x1)
|
166 |
+
x2 = self.down2(x)
|
167 |
+
x = self.down(x2)
|
168 |
+
x3 = self.down3(x)
|
169 |
+
x = self.down(x3)
|
170 |
+
x = self.down4(x)
|
171 |
+
x = self.up(x)
|
172 |
+
|
173 |
+
face = self.face_decoder(x, x1, x2, x3)
|
174 |
+
body = self.body_decoder(x, x1, x2, x3)
|
175 |
+
hand = self.hand_decoder(x, x1, x2, x3)
|
176 |
+
|
177 |
+
return face, body, hand
|
178 |
+
|
179 |
+
|
180 |
+
class Generator(nn.Module):
|
181 |
+
def __init__(self,
|
182 |
+
each_dim,
|
183 |
+
training=False,
|
184 |
+
device=None
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.training = training
|
189 |
+
self.device = device
|
190 |
+
|
191 |
+
self.encoderdecoder = EncoderDecoder(15, each_dim)
|
192 |
+
|
193 |
+
def forward(self, in_spec, time_steps=None):
|
194 |
+
if time_steps is not None:
|
195 |
+
self.gen_length = time_steps
|
196 |
+
|
197 |
+
face, body, hand = self.encoderdecoder(in_spec)
|
198 |
+
out = torch.cat([face, body, hand], dim=1)
|
199 |
+
out = out.transpose(1, 2)
|
200 |
+
|
201 |
+
return out
|
202 |
+
|
203 |
+
|
204 |
+
class Discriminator(nn.Module):
|
205 |
+
def __init__(self, input_dim):
|
206 |
+
super().__init__()
|
207 |
+
self.net = nn.Sequential(
|
208 |
+
ConvNormRelu(input_dim, 128, '1d'),
|
209 |
+
ConvNormRelu(128, 256, '1d'),
|
210 |
+
nn.MaxPool1d(kernel_size=2),
|
211 |
+
ConvNormRelu(256, 256, '1d'),
|
212 |
+
ConvNormRelu(256, 512, '1d'),
|
213 |
+
nn.MaxPool1d(kernel_size=2),
|
214 |
+
ConvNormRelu(512, 512, '1d'),
|
215 |
+
ConvNormRelu(512, 1024, '1d'),
|
216 |
+
nn.MaxPool1d(kernel_size=2),
|
217 |
+
nn.Conv1d(1024, 1, 1, 1),
|
218 |
+
nn.Sigmoid()
|
219 |
+
)
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
x = x.transpose(1, 2)
|
223 |
+
|
224 |
+
out = self.net(x)
|
225 |
+
return out
|
226 |
+
|
227 |
+
|
228 |
+
class TrainWrapper(TrainWrapperBaseClass):
|
229 |
+
def __init__(self, args, config) -> None:
|
230 |
+
self.args = args
|
231 |
+
self.config = config
|
232 |
+
self.device = torch.device(self.args.gpu)
|
233 |
+
self.global_step = 0
|
234 |
+
self.convert_to_6d = self.config.Data.pose.convert_to_6d
|
235 |
+
self.init_params()
|
236 |
+
|
237 |
+
self.generator = Generator(
|
238 |
+
each_dim=self.each_dim,
|
239 |
+
training=not self.args.infer,
|
240 |
+
device=self.device,
|
241 |
+
).to(self.device)
|
242 |
+
self.discriminator = Discriminator(
|
243 |
+
input_dim=self.each_dim[1] + self.each_dim[2] + 64
|
244 |
+
).to(self.device)
|
245 |
+
if self.convert_to_6d:
|
246 |
+
self.c_index = c_index_6d
|
247 |
+
else:
|
248 |
+
self.c_index = c_index_3d
|
249 |
+
self.MSELoss = KeypointLoss().to(self.device)
|
250 |
+
self.L1Loss = L1Loss().to(self.device)
|
251 |
+
super().__init__(args, config)
|
252 |
+
|
253 |
+
def init_params(self):
|
254 |
+
scale = 1
|
255 |
+
|
256 |
+
global_orient = round(0 * scale)
|
257 |
+
leye_pose = reye_pose = round(0 * scale)
|
258 |
+
jaw_pose = round(3 * scale)
|
259 |
+
body_pose = round((63 - 24) * scale)
|
260 |
+
left_hand_pose = right_hand_pose = round(45 * scale)
|
261 |
+
|
262 |
+
expression = 100
|
263 |
+
|
264 |
+
b_j = 0
|
265 |
+
jaw_dim = jaw_pose
|
266 |
+
b_e = b_j + jaw_dim
|
267 |
+
eye_dim = leye_pose + reye_pose
|
268 |
+
b_b = b_e + eye_dim
|
269 |
+
body_dim = global_orient + body_pose
|
270 |
+
b_h = b_b + body_dim
|
271 |
+
hand_dim = left_hand_pose + right_hand_pose
|
272 |
+
b_f = b_h + hand_dim
|
273 |
+
face_dim = expression
|
274 |
+
|
275 |
+
self.dim_list = [b_j, b_e, b_b, b_h, b_f]
|
276 |
+
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
|
277 |
+
self.pose = int(self.full_dim / round(3 * scale))
|
278 |
+
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
|
279 |
+
|
280 |
+
def __call__(self, bat):
|
281 |
+
assert (not self.args.infer), "infer mode"
|
282 |
+
self.global_step += 1
|
283 |
+
|
284 |
+
loss_dict = {}
|
285 |
+
|
286 |
+
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
|
287 |
+
expression = bat['expression'].to(self.device).to(torch.float32)
|
288 |
+
jaw = poses[:, :3, :]
|
289 |
+
poses = poses[:, self.c_index, :]
|
290 |
+
|
291 |
+
pred = self.generator(in_spec=aud)
|
292 |
+
|
293 |
+
D_loss, D_loss_dict = self.get_loss(
|
294 |
+
pred_poses=pred.detach(),
|
295 |
+
gt_poses=poses,
|
296 |
+
aud=aud,
|
297 |
+
mode='training_D',
|
298 |
+
)
|
299 |
+
|
300 |
+
self.discriminator_optimizer.zero_grad()
|
301 |
+
D_loss.backward()
|
302 |
+
self.discriminator_optimizer.step()
|
303 |
+
|
304 |
+
G_loss, G_loss_dict = self.get_loss(
|
305 |
+
pred_poses=pred,
|
306 |
+
gt_poses=poses,
|
307 |
+
aud=aud,
|
308 |
+
expression=expression,
|
309 |
+
jaw=jaw,
|
310 |
+
mode='training_G',
|
311 |
+
)
|
312 |
+
self.generator_optimizer.zero_grad()
|
313 |
+
G_loss.backward()
|
314 |
+
self.generator_optimizer.step()
|
315 |
+
|
316 |
+
total_loss = None
|
317 |
+
loss_dict = {}
|
318 |
+
for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
|
319 |
+
loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)
|
320 |
+
|
321 |
+
return total_loss, loss_dict
|
322 |
+
|
323 |
+
def get_loss(self,
|
324 |
+
pred_poses,
|
325 |
+
gt_poses,
|
326 |
+
aud=None,
|
327 |
+
jaw=None,
|
328 |
+
expression=None,
|
329 |
+
mode='training_G',
|
330 |
+
):
|
331 |
+
loss_dict = {}
|
332 |
+
aud = aud.transpose(1, 2)
|
333 |
+
gt_poses = gt_poses.transpose(1, 2)
|
334 |
+
gt_aud = torch.cat([gt_poses, aud], dim=2)
|
335 |
+
pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)
|
336 |
+
|
337 |
+
if mode == 'training_D':
|
338 |
+
dis_real = self.discriminator(gt_aud)
|
339 |
+
dis_fake = self.discriminator(pred_aud)
|
340 |
+
dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
|
341 |
+
torch.zeros_like(dis_fake).to(self.device), dis_fake)
|
342 |
+
loss_dict['dis'] = dis_error
|
343 |
+
|
344 |
+
return dis_error, loss_dict
|
345 |
+
elif mode == 'training_G':
|
346 |
+
jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
|
347 |
+
face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
|
348 |
+
body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
|
349 |
+
hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
|
350 |
+
l1_loss = jaw_loss + face_loss + body_loss + hand_loss
|
351 |
+
|
352 |
+
dis_output = self.discriminator(pred_aud)
|
353 |
+
gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
|
354 |
+
gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error
|
355 |
+
|
356 |
+
loss_dict['gen'] = gen_error
|
357 |
+
loss_dict['jaw_loss'] = jaw_loss
|
358 |
+
loss_dict['face_loss'] = face_loss
|
359 |
+
loss_dict['body_loss'] = body_loss
|
360 |
+
loss_dict['hand_loss'] = hand_loss
|
361 |
+
return gen_loss, loss_dict
|
362 |
+
else:
|
363 |
+
raise ValueError(mode)
|
364 |
+
|
365 |
+
def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
|
366 |
+
output = []
|
367 |
+
assert self.args.infer, "train mode"
|
368 |
+
self.generator.eval()
|
369 |
+
|
370 |
+
if self.config.Data.pose.normalization:
|
371 |
+
assert norm_stats is not None
|
372 |
+
data_mean = norm_stats[0]
|
373 |
+
data_std = norm_stats[1]
|
374 |
+
|
375 |
+
pre_length = self.config.Data.pose.pre_pose_length
|
376 |
+
generate_length = self.config.Data.pose.generate_length
|
377 |
+
# assert pre_length == initial_pose.shape[-1]
|
378 |
+
# pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
|
379 |
+
# B = pre_poses.shape[0]
|
380 |
+
|
381 |
+
aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
|
382 |
+
num_poses_to_generate = aud_feat.shape[-1]
|
383 |
+
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
|
384 |
+
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
|
385 |
+
|
386 |
+
with torch.no_grad():
|
387 |
+
pred_poses = self.generator(aud_feat)
|
388 |
+
pred_poses = pred_poses.cpu().numpy()
|
389 |
+
output = pred_poses.squeeze()
|
390 |
+
|
391 |
+
return output
|
392 |
+
|
393 |
+
def generate(self, aud, id):
|
394 |
+
self.generator.eval()
|
395 |
+
pred_poses = self.generator(aud)
|
396 |
+
return pred_poses
|
397 |
+
|
398 |
+
|
399 |
+
if __name__ == '__main__':
|
400 |
+
from trainer.options import parse_args
|
401 |
+
|
402 |
+
parser = parse_args()
|
403 |
+
args = parser.parse_args(
|
404 |
+
['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
|
405 |
+
'--infer'])
|
406 |
+
|
407 |
+
generator = TrainWrapper(args)
|
408 |
+
|
409 |
+
aud_fn = '../sample_audio/jon.wav'
|
410 |
+
initial_pose = torch.randn(64, 108, 4)
|
411 |
+
norm_stats = (np.random.randn(108), np.random.randn(108))
|
412 |
+
output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)
|
413 |
+
|
414 |
+
print(output.shape)
|
nets/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .smplx_face import TrainWrapper as s2g_face
|
2 |
+
from .smplx_body_vq import TrainWrapper as s2g_body_vq
|
3 |
+
from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
|
4 |
+
from .body_ae import TrainWrapper as s2g_body_ae
|
5 |
+
from .LS3DCG import TrainWrapper as LS3DCG
|
6 |
+
from .base import TrainWrapperBaseClass
|
7 |
+
|
8 |
+
from .utils import normalize, denormalize
|
nets/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (407 Bytes). View file
|
|
nets/__pycache__/base.cpython-37.pyc
ADDED
Binary file (2.29 kB). View file
|
|
nets/__pycache__/init_model.cpython-37.pyc
ADDED
Binary file (460 Bytes). View file
|
|
nets/__pycache__/layers.cpython-37.pyc
ADDED
Binary file (22.7 kB). View file
|
|
nets/__pycache__/smplx_body_pixel.cpython-37.pyc
ADDED
Binary file (9.55 kB). View file
|
|
nets/__pycache__/smplx_body_vq.cpython-37.pyc
ADDED
Binary file (7.89 kB). View file
|
|