mvreddy13 commited on
Commit
f0c7f08
·
1 Parent(s): ed77e1a

Adding new Folders

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +159 -13
  2. __init__.py +0 -0
  3. config/LS3DCG.json +64 -0
  4. config/body_pixel.json +63 -0
  5. config/body_vq.json +62 -0
  6. config/face.json +59 -0
  7. data_utils/__init__.py +3 -0
  8. data_utils/__pycache__/__init__.cpython-37.pyc +0 -0
  9. data_utils/__pycache__/consts.cpython-37.pyc +0 -0
  10. data_utils/__pycache__/dataloader_torch.cpython-37.pyc +0 -0
  11. data_utils/__pycache__/lower_body.cpython-37.pyc +0 -0
  12. data_utils/__pycache__/mesh_dataset.cpython-37.pyc +0 -0
  13. data_utils/__pycache__/rotation_conversion.cpython-37.pyc +0 -0
  14. data_utils/__pycache__/utils.cpython-37.pyc +0 -0
  15. data_utils/apply_split.py +51 -0
  16. data_utils/axis2matrix.py +29 -0
  17. data_utils/consts.py +0 -0
  18. data_utils/dataloader_torch.py +279 -0
  19. data_utils/dataset_preprocess.py +170 -0
  20. data_utils/get_j.py +51 -0
  21. data_utils/hand_component.json +0 -0
  22. data_utils/lower_body.py +143 -0
  23. data_utils/mesh_dataset.py +348 -0
  24. data_utils/rotation_conversion.py +551 -0
  25. data_utils/split_more_than_2s.pkl +3 -0
  26. data_utils/split_train_val_test.py +27 -0
  27. data_utils/train_val_test.json +0 -0
  28. data_utils/utils.py +318 -0
  29. evaluation/FGD.py +199 -0
  30. evaluation/__init__.py +0 -0
  31. evaluation/__pycache__/__init__.cpython-37.pyc +0 -0
  32. evaluation/__pycache__/metrics.cpython-37.pyc +0 -0
  33. evaluation/diversity_LVD.py +64 -0
  34. evaluation/get_quality_samples.py +62 -0
  35. evaluation/metrics.py +109 -0
  36. evaluation/mode_transition.py +60 -0
  37. evaluation/peak_velocity.py +65 -0
  38. evaluation/util.py +148 -0
  39. losses/__init__.py +1 -0
  40. losses/__pycache__/__init__.cpython-37.pyc +0 -0
  41. losses/__pycache__/losses.cpython-37.pyc +0 -0
  42. losses/losses.py +91 -0
  43. nets/LS3DCG.py +414 -0
  44. nets/__init__.py +8 -0
  45. nets/__pycache__/__init__.cpython-37.pyc +0 -0
  46. nets/__pycache__/base.cpython-37.pyc +0 -0
  47. nets/__pycache__/init_model.cpython-37.pyc +0 -0
  48. nets/__pycache__/layers.cpython-37.pyc +0 -0
  49. nets/__pycache__/smplx_body_pixel.cpython-37.pyc +0 -0
  50. nets/__pycache__/smplx_body_vq.cpython-37.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,159 @@
1
- ---
2
- title: TalkShow
3
- emoji: 👀
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.40.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TalkSHOW: Generating Holistic 3D Human Motion from Speech [CVPR2023]
2
+
3
+ The official PyTorch implementation of the **CVPR2023** paper [**"Generating Holistic 3D Human Motion from Speech"**](https://arxiv.org/abs/2212.04420).
4
+
5
+ Please visit our [**webpage**](https://talkshow.is.tue.mpg.de/) for more details.
6
+
7
+ ![teaser](visualise/teaser_01.png)
8
+
9
+ ## HighLight
10
+
11
+ We directly provide the input and our output for the demo data, you can find them in `/demo/` and `/demo_audio/`. TalkSHOW can generalize well on English, French, Songs so far. Looking forward to more demos.
12
+
13
+ You can directly use the generated motion to animate your 3D character or your own digital avatar. We will provide more demos, please stay tuned. And we are quite looking forward to your pull request.
14
+
15
+ ## Notes
16
+
17
+ We are using 100 dimension parameters for SMPL-X facial expression, if you need other dimensions parameters, you can use this code to convert.
18
+
19
+ ```
20
+ https://github.com/yhw-yhw/SHOW/blob/main/cvt_exp_dim_tool.py
21
+ ```
22
+
23
+ ## TODO
24
+
25
+ - [x] [🤗Hugging Face Demo](https://huggingface.co/spaces/feifeifeiliu/TalkSHOW)
26
+ - [ ] Animated 2D videos by the generated motion from TalkSHOW.
27
+
28
+
29
+ ## Getting started
30
+
31
+ The training code was tested on `Ubuntu 18.04.5 LTS` and the visualization code was test on `Windows 10`, and it requires:
32
+
33
+ * Python 3.7
34
+ * conda3 or miniconda3
35
+ * CUDA capable GPU (one is enough)
36
+
37
+
38
+
39
+ ### 1. Setup environment
40
+
41
+ Clone the repo:
42
+ ```bash
43
+ git clone https://github.com/yhw-yhw/TalkSHOW
44
+ cd TalkSHOW
45
+ ```
46
+ Create conda environment:
47
+ ```bash
48
+ conda create --name talkshow python=3.7
49
+ conda activate talkshow
50
+ ```
51
+ Please install pytorch (v1.10.1).
52
+
53
+ pip install -r requirements.txt
54
+
55
+ Please install [**MPI-Mesh**](https://github.com/MPI-IS/mesh).
56
+
57
+ ### 2. Get data
58
+
59
+ Please note that if you only want to generate demo videos, you can skip this step and directly download the pretrained models.
60
+
61
+ Download [**SHOW_dataset_v1.0.zip**](https://download.is.tue.mpg.de/download.php?domain=talkshow&resume=1&sfile=SHOW_dataset_v1.0.zip) from [**TalkSHOW download webpage**](https://talkshow.is.tue.mpg.de/download.php),
62
+ unzip using ``for i in $(ls *.tar.gz);do tar xvf $i;done``.
63
+
64
+ ~~Run ``python data_utils/dataset_preprocess.py`` to check and split dataset.
65
+ Modify ``data_root`` in ``config/*.json`` to the dataset-path.~~
66
+
67
+ Modify ``data_root`` in ``data_utils/apply_split.py`` to the dataset path and run it to apply ``data_utils/split_more_than_2s.pkl`` to the dataset.
68
+
69
+ We will update the benchmark soon.
70
+
71
+ ### 3. Download the pretrained models (Optional)
72
+
73
+ Download [**pretrained models**](https://drive.google.com/file/d/1bC0ZTza8HOhLB46WOJ05sBywFvcotDZG/view?usp=sharing),
74
+ unzip and place it in the TalkSHOW folder, i.e. ``path-to-TalkSHOW/experiments``.
75
+
76
+ ### 4. Training
77
+ Please note that the process of loading data for the first time can be quite slow. If you have already completed the loading process, setting ``dataset_load_mode`` to ``pickle`` in ``config/[config_name].json`` will make the loading process much faster.
78
+
79
+ # 1. Train VQ-VAEs.
80
+ bash train_body_vq.sh
81
+ # 2. Train PixelCNN. Please modify "Model:vq_path" in config/body_pixel.json to the path of VQ-VAEs.
82
+ bash train_body_pixel.sh
83
+ # 3. Train face generator.
84
+ bash train_face.sh
85
+
86
+ ### 5. Testing
87
+
88
+ Modify the arguments in ``test_face.sh`` and ``test_body.sh``. Then
89
+
90
+ bash test_face.sh
91
+ bash test_body.sh
92
+
93
+ ### 5. Visualization
94
+
95
+ If you ssh into the linux machine, NotImplementedError might occur. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error.
96
+
97
+ Download [**smplx model**](https://drive.google.com/file/d/1Ly_hQNLQcZ89KG0Nj4jYZwccQiimSUVn/view?usp=share_link) (Please register in the official [**SMPLX webpage**](https://smpl-x.is.tue.mpg.de) before you use it.)
98
+ and place it in ``path-to-TalkSHOW/visualise/smplx_model``.
99
+ To visualise the test set and generated result (in each video, left: generated result | right: ground truth).
100
+ The videos and generated motion data are saved in ``./visualise/video/body-pixel``:
101
+
102
+ bash visualise.sh
103
+
104
+ If you ssh into the linux machine, there might be an error about OffscreenRenderer. In this case, please refer to [**issue**](https://github.com/MPI-IS/mesh/issues/66) for solving the error.
105
+
106
+ To reproduce the demo videos, run
107
+ ```bash
108
+ # the whole body demo
109
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/1st-page.wav --id 0 --whole_body
110
+ # the face demo
111
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --only_face
112
+ # the identity-specific demo
113
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0
114
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 1
115
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 2
116
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 3 --stand
117
+ # the diversity demo
118
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/style.wav --id 0 --num_samples 12
119
+ # the french demo
120
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/french.wav --id 0
121
+ # the synthetic speech demo
122
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/rich.wav --id 0
123
+ # the song demo
124
+ python scripts/demo.py --config_file ./config/body_pixel.json --infer --audio_file ./demo_audio/song.wav --id 0
125
+ ````
126
+ ### 6. Baseline
127
+
128
+ For training the reproducted "Learning Speech-driven 3D Conversational Gestures from Video" (Habibie et al.), you could run
129
+ ```bash
130
+ python -W ignore scripts/train.py --speakers oliver seth conan chemistry --config_file ./config/LS3DCG.json
131
+ ```
132
+
133
+ For visualization with the pretrained model, download the above [pretrained models](#3-download-the-pretrained-models--optional-) and run
134
+ ```bash
135
+ python scripts/demo.py --config_file ./config/LS3DCG.json --infer --audio_file ./demo_audio/style.wav --body_model_name s2g_LS3DCG --body_model_path experiments/2022-10-19-smplx_S2G-LS3DCG/ckpt-99.pth --id 0
136
+ ```
137
+
138
+ ## Citation
139
+ If you find our work useful to your research, please consider citing:
140
+ ```
141
+ @inproceedings{yi2022generating,
142
+ title={Generating Holistic 3D Human Motion from Speech},
143
+ author={Yi, Hongwei and Liang, Hualin and Liu, Yifei and Cao, Qiong and Wen, Yandong and Bolkart, Timo and Tao, Dacheng and Black, Michael J},
144
+ booktitle={CVPR},
145
+ year={2023}
146
+ }
147
+ ```
148
+
149
+ ## Acknowledgements
150
+ For functions or scripts that are based on external sources, we acknowledge the origin individually in each file.
151
+ Here are some great resources we benefit:
152
+ - [Freeform](https://github.com/TheTempAccount/Co-Speech-Motion-Generation) for training pipeline
153
+ - [MPI-Mesh](https://github.com/MPI-IS/mesh), [Pyrender](https://github.com/mmatl/pyrender), [Smplx](https://github.com/vchoutas/smplx), [VOCA](https://github.com/TimoBolkart/voca) for rendering
154
+ - [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) and [Faceformer](https://github.com/EvelynFan/FaceFormer) for audio encoder
155
+
156
+ ## Contact
157
158
+
159
+ For commercial licensing, please contact [email protected]
__init__.py ADDED
File without changes
config/LS3DCG.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
+ "dataset_load_mode": "pickle",
4
+ "store_file_path": "store.pkl",
5
+ "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
+ "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
+ "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
+ "param": {
9
+ "w_j": 1,
10
+ "w_b": 1,
11
+ "w_h": 1
12
+ },
13
+ "Data": {
14
+ "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
+ "pklname": "_3d_mfcc.pkl",
16
+ "whole_video": false,
17
+ "pose": {
18
+ "normalization": false,
19
+ "convert_to_6d": false,
20
+ "norm_method": "all",
21
+ "augmentation": false,
22
+ "generate_length": 88,
23
+ "pre_pose_length": 0,
24
+ "pose_dim": 99,
25
+ "expression": true
26
+ },
27
+ "aud": {
28
+ "feat_method": "mfcc",
29
+ "aud_feat_dim": 64,
30
+ "aud_feat_win_size": null,
31
+ "context_info": false
32
+ }
33
+ },
34
+ "Model": {
35
+ "model_type": "body",
36
+ "model_name": "s2g_LS3DCG",
37
+ "code_num": 2048,
38
+ "AudioOpt": "Adam",
39
+ "encoder_choice": "mfcc",
40
+ "gan": false
41
+ },
42
+ "DataLoader": {
43
+ "batch_size": 128,
44
+ "num_workers": 0
45
+ },
46
+ "Train": {
47
+ "epochs": 100,
48
+ "max_gradient_norm": 5,
49
+ "learning_rate": {
50
+ "generator_learning_rate": 1e-4,
51
+ "discriminator_learning_rate": 1e-4
52
+ },
53
+ "weights": {
54
+ "keypoint_loss_weight": 1.0,
55
+ "gan_loss_weight": 1.0
56
+ }
57
+ },
58
+ "Log": {
59
+ "save_every": 50,
60
+ "print_every": 200,
61
+ "name": "LS3DCG"
62
+ }
63
+ }
64
+
config/body_pixel.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
+ "dataset_load_mode": "json",
4
+ "store_file_path": "store.pkl",
5
+ "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
+ "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
+ "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
+ "param": {
9
+ "w_j": 1,
10
+ "w_b": 1,
11
+ "w_h": 1
12
+ },
13
+ "Data": {
14
+ "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
+ "pklname": "_3d_mfcc.pkl",
16
+ "whole_video": false,
17
+ "pose": {
18
+ "normalization": false,
19
+ "convert_to_6d": false,
20
+ "norm_method": "all",
21
+ "augmentation": false,
22
+ "generate_length": 88,
23
+ "pre_pose_length": 0,
24
+ "pose_dim": 99,
25
+ "expression": true
26
+ },
27
+ "aud": {
28
+ "feat_method": "mfcc",
29
+ "aud_feat_dim": 64,
30
+ "aud_feat_win_size": null,
31
+ "context_info": false
32
+ }
33
+ },
34
+ "Model": {
35
+ "model_type": "body",
36
+ "model_name": "s2g_body_pixel",
37
+ "composition": true,
38
+ "code_num": 2048,
39
+ "bh_model": true,
40
+ "AudioOpt": "Adam",
41
+ "encoder_choice": "mfcc",
42
+ "gan": false,
43
+ "vq_path": "./experiments/2022-10-31-smplx_S2G-body-vq-3d/ckpt-99.pth"
44
+ },
45
+ "DataLoader": {
46
+ "batch_size": 128,
47
+ "num_workers": 0
48
+ },
49
+ "Train": {
50
+ "epochs": 100,
51
+ "max_gradient_norm": 5,
52
+ "learning_rate": {
53
+ "generator_learning_rate": 1e-4,
54
+ "discriminator_learning_rate": 1e-4
55
+ }
56
+ },
57
+ "Log": {
58
+ "save_every": 50,
59
+ "print_every": 200,
60
+ "name": "body-pixel2"
61
+ }
62
+ }
63
+
config/body_vq.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
+ "dataset_load_mode": "json",
4
+ "store_file_path": "store.pkl",
5
+ "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
+ "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
+ "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
+ "param": {
9
+ "w_j": 1,
10
+ "w_b": 1,
11
+ "w_h": 1
12
+ },
13
+ "Data": {
14
+ "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
+ "pklname": "_3d_mfcc.pkl",
16
+ "whole_video": false,
17
+ "pose": {
18
+ "normalization": false,
19
+ "convert_to_6d": false,
20
+ "norm_method": "all",
21
+ "augmentation": false,
22
+ "generate_length": 88,
23
+ "pre_pose_length": 0,
24
+ "pose_dim": 99,
25
+ "expression": true
26
+ },
27
+ "aud": {
28
+ "feat_method": "mfcc",
29
+ "aud_feat_dim": 64,
30
+ "aud_feat_win_size": null,
31
+ "context_info": false
32
+ }
33
+ },
34
+ "Model": {
35
+ "model_type": "body",
36
+ "model_name": "s2g_body_vq",
37
+ "composition": true,
38
+ "code_num": 2048,
39
+ "bh_model": true,
40
+ "AudioOpt": "Adam",
41
+ "encoder_choice": "mfcc",
42
+ "gan": false
43
+ },
44
+ "DataLoader": {
45
+ "batch_size": 128,
46
+ "num_workers": 0
47
+ },
48
+ "Train": {
49
+ "epochs": 100,
50
+ "max_gradient_norm": 5,
51
+ "learning_rate": {
52
+ "generator_learning_rate": 1e-4,
53
+ "discriminator_learning_rate": 1e-4
54
+ }
55
+ },
56
+ "Log": {
57
+ "save_every": 50,
58
+ "print_every": 200,
59
+ "name": "body-vq"
60
+ }
61
+ }
62
+
config/face.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "config_root_path": "/is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts",
3
+ "dataset_load_mode": "json",
4
+ "store_file_path": "store.pkl",
5
+ "smplx_npz_path": "visualise/smplx_model/SMPLX_NEUTRAL_2020.npz",
6
+ "extra_joint_path": "visualise/smplx_model/smplx_extra_joints.yaml",
7
+ "j14_regressor_path": "visualise/smplx_model/SMPLX_to_J14.pkl",
8
+ "param": {
9
+ "w_j": 1,
10
+ "w_b": 1,
11
+ "w_h": 1
12
+ },
13
+ "Data": {
14
+ "data_root": "../ExpressiveWholeBodyDatasetv1.0/",
15
+ "pklname": "_3d_wv2.pkl",
16
+ "whole_video": true,
17
+ "pose": {
18
+ "normalization": false,
19
+ "convert_to_6d": false,
20
+ "norm_method": "all",
21
+ "augmentation": false,
22
+ "generate_length": 88,
23
+ "pre_pose_length": 0,
24
+ "pose_dim": 99,
25
+ "expression": true
26
+ },
27
+ "aud": {
28
+ "feat_method": "mfcc",
29
+ "aud_feat_dim": 64,
30
+ "aud_feat_win_size": null,
31
+ "context_info": false
32
+ }
33
+ },
34
+ "Model": {
35
+ "model_type": "face",
36
+ "model_name": "s2g_face",
37
+ "AudioOpt": "SGD",
38
+ "encoder_choice": "faceformer",
39
+ "gan": false
40
+ },
41
+ "DataLoader": {
42
+ "batch_size": 1,
43
+ "num_workers": 0
44
+ },
45
+ "Train": {
46
+ "epochs": 100,
47
+ "max_gradient_norm": 5,
48
+ "learning_rate": {
49
+ "generator_learning_rate": 1e-4,
50
+ "discriminator_learning_rate": 1e-4
51
+ }
52
+ },
53
+ "Log": {
54
+ "save_every": 50,
55
+ "print_every": 1000,
56
+ "name": "face"
57
+ }
58
+ }
59
+
data_utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # from .dataloader_csv import MultiVidData as csv_data
2
+ from .dataloader_torch import MultiVidData as torch_data
3
+ from .utils import get_melspec, get_mfcc, get_mfcc_old, get_mfcc_psf, get_mfcc_psf_min, get_mfcc_ta
data_utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (375 Bytes). View file
 
data_utils/__pycache__/consts.cpython-37.pyc ADDED
Binary file (92.7 kB). View file
 
data_utils/__pycache__/dataloader_torch.cpython-37.pyc ADDED
Binary file (5.31 kB). View file
 
data_utils/__pycache__/lower_body.cpython-37.pyc ADDED
Binary file (3.91 kB). View file
 
data_utils/__pycache__/mesh_dataset.cpython-37.pyc ADDED
Binary file (7.9 kB). View file
 
data_utils/__pycache__/rotation_conversion.cpython-37.pyc ADDED
Binary file (16.4 kB). View file
 
data_utils/__pycache__/utils.cpython-37.pyc ADDED
Binary file (7.42 kB). View file
 
data_utils/apply_split.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import pickle
4
+ import shutil
5
+
6
+ speakers = ['seth', 'oliver', 'conan', 'chemistry']
7
+ source_data_root = "../expressive_body-V0.7"
8
+ data_root = "D:/Downloads/SHOW_dataset_v1.0/ExpressiveWholeBodyDatasetReleaseV1.0"
9
+
10
+ f_read = open('split_more_than_2s.pkl', 'rb')
11
+ f_save = open('none.pkl', 'wb')
12
+ data_split = pickle.load(f_read)
13
+ none_split = []
14
+
15
+ train = val = test = 0
16
+
17
+ for speaker_name in speakers:
18
+ speaker_root = os.path.join(data_root, speaker_name)
19
+
20
+ videos = [v for v in data_split[speaker_name]]
21
+
22
+ for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
23
+ for split in data_split[speaker_name][vid]:
24
+ for seq in data_split[speaker_name][vid][split]:
25
+
26
+ seq = seq.replace('\\', '/')
27
+ old_file_path = os.path.join(data_root, speaker_name, vid, seq.split('/')[-1])
28
+ old_file_path = old_file_path.replace('\\', '/')
29
+ new_file_path = seq.replace(source_data_root.split('/')[-1], data_root.split('/')[-1])
30
+ try:
31
+ shutil.move(old_file_path, new_file_path)
32
+ if split == 'train':
33
+ train = train + 1
34
+ elif split == 'test':
35
+ test = test + 1
36
+ elif split == 'val':
37
+ val = val + 1
38
+ except FileNotFoundError:
39
+ none_split.append(old_file_path)
40
+ print(f"The file {old_file_path} does not exists.")
41
+ except shutil.Error:
42
+ none_split.append(old_file_path)
43
+ print(f"The file {old_file_path} does not exists.")
44
+
45
+ print(none_split.__len__())
46
+ pickle.dump(none_split, f_save)
47
+ f_save.close()
48
+
49
+ print(train, val, test)
50
+
51
+
data_utils/axis2matrix.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import scipy.linalg as linalg
4
+
5
+
6
+ def rotate_mat(axis, radian):
7
+
8
+ a = np.cross(np.eye(3), axis / linalg.norm(axis) * radian)
9
+
10
+ rot_matrix = linalg.expm(a)
11
+
12
+ return rot_matrix
13
+
14
+ def aaa2mat(axis, sin, cos):
15
+ i = np.eye(3)
16
+ nnt = np.dot(axis.T, axis)
17
+ s = np.asarray([[0, -axis[0,2], axis[0,1]],
18
+ [axis[0,2], 0, -axis[0,0]],
19
+ [-axis[0,1], axis[0,0], 0]])
20
+ r = cos * i + (1-cos)*nnt +sin * s
21
+ return r
22
+
23
+ rand_axis = np.asarray([[1,0,0]])
24
+ #旋转角度
25
+ r = math.pi/2
26
+ #返回旋转矩阵
27
+ rot_matrix = rotate_mat(rand_axis, r)
28
+ r2 = aaa2mat(rand_axis, np.sin(r), np.cos(r))
29
+ print(rot_matrix)
data_utils/consts.py ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/dataloader_torch.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.getcwd())
4
+ import os
5
+ from tqdm import tqdm
6
+ from data_utils.utils import *
7
+ import torch.utils.data as data
8
+ from data_utils.mesh_dataset import SmplxDataset
9
+ from transformers import Wav2Vec2Processor
10
+
11
+
12
+ class MultiVidData():
13
+ def __init__(self,
14
+ data_root,
15
+ speakers,
16
+ split='train',
17
+ limbscaling=False,
18
+ normalization=False,
19
+ norm_method='new',
20
+ split_trans_zero=False,
21
+ num_frames=25,
22
+ num_pre_frames=25,
23
+ num_generate_length=None,
24
+ aud_feat_win_size=None,
25
+ aud_feat_dim=64,
26
+ feat_method='mel_spec',
27
+ context_info=False,
28
+ smplx=False,
29
+ audio_sr=16000,
30
+ convert_to_6d=False,
31
+ expression=False,
32
+ config=None
33
+ ):
34
+ self.data_root = data_root
35
+ self.speakers = speakers
36
+ self.split = split
37
+ if split == 'pre':
38
+ self.split = 'train'
39
+ self.norm_method=norm_method
40
+ self.normalization = normalization
41
+ self.limbscaling = limbscaling
42
+ self.convert_to_6d = convert_to_6d
43
+ self.num_frames=num_frames
44
+ self.num_pre_frames=num_pre_frames
45
+ if num_generate_length is None:
46
+ self.num_generate_length = num_frames
47
+ else:
48
+ self.num_generate_length = num_generate_length
49
+ self.split_trans_zero=split_trans_zero
50
+
51
+ dataset = SmplxDataset
52
+
53
+ if self.split_trans_zero:
54
+ self.trans_dataset_list = []
55
+ self.zero_dataset_list = []
56
+ else:
57
+ self.all_dataset_list = []
58
+ self.dataset={}
59
+ self.complete_data=[]
60
+ self.config=config
61
+ load_mode=self.config.dataset_load_mode
62
+
63
+ ######################load with pickle file
64
+ if load_mode=='pickle':
65
+ import pickle
66
+ import subprocess
67
+
68
+ # store_file_path='/tmp/store.pkl'
69
+ # cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl /tmp/store.pkl
70
+ # subprocess.run(f'cp /is/cluster/scratch/hyi/ExpressiveBody/SMPLifyX4/scripts/store.pkl {store_file_path}',shell=True)
71
+
72
+ # f = open(self.config.store_file_path, 'rb+')
73
+ f = open(self.split+config.Data.pklname, 'rb+')
74
+ self.dataset=pickle.load(f)
75
+ f.close()
76
+ for key in self.dataset:
77
+ self.complete_data.append(self.dataset[key].complete_data)
78
+ ######################load with pickle file
79
+
80
+ ######################load with a csv file
81
+ elif load_mode=='csv':
82
+
83
+ # 这里从我的一个code文件夹导入的,后续再完善进来
84
+ try:
85
+ sys.path.append(self.config.config_root_path)
86
+ from config import config_path
87
+ from csv_parser import csv_parse
88
+
89
+ except ImportError as e:
90
+ print(f'err: {e}')
91
+ raise ImportError('config root path error...')
92
+
93
+
94
+ for speaker_name in self.speakers:
95
+ # df_intervals=pd.read_csv(self.config.voca_csv_file_path)
96
+ df_intervals=None
97
+ df_intervals=df_intervals[df_intervals['speaker']==speaker_name]
98
+ df_intervals = df_intervals[df_intervals['dataset'] == self.split]
99
+
100
+ print(f'speaker {speaker_name} train interval length: {len(df_intervals)}')
101
+ for iter_index, (_, interval) in tqdm(
102
+ (enumerate(df_intervals.iterrows())),desc=f'load {speaker_name}'
103
+ ):
104
+
105
+ (
106
+ interval_index,
107
+ interval_speaker,
108
+ interval_video_fn,
109
+ interval_id,
110
+
111
+ start_time,
112
+ end_time,
113
+ duration_time,
114
+ start_time_10,
115
+ over_flow_flag,
116
+ short_dur_flag,
117
+
118
+ big_video_dir,
119
+ small_video_dir_name,
120
+ speaker_video_path,
121
+
122
+ voca_basename,
123
+ json_basename,
124
+ wav_basename,
125
+ voca_top_clip_path,
126
+ voca_json_clip_path,
127
+ voca_wav_clip_path,
128
+
129
+ audio_output_fn,
130
+ image_output_path,
131
+ pifpaf_output_path,
132
+ mp_output_path,
133
+ op_output_path,
134
+ deca_output_path,
135
+ pixie_output_path,
136
+ cam_output_path,
137
+ ours_output_path,
138
+ merge_output_path,
139
+ multi_output_path,
140
+ gt_output_path,
141
+ ours_images_path,
142
+ pkl_fil_path,
143
+ )=csv_parse(interval)
144
+
145
+ if not os.path.exists(pkl_fil_path) or not os.path.exists(audio_output_fn):
146
+ continue
147
+
148
+ key=f'{interval_video_fn}/{small_video_dir_name}'
149
+ self.dataset[key] = dataset(
150
+ data_root=pkl_fil_path,
151
+ speaker=speaker_name,
152
+ audio_fn=audio_output_fn,
153
+ audio_sr=audio_sr,
154
+ fps=num_frames,
155
+ feat_method=feat_method,
156
+ audio_feat_dim=aud_feat_dim,
157
+ train=(self.split == 'train'),
158
+ load_all=True,
159
+ split_trans_zero=self.split_trans_zero,
160
+ limbscaling=self.limbscaling,
161
+ num_frames=self.num_frames,
162
+ num_pre_frames=self.num_pre_frames,
163
+ num_generate_length=self.num_generate_length,
164
+ audio_feat_win_size=aud_feat_win_size,
165
+ context_info=context_info,
166
+ convert_to_6d=convert_to_6d,
167
+ expression=expression,
168
+ config=self.config
169
+ )
170
+ self.complete_data.append(self.dataset[key].complete_data)
171
+ ######################load with a csv file
172
+
173
+ ######################origin load method
174
+ elif load_mode=='json':
175
+
176
+ # if self.split == 'train':
177
+ # import pickle
178
+ # f = open('store.pkl', 'rb+')
179
+ # self.dataset=pickle.load(f)
180
+ # f.close()
181
+ # for key in self.dataset:
182
+ # self.complete_data.append(self.dataset[key].complete_data)
183
+ # else:https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav
184
+ # if config.Model.model_type == 'face':
185
+ am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
186
+ am_sr = 16000
187
+ # else:
188
+ # am, am_sr = None, None
189
+ for speaker_name in self.speakers:
190
+ speaker_root = os.path.join(self.data_root, speaker_name)
191
+
192
+ videos=[v for v in os.listdir(speaker_root) ]
193
+ print(videos)
194
+
195
+ haode = huaide = 0
196
+
197
+ for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
198
+ source_vid=vid
199
+ # vid_pth=os.path.join(speaker_root, source_vid, 'images/half', self.split)
200
+ vid_pth = os.path.join(speaker_root, source_vid, self.split)
201
+ if smplx == 'pose':
202
+ seqs = [s for s in os.listdir(vid_pth) if (s.startswith('clip'))]
203
+ else:
204
+ try:
205
+ seqs = [s for s in os.listdir(vid_pth)]
206
+ except:
207
+ continue
208
+
209
+ for s in seqs:
210
+ seq_root=os.path.join(vid_pth, s)
211
+ key = seq_root # correspond to clip******
212
+ audio_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.wav' % (s))
213
+ motion_fname = os.path.join(speaker_root, source_vid, self.split, s, '%s.pkl' % (s))
214
+ if not os.path.isfile(audio_fname) or not os.path.isfile(motion_fname):
215
+ huaide = huaide + 1
216
+ continue
217
+
218
+ self.dataset[key]=dataset(
219
+ data_root=seq_root,
220
+ speaker=speaker_name,
221
+ motion_fn=motion_fname,
222
+ audio_fn=audio_fname,
223
+ audio_sr=audio_sr,
224
+ fps=num_frames,
225
+ feat_method=feat_method,
226
+ audio_feat_dim=aud_feat_dim,
227
+ train=(self.split=='train'),
228
+ load_all=True,
229
+ split_trans_zero=self.split_trans_zero,
230
+ limbscaling=self.limbscaling,
231
+ num_frames=self.num_frames,
232
+ num_pre_frames=self.num_pre_frames,
233
+ num_generate_length=self.num_generate_length,
234
+ audio_feat_win_size=aud_feat_win_size,
235
+ context_info=context_info,
236
+ convert_to_6d=convert_to_6d,
237
+ expression=expression,
238
+ config=self.config,
239
+ am=am,
240
+ am_sr=am_sr,
241
+ whole_video=config.Data.whole_video
242
+ )
243
+ self.complete_data.append(self.dataset[key].complete_data)
244
+ haode = haode + 1
245
+ print("huaide:{}, haode:{}".format(huaide, haode))
246
+ import pickle
247
+
248
+ f = open(self.split+config.Data.pklname, 'wb')
249
+ pickle.dump(self.dataset, f)
250
+ f.close()
251
+ ######################origin load method
252
+
253
+ self.complete_data=np.concatenate(self.complete_data, axis=0)
254
+
255
+ # assert self.complete_data.shape[-1] == (12+21+21)*2
256
+ self.normalize_stats = {}
257
+
258
+ self.data_mean = None
259
+ self.data_std = None
260
+
261
+ def get_dataset(self):
262
+ self.normalize_stats['mean'] = self.data_mean
263
+ self.normalize_stats['std'] = self.data_std
264
+
265
+ for key in list(self.dataset.keys()):
266
+ if self.dataset[key].complete_data.shape[0] < self.num_generate_length:
267
+ continue
268
+ self.dataset[key].num_generate_length = self.num_generate_length
269
+ self.dataset[key].get_dataset(self.normalization, self.normalize_stats, self.split)
270
+ self.all_dataset_list.append(self.dataset[key].all_dataset)
271
+
272
+ if self.split_trans_zero:
273
+ self.trans_dataset = data.ConcatDataset(self.trans_dataset_list)
274
+ self.zero_dataset = data.ConcatDataset(self.zero_dataset_list)
275
+ else:
276
+ self.all_dataset = data.ConcatDataset(self.all_dataset_list)
277
+
278
+
279
+
data_utils/dataset_preprocess.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from tqdm import tqdm
4
+ import shutil
5
+ import torch
6
+ import numpy as np
7
+ import librosa
8
+ import random
9
+
10
+ speakers = ['seth', 'conan', 'oliver', 'chemistry']
11
+ data_root = "../ExpressiveWholeBodyDatasetv1.0/"
12
+ split = 'train'
13
+
14
+
15
+
16
+ def split_list(full_list,shuffle=False,ratio=0.2):
17
+ n_total = len(full_list)
18
+ offset_0 = int(n_total * ratio)
19
+ offset_1 = int(n_total * ratio * 2)
20
+ if n_total==0 or offset_1<1:
21
+ return [],full_list
22
+ if shuffle:
23
+ random.shuffle(full_list)
24
+ sublist_0 = full_list[:offset_0]
25
+ sublist_1 = full_list[offset_0:offset_1]
26
+ sublist_2 = full_list[offset_1:]
27
+ return sublist_0, sublist_1, sublist_2
28
+
29
+
30
+ def moveto(list, file):
31
+ for f in list:
32
+ before, after = '/'.join(f.split('/')[:-1]), f.split('/')[-1]
33
+ new_path = os.path.join(before, file)
34
+ new_path = os.path.join(new_path, after)
35
+ # os.makedirs(new_path)
36
+ # os.path.isdir(new_path)
37
+ # shutil.move(f, new_path)
38
+
39
+ #转移到新目录
40
+ shutil.copytree(f, new_path)
41
+ #删除原train里的文件
42
+ shutil.rmtree(f)
43
+ return None
44
+
45
+
46
+ def read_pkl(data):
47
+ betas = np.array(data['betas'])
48
+
49
+ jaw_pose = np.array(data['jaw_pose'])
50
+ leye_pose = np.array(data['leye_pose'])
51
+ reye_pose = np.array(data['reye_pose'])
52
+ global_orient = np.array(data['global_orient']).squeeze()
53
+ body_pose = np.array(data['body_pose_axis'])
54
+ left_hand_pose = np.array(data['left_hand_pose'])
55
+ right_hand_pose = np.array(data['right_hand_pose'])
56
+
57
+ full_body = np.concatenate(
58
+ (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
59
+
60
+ expression = np.array(data['expression'])
61
+ full_body = np.concatenate((full_body, expression), axis=1)
62
+
63
+ if (full_body.shape[0] < 90) or (torch.isnan(torch.from_numpy(full_body)).sum() > 0):
64
+ return 1
65
+ else:
66
+ return 0
67
+
68
+
69
+ for speaker_name in speakers:
70
+ speaker_root = os.path.join(data_root, speaker_name)
71
+
72
+ videos = [v for v in os.listdir(speaker_root)]
73
+ print(videos)
74
+
75
+ haode = huaide = 0
76
+ total_seqs = []
77
+
78
+ for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
79
+ # for vid in videos:
80
+ source_vid = vid
81
+ vid_pth = os.path.join(speaker_root, source_vid)
82
+ # vid_pth = os.path.join(speaker_root, source_vid, 'images/half', split)
83
+ t = os.path.join(speaker_root, source_vid, 'test')
84
+ v = os.path.join(speaker_root, source_vid, 'val')
85
+
86
+ # if os.path.exists(t):
87
+ # shutil.rmtree(t)
88
+ # if os.path.exists(v):
89
+ # shutil.rmtree(v)
90
+ try:
91
+ seqs = [s for s in os.listdir(vid_pth)]
92
+ except:
93
+ continue
94
+ # if len(seqs) == 0:
95
+ # shutil.rmtree(os.path.join(speaker_root, source_vid))
96
+ # None
97
+ for s in seqs:
98
+ quality = 0
99
+ total_seqs.append(os.path.join(vid_pth,s))
100
+ seq_root = os.path.join(vid_pth, s)
101
+ key = seq_root # correspond to clip******
102
+ audio_fname = os.path.join(speaker_root, source_vid, s, '%s.wav' % (s))
103
+
104
+ # delete the data without audio or the audio file could not be read
105
+ if os.path.isfile(audio_fname):
106
+ try:
107
+ audio = librosa.load(audio_fname)
108
+ except:
109
+ # print(key)
110
+ shutil.rmtree(key)
111
+ huaide = huaide + 1
112
+ continue
113
+ else:
114
+ huaide = huaide + 1
115
+ # print(key)
116
+ shutil.rmtree(key)
117
+ continue
118
+
119
+ # check motion file
120
+ motion_fname = os.path.join(speaker_root, source_vid, s, '%s.pkl' % (s))
121
+ try:
122
+ f = open(motion_fname, 'rb+')
123
+ except:
124
+ shutil.rmtree(key)
125
+ huaide = huaide + 1
126
+ continue
127
+
128
+ data = pickle.load(f)
129
+ w = read_pkl(data)
130
+ f.close()
131
+ quality = quality + w
132
+
133
+ if w == 1:
134
+ shutil.rmtree(key)
135
+ # print(key)
136
+ huaide = huaide + 1
137
+ continue
138
+
139
+ haode = haode + 1
140
+
141
+ print("huaide:{}, haode:{}, total_seqs:{}".format(huaide, haode, total_seqs.__len__()))
142
+
143
+ for speaker_name in speakers:
144
+ speaker_root = os.path.join(data_root, speaker_name)
145
+
146
+ videos = [v for v in os.listdir(speaker_root)]
147
+ print(videos)
148
+
149
+ haode = huaide = 0
150
+ total_seqs = []
151
+
152
+ for vid in tqdm(videos, desc="Processing training data of {}......".format(speaker_name)):
153
+ # for vid in videos:
154
+ source_vid = vid
155
+ vid_pth = os.path.join(speaker_root, source_vid)
156
+ try:
157
+ seqs = [s for s in os.listdir(vid_pth)]
158
+ except:
159
+ continue
160
+ for s in seqs:
161
+ quality = 0
162
+ total_seqs.append(os.path.join(vid_pth, s))
163
+ print("total_seqs:{}".format(total_seqs.__len__()))
164
+ # split the dataset
165
+ test_list, val_list, train_list = split_list(total_seqs, True, 0.1)
166
+ print(len(test_list), len(val_list), len(train_list))
167
+ moveto(train_list, 'train')
168
+ moveto(test_list, 'test')
169
+ moveto(val_list, 'val')
170
+
data_utils/get_j.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def to3d(poses, config):
5
+ if config.Data.pose.convert_to_6d:
6
+ if config.Data.pose.expression:
7
+ poses_exp = poses[:, -100:]
8
+ poses = poses[:, :-100]
9
+
10
+ poses = poses.reshape(poses.shape[0], -1, 5)
11
+ sin, cos = poses[:, :, 3], poses[:, :, 4]
12
+ pose_angle = torch.atan2(sin, cos)
13
+ poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1)
14
+
15
+ if config.Data.pose.expression:
16
+ poses = torch.cat([poses, poses_exp], dim=-1)
17
+ return poses
18
+
19
+
20
+ def get_joint(smplx_model, betas, pred):
21
+ joint = smplx_model(betas=betas.repeat(pred.shape[0], 1),
22
+ expression=pred[:, 165:265],
23
+ jaw_pose=pred[:, 0:3],
24
+ leye_pose=pred[:, 3:6],
25
+ reye_pose=pred[:, 6:9],
26
+ global_orient=pred[:, 9:12],
27
+ body_pose=pred[:, 12:75],
28
+ left_hand_pose=pred[:, 75:120],
29
+ right_hand_pose=pred[:, 120:165],
30
+ return_verts=True)['joints']
31
+ return joint
32
+
33
+
34
+ def get_joints(smplx_model, betas, pred):
35
+ if len(pred.shape) == 3:
36
+ B = pred.shape[0]
37
+ x = 4 if B>= 4 else B
38
+ T = pred.shape[1]
39
+ pred = pred.reshape(-1, 265)
40
+ smplx_model.batch_size = L = T * x
41
+
42
+ times = pred.shape[0] // smplx_model.batch_size
43
+ joints = []
44
+ for i in range(times):
45
+ joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L]))
46
+ joints = torch.cat(joints, dim=0)
47
+ joints = joints.reshape(B, T, -1, 3)
48
+ else:
49
+ smplx_model.batch_size = pred.shape[0]
50
+ joints = get_joint(smplx_model, betas, pred)
51
+ return joints
data_utils/hand_component.json ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/lower_body.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ lower_pose = torch.tensor(
5
+ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
6
+ 0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
7
+ -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
8
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
9
+ lower_pose_stand = torch.tensor([
10
+ 8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
11
+ 3.0747, -0.0158, -0.0152,
12
+ -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
13
+ -3.9716e-01, -4.0229e-02, -1.2637e-01,
14
+ 7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
15
+ 7.8632e-01, -4.3810e-02, 1.4375e-02,
16
+ -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])
17
+ # lower_pose_stand = torch.tensor(
18
+ # [6.4919e-02, 3.3018e-02, 1.7485e-02, 8.9759e-04, 7.1074e-04, -5.9163e-06,
19
+ # 3.0747, -0.0158, -0.0152,
20
+ # -3.3633e+00, -9.3915e-02, 3.0996e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
21
+ # 1.1654603481292725, 0.0, 0.0,
22
+ # 4.4167e-01, 6.7183e-03, -3.6379e-03, 7.9163e-01, 6.8519e-02, -1.5091e-01,
23
+ # 0.0, 0.0, 0.0,
24
+ # 2.2910e-02, -2.4797e-02, -5.5657e-03, -1.0675e-01, 1.2635e-01, 1.6711e-02,])
25
+ lower_body = [0, 1, 3, 4, 6, 7, 9, 10]
26
+ count_part = [6, 9, 12, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
27
+ 29, 30, 31, 32, 33, 34, 35, 36, 37,
28
+ 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54]
29
+ fix_index = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
30
+ 29,
31
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
32
+ 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
33
+ 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]
34
+ all_index = np.ones(275)
35
+ all_index[fix_index] = 0
36
+ c_index = []
37
+ i = 0
38
+ for num in all_index:
39
+ if num == 1:
40
+ c_index.append(i)
41
+ i = i + 1
42
+ c_index = np.asarray(c_index)
43
+
44
+ fix_index_3d = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
45
+ 21, 22, 23, 24, 25, 26,
46
+ 30, 31, 32, 33, 34, 35,
47
+ 45, 46, 47, 48, 49, 50]
48
+ all_index_3d = np.ones(165)
49
+ all_index_3d[fix_index_3d] = 0
50
+ c_index_3d = []
51
+ i = 0
52
+ for num in all_index_3d:
53
+ if num == 1:
54
+ c_index_3d.append(i)
55
+ i = i + 1
56
+ c_index_3d = np.asarray(c_index_3d)
57
+
58
+ c_index_6d = []
59
+ i = 0
60
+ for num in all_index_3d:
61
+ if num == 1:
62
+ c_index_6d.append(2*i)
63
+ c_index_6d.append(2 * i + 1)
64
+ i = i + 1
65
+ c_index_6d = np.asarray(c_index_6d)
66
+
67
+
68
+ def part2full(input, stand=False):
69
+ if stand:
70
+ # lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
71
+ lp = torch.zeros_like(lower_pose)
72
+ lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
73
+ lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
74
+ else:
75
+ lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
76
+
77
+ input = torch.cat([input[:, :3],
78
+ lp[:, :15],
79
+ input[:, 3:6],
80
+ lp[:, 15:21],
81
+ input[:, 6:9],
82
+ lp[:, 21:27],
83
+ input[:, 9:12],
84
+ lp[:, 27:],
85
+ input[:, 12:]]
86
+ , dim=1)
87
+ return input
88
+
89
+
90
+ def pred2poses(input, gt):
91
+ input = torch.cat([input[:, :3],
92
+ gt[0:1, 3:18].repeat(input.shape[0], 1),
93
+ input[:, 3:6],
94
+ gt[0:1, 21:27].repeat(input.shape[0], 1),
95
+ input[:, 6:9],
96
+ gt[0:1, 30:36].repeat(input.shape[0], 1),
97
+ input[:, 9:12],
98
+ gt[0:1, 39:45].repeat(input.shape[0], 1),
99
+ input[:, 12:]]
100
+ , dim=1)
101
+ return input
102
+
103
+
104
+ def poses2poses(input, gt):
105
+ input = torch.cat([input[:, :3],
106
+ gt[0:1, 3:18].repeat(input.shape[0], 1),
107
+ input[:, 18:21],
108
+ gt[0:1, 21:27].repeat(input.shape[0], 1),
109
+ input[:, 27:30],
110
+ gt[0:1, 30:36].repeat(input.shape[0], 1),
111
+ input[:, 36:39],
112
+ gt[0:1, 39:45].repeat(input.shape[0], 1),
113
+ input[:, 45:]]
114
+ , dim=1)
115
+ return input
116
+
117
+ def poses2pred(input, stand=False):
118
+ if stand:
119
+ lp = lower_pose_stand.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
120
+ # lp = torch.zeros_like(lower_pose).unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
121
+ else:
122
+ lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
123
+ input = torch.cat([input[:, :3],
124
+ lp[:, :15],
125
+ input[:, 18:21],
126
+ lp[:, 15:21],
127
+ input[:, 27:30],
128
+ lp[:, 21:27],
129
+ input[:, 36:39],
130
+ lp[:, 27:],
131
+ input[:, 45:]]
132
+ , dim=1)
133
+ return input
134
+
135
+
136
+ rearrange = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]\
137
+ # ,22, 23, 24, 25, 40, 26, 41,
138
+ # 27, 42, 28, 43, 29, 44, 30, 45, 31, 46, 32, 47, 33, 48, 34, 49, 35, 50, 36, 51, 37, 52, 38, 53, 39, 54, 55,
139
+ # 57, 56, 59, 58, 60, 63, 61, 64, 62, 65, 66, 71, 67, 72, 68, 73, 69, 74, 70, 75]
140
+
141
+ symmetry = [0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1]#, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
142
+ # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
143
+ # 1, 1, 1, 1, 1, 1]
data_utils/mesh_dataset.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import sys
3
+ import os
4
+
5
+ sys.path.append(os.getcwd())
6
+
7
+ import json
8
+ from glob import glob
9
+ from data_utils.utils import *
10
+ import torch.utils.data as data
11
+ from data_utils.consts import speaker_id
12
+ from data_utils.lower_body import count_part
13
+ import random
14
+ from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
15
+
16
+ with open('data_utils/hand_component.json') as file_obj:
17
+ comp = json.load(file_obj)
18
+ left_hand_c = np.asarray(comp['left'])
19
+ right_hand_c = np.asarray(comp['right'])
20
+
21
+
22
+ def to3d(data):
23
+ left_hand_pose = np.einsum('bi,ij->bj', data[:, 75:87], left_hand_c[:12, :])
24
+ right_hand_pose = np.einsum('bi,ij->bj', data[:, 87:99], right_hand_c[:12, :])
25
+ data = np.concatenate((data[:, :75], left_hand_pose, right_hand_pose), axis=-1)
26
+ return data
27
+
28
+
29
+ class SmplxDataset():
30
+ '''
31
+ creat a dataset for every segment and concat.
32
+ '''
33
+
34
+ def __init__(self,
35
+ data_root,
36
+ speaker,
37
+ motion_fn,
38
+ audio_fn,
39
+ audio_sr,
40
+ fps,
41
+ feat_method='mel_spec',
42
+ audio_feat_dim=64,
43
+ audio_feat_win_size=None,
44
+
45
+ train=True,
46
+ load_all=False,
47
+ split_trans_zero=False,
48
+ limbscaling=False,
49
+ num_frames=25,
50
+ num_pre_frames=25,
51
+ num_generate_length=25,
52
+ context_info=False,
53
+ convert_to_6d=False,
54
+ expression=False,
55
+ config=None,
56
+ am=None,
57
+ am_sr=None,
58
+ whole_video=False
59
+ ):
60
+
61
+ self.data_root = data_root
62
+ self.speaker = speaker
63
+
64
+ self.feat_method = feat_method
65
+ self.audio_fn = audio_fn
66
+ self.audio_sr = audio_sr
67
+ self.fps = fps
68
+ self.audio_feat_dim = audio_feat_dim
69
+ self.audio_feat_win_size = audio_feat_win_size
70
+ self.context_info = context_info # for aud feat
71
+ self.convert_to_6d = convert_to_6d
72
+ self.expression = expression
73
+
74
+ self.train = train
75
+ self.load_all = load_all
76
+ self.split_trans_zero = split_trans_zero
77
+ self.limbscaling = limbscaling
78
+ self.num_frames = num_frames
79
+ self.num_pre_frames = num_pre_frames
80
+ self.num_generate_length = num_generate_length
81
+ # print('num_generate_length ', self.num_generate_length)
82
+
83
+ self.config = config
84
+ self.am_sr = am_sr
85
+ self.whole_video = whole_video
86
+ load_mode = self.config.dataset_load_mode
87
+
88
+ if load_mode == 'pickle':
89
+ raise NotImplementedError
90
+
91
+ elif load_mode == 'csv':
92
+ import pickle
93
+ with open(data_root, 'rb') as f:
94
+ u = pickle._Unpickler(f)
95
+ data = u.load()
96
+ self.data = data[0]
97
+ if self.load_all:
98
+ self._load_npz_all()
99
+
100
+ elif load_mode == 'json':
101
+ self.annotations = glob(data_root + '/*pkl')
102
+ if len(self.annotations) == 0:
103
+ raise FileNotFoundError(data_root + ' are empty')
104
+ self.annotations = sorted(self.annotations)
105
+ self.img_name_list = self.annotations
106
+
107
+ if self.load_all:
108
+ self._load_them_all(am, am_sr, motion_fn)
109
+
110
+ def _load_npz_all(self):
111
+ self.loaded_data = {}
112
+ self.complete_data = []
113
+ data = self.data
114
+ shape = data['body_pose_axis'].shape[0]
115
+ self.betas = data['betas']
116
+ self.img_name_list = []
117
+ for index in range(shape):
118
+ img_name = f'{index:6d}'
119
+ self.img_name_list.append(img_name)
120
+
121
+ jaw_pose = data['jaw_pose'][index]
122
+ leye_pose = data['leye_pose'][index]
123
+ reye_pose = data['reye_pose'][index]
124
+ global_orient = data['global_orient'][index]
125
+ body_pose = data['body_pose_axis'][index]
126
+ left_hand_pose = data['left_hand_pose'][index]
127
+ right_hand_pose = data['right_hand_pose'][index]
128
+
129
+ full_body = np.concatenate(
130
+ (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose))
131
+ assert full_body.shape[0] == 99
132
+ if self.convert_to_6d:
133
+ full_body = to3d(full_body)
134
+ full_body = torch.from_numpy(full_body)
135
+ full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body))
136
+ full_body = np.asarray(full_body)
137
+ if self.expression:
138
+ expression = data['expression'][index]
139
+ full_body = np.concatenate((full_body, expression))
140
+ # full_body = np.concatenate((full_body, non_zero))
141
+ else:
142
+ full_body = to3d(full_body)
143
+ if self.expression:
144
+ expression = data['expression'][index]
145
+ full_body = np.concatenate((full_body, expression))
146
+
147
+ self.loaded_data[img_name] = full_body.reshape(-1)
148
+ self.complete_data.append(full_body.reshape(-1))
149
+
150
+ self.complete_data = np.array(self.complete_data)
151
+
152
+ if self.audio_feat_win_size is not None:
153
+ self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
154
+ # print(self.audio_feat.shape)
155
+ else:
156
+ if self.feat_method == 'mel_spec':
157
+ self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
158
+ elif self.feat_method == 'mfcc':
159
+ self.audio_feat = get_mfcc(self.audio_fn,
160
+ smlpx=True,
161
+ sr=self.audio_sr,
162
+ n_mfcc=self.audio_feat_dim,
163
+ win_size=self.audio_feat_win_size
164
+ )
165
+
166
+ def _load_them_all(self, am, am_sr, motion_fn):
167
+ self.loaded_data = {}
168
+ self.complete_data = []
169
+ f = open(motion_fn, 'rb+')
170
+ data = pickle.load(f)
171
+
172
+ self.betas = np.array(data['betas'])
173
+
174
+ jaw_pose = np.array(data['jaw_pose'])
175
+ leye_pose = np.array(data['leye_pose'])
176
+ reye_pose = np.array(data['reye_pose'])
177
+ global_orient = np.array(data['global_orient']).squeeze()
178
+ body_pose = np.array(data['body_pose_axis'])
179
+ left_hand_pose = np.array(data['left_hand_pose'])
180
+ right_hand_pose = np.array(data['right_hand_pose'])
181
+
182
+ full_body = np.concatenate(
183
+ (jaw_pose, leye_pose, reye_pose, global_orient, body_pose, left_hand_pose, right_hand_pose), axis=1)
184
+ assert full_body.shape[1] == 99
185
+
186
+
187
+ if self.convert_to_6d:
188
+ full_body = to3d(full_body)
189
+ full_body = torch.from_numpy(full_body)
190
+ full_body = matrix_to_rotation_6d(axis_angle_to_matrix(full_body.reshape(-1, 55, 3))).reshape(-1, 330)
191
+ full_body = np.asarray(full_body)
192
+ if self.expression:
193
+ expression = np.array(data['expression'])
194
+ full_body = np.concatenate((full_body, expression), axis=1)
195
+
196
+ else:
197
+ full_body = to3d(full_body)
198
+ expression = np.array(data['expression'])
199
+ full_body = np.concatenate((full_body, expression), axis=1)
200
+
201
+ self.complete_data = full_body
202
+ self.complete_data = np.array(self.complete_data)
203
+
204
+ if self.audio_feat_win_size is not None:
205
+ self.audio_feat = get_mfcc_old(self.audio_fn).transpose(1, 0)
206
+ else:
207
+ # if self.feat_method == 'mel_spec':
208
+ # self.audio_feat = get_melspec(self.audio_fn, fps=self.fps, sr=self.audio_sr, n_mels=self.audio_feat_dim)
209
+ # elif self.feat_method == 'mfcc':
210
+ self.audio_feat = get_mfcc_ta(self.audio_fn,
211
+ smlpx=True,
212
+ fps=30,
213
+ sr=self.audio_sr,
214
+ n_mfcc=self.audio_feat_dim,
215
+ win_size=self.audio_feat_win_size,
216
+ type=self.feat_method,
217
+ am=am,
218
+ am_sr=am_sr,
219
+ encoder_choice=self.config.Model.encoder_choice,
220
+ )
221
+ # with open(audio_file, 'w', encoding='utf-8') as file:
222
+ # file.write(json.dumps(self.audio_feat.__array__().tolist(), indent=0, ensure_ascii=False))
223
+
224
+ def get_dataset(self, normalization=False, normalize_stats=None, split='train'):
225
+
226
+ class __Worker__(data.Dataset):
227
+ def __init__(child, index_list, normalization, normalize_stats, split='train') -> None:
228
+ super().__init__()
229
+ child.index_list = index_list
230
+ child.normalization = normalization
231
+ child.normalize_stats = normalize_stats
232
+ child.split = split
233
+
234
+ def __getitem__(child, index):
235
+ num_generate_length = self.num_generate_length
236
+ num_pre_frames = self.num_pre_frames
237
+ seq_len = num_generate_length + num_pre_frames
238
+ # print(num_generate_length)
239
+
240
+ index = child.index_list[index]
241
+ index_new = index + random.randrange(0, 5, 3)
242
+ if index_new + seq_len > self.complete_data.shape[0]:
243
+ index_new = index
244
+ index = index_new
245
+
246
+ if child.split in ['val', 'pre', 'test'] or self.whole_video:
247
+ index = 0
248
+ seq_len = self.complete_data.shape[0]
249
+ seq_data = []
250
+ assert index + seq_len <= self.complete_data.shape[0]
251
+ # print(seq_len)
252
+ seq_data = self.complete_data[index:(index + seq_len), :]
253
+ seq_data = np.array(seq_data)
254
+
255
+ '''
256
+ audio feature,
257
+ '''
258
+ if not self.context_info:
259
+ if not self.whole_video:
260
+ audio_feat = self.audio_feat[index:index + seq_len, ...]
261
+ if audio_feat.shape[0] < seq_len:
262
+ audio_feat = np.pad(audio_feat, [[0, seq_len - audio_feat.shape[0]], [0, 0]],
263
+ mode='reflect')
264
+
265
+ assert audio_feat.shape[0] == seq_len and audio_feat.shape[1] == self.audio_feat_dim
266
+ else:
267
+ audio_feat = self.audio_feat
268
+
269
+ else: # including feature and history
270
+ if self.audio_feat_win_size is None:
271
+ audio_feat = self.audio_feat[index:index + seq_len + num_pre_frames, ...]
272
+ if audio_feat.shape[0] < seq_len + num_pre_frames:
273
+ audio_feat = np.pad(audio_feat,
274
+ [[0, seq_len + self.num_frames - audio_feat.shape[0]], [0, 0]],
275
+ mode='constant')
276
+
277
+ assert audio_feat.shape[0] == self.num_frames + seq_len and audio_feat.shape[
278
+ 1] == self.audio_feat_dim
279
+
280
+ if child.normalization:
281
+ data_mean = child.normalize_stats['mean'].reshape(1, -1)
282
+ data_std = child.normalize_stats['std'].reshape(1, -1)
283
+ seq_data[:, :330] = (seq_data[:, :330] - data_mean) / data_std
284
+ if child.split in['train', 'test']:
285
+ if self.convert_to_6d:
286
+ if self.expression:
287
+ data_sample = {
288
+ 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
289
+ 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
290
+ # 'nzero': seq_data[:, 375:].astype(np.float).transpose(1, 0),
291
+ 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
292
+ 'speaker': speaker_id[self.speaker],
293
+ 'betas': self.betas,
294
+ 'aud_file': self.audio_fn,
295
+ }
296
+ else:
297
+ data_sample = {
298
+ 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
299
+ 'nzero': seq_data[:, 330:].astype(np.float).transpose(1, 0),
300
+ 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
301
+ 'speaker': speaker_id[self.speaker],
302
+ 'betas': self.betas
303
+ }
304
+ else:
305
+ if self.expression:
306
+ data_sample = {
307
+ 'poses': seq_data[:, :165].astype(np.float).transpose(1, 0),
308
+ 'expression': seq_data[:, 165:].astype(np.float).transpose(1, 0),
309
+ 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
310
+ # 'wv2_feat': wv2_feat.astype(np.float).transpose(1, 0),
311
+ 'speaker': speaker_id[self.speaker],
312
+ 'aud_file': self.audio_fn,
313
+ 'betas': self.betas
314
+ }
315
+ else:
316
+ data_sample = {
317
+ 'poses': seq_data.astype(np.float).transpose(1, 0),
318
+ 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
319
+ 'speaker': speaker_id[self.speaker],
320
+ 'betas': self.betas
321
+ }
322
+ return data_sample
323
+ else:
324
+ data_sample = {
325
+ 'poses': seq_data[:, :330].astype(np.float).transpose(1, 0),
326
+ 'expression': seq_data[:, 330:].astype(np.float).transpose(1, 0),
327
+ # 'nzero': seq_data[:, 325:].astype(np.float).transpose(1, 0),
328
+ 'aud_feat': audio_feat.astype(np.float).transpose(1, 0),
329
+ 'aud_file': self.audio_fn,
330
+ 'speaker': speaker_id[self.speaker],
331
+ 'betas': self.betas
332
+ }
333
+ return data_sample
334
+ def __len__(child):
335
+ return len(child.index_list)
336
+
337
+ if split == 'train':
338
+ index_list = list(
339
+ range(0, min(self.complete_data.shape[0], self.audio_feat.shape[0]) - self.num_generate_length - self.num_pre_frames,
340
+ 6))
341
+ elif split in ['val', 'test']:
342
+ index_list = list([0])
343
+ if self.whole_video:
344
+ index_list = list([0])
345
+ self.all_dataset = __Worker__(index_list, normalization, normalize_stats, split)
346
+
347
+ def __len__(self):
348
+ return len(self.img_name_list)
data_utils/rotation_conversion.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ # Check PYTORCH3D_LICENCE before use
3
+
4
+ import functools
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ """
12
+ The transformation matrices returned from the functions in this file assume
13
+ the points on which the transformation will be applied are column vectors.
14
+ i.e. the R matrix is structured as
15
+
16
+ R = [
17
+ [Rxx, Rxy, Rxz],
18
+ [Ryx, Ryy, Ryz],
19
+ [Rzx, Rzy, Rzz],
20
+ ] # (3, 3)
21
+
22
+ This matrix can be applied to column vectors by post multiplication
23
+ by the points e.g.
24
+
25
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
26
+ transformed_points = R * points
27
+
28
+ To apply the same matrix to points which are row vectors, the R matrix
29
+ can be transposed and pre multiplied by the points:
30
+
31
+ e.g.
32
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
33
+ transformed_points = points * R.transpose(1, 0)
34
+ """
35
+
36
+
37
+ def quaternion_to_matrix(quaternions):
38
+ """
39
+ Convert rotations given as quaternions to rotation matrices.
40
+
41
+ Args:
42
+ quaternions: quaternions with real part first,
43
+ as tensor of shape (..., 4).
44
+
45
+ Returns:
46
+ Rotation matrices as tensor of shape (..., 3, 3).
47
+ """
48
+ r, i, j, k = torch.unbind(quaternions, -1)
49
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
50
+
51
+ o = torch.stack(
52
+ (
53
+ 1 - two_s * (j * j + k * k),
54
+ two_s * (i * j - k * r),
55
+ two_s * (i * k + j * r),
56
+ two_s * (i * j + k * r),
57
+ 1 - two_s * (i * i + k * k),
58
+ two_s * (j * k - i * r),
59
+ two_s * (i * k - j * r),
60
+ two_s * (j * k + i * r),
61
+ 1 - two_s * (i * i + j * j),
62
+ ),
63
+ -1,
64
+ )
65
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
66
+
67
+
68
+ def _copysign(a, b):
69
+ """
70
+ Return a tensor where each element has the absolute value taken from the,
71
+ corresponding element of a, with sign taken from the corresponding
72
+ element of b. This is like the standard copysign floating-point operation,
73
+ but is not careful about negative 0 and NaN.
74
+
75
+ Args:
76
+ a: source tensor.
77
+ b: tensor whose signs will be used, of the same shape as a.
78
+
79
+ Returns:
80
+ Tensor of the same shape as a with the signs of b.
81
+ """
82
+ signs_differ = (a < 0) != (b < 0)
83
+ return torch.where(signs_differ, -a, a)
84
+
85
+
86
+ def _sqrt_positive_part(x):
87
+ """
88
+ Returns torch.sqrt(torch.max(0, x))
89
+ but with a zero subgradient where x is 0.
90
+ """
91
+ ret = torch.zeros_like(x)
92
+ positive_mask = x > 0
93
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
94
+ return ret
95
+
96
+
97
+ def matrix_to_quaternion(matrix):
98
+ """
99
+ Convert rotations given as rotation matrices to quaternions.
100
+
101
+ Args:
102
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
103
+
104
+ Returns:
105
+ quaternions with real part first, as tensor of shape (..., 4).
106
+ """
107
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
108
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
109
+ m00 = matrix[..., 0, 0]
110
+ m11 = matrix[..., 1, 1]
111
+ m22 = matrix[..., 2, 2]
112
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
113
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
114
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
115
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
116
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
117
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
118
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
119
+ return torch.stack((o0, o1, o2, o3), -1)
120
+
121
+
122
+ def _axis_angle_rotation(axis: str, angle):
123
+ """
124
+ Return the rotation matrices for one of the rotations about an axis
125
+ of which Euler angles describe, for each value of the angle given.
126
+
127
+ Args:
128
+ axis: Axis label "X" or "Y or "Z".
129
+ angle: any shape tensor of Euler angles in radians
130
+
131
+ Returns:
132
+ Rotation matrices as tensor of shape (..., 3, 3).
133
+ """
134
+
135
+ cos = torch.cos(angle)
136
+ sin = torch.sin(angle)
137
+ one = torch.ones_like(angle)
138
+ zero = torch.zeros_like(angle)
139
+
140
+ if axis == "X":
141
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
142
+ if axis == "Y":
143
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
144
+ if axis == "Z":
145
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
146
+
147
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
148
+
149
+
150
+ def euler_angles_to_matrix(euler_angles, convention: str):
151
+ """
152
+ Convert rotations given as Euler angles in radians to rotation matrices.
153
+
154
+ Args:
155
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
156
+ convention: Convention string of three uppercase letters from
157
+ {"X", "Y", and "Z"}.
158
+
159
+ Returns:
160
+ Rotation matrices as tensor of shape (..., 3, 3).
161
+ """
162
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
163
+ raise ValueError("Invalid input euler angles.")
164
+ if len(convention) != 3:
165
+ raise ValueError("Convention must have 3 letters.")
166
+ if convention[1] in (convention[0], convention[2]):
167
+ raise ValueError(f"Invalid convention {convention}.")
168
+ for letter in convention:
169
+ if letter not in ("X", "Y", "Z"):
170
+ raise ValueError(f"Invalid letter {letter} in convention string.")
171
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
172
+ return functools.reduce(torch.matmul, matrices)
173
+
174
+
175
+ def _angle_from_tan(
176
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
177
+ ):
178
+ """
179
+ Extract the first or third Euler angle from the two members of
180
+ the matrix which are positive constant times its sine and cosine.
181
+
182
+ Args:
183
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
184
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
185
+ convention.
186
+ data: Rotation matrices as tensor of shape (..., 3, 3).
187
+ horizontal: Whether we are looking for the angle for the third axis,
188
+ which means the relevant entries are in the same row of the
189
+ rotation matrix. If not, they are in the same column.
190
+ tait_bryan: Whether the first and third axes in the convention differ.
191
+
192
+ Returns:
193
+ Euler Angles in radians for each matrix in data as a tensor
194
+ of shape (...).
195
+ """
196
+
197
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
198
+ if horizontal:
199
+ i2, i1 = i1, i2
200
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
201
+ if horizontal == even:
202
+ return torch.atan2(data[..., i1], data[..., i2])
203
+ if tait_bryan:
204
+ return torch.atan2(-data[..., i2], data[..., i1])
205
+ return torch.atan2(data[..., i2], -data[..., i1])
206
+
207
+
208
+ def _index_from_letter(letter: str):
209
+ if letter == "X":
210
+ return 0
211
+ if letter == "Y":
212
+ return 1
213
+ if letter == "Z":
214
+ return 2
215
+
216
+
217
+ def matrix_to_euler_angles(matrix, convention: str):
218
+ """
219
+ Convert rotations given as rotation matrices to Euler angles in radians.
220
+
221
+ Args:
222
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
223
+ convention: Convention string of three uppercase letters.
224
+
225
+ Returns:
226
+ Euler angles in radians as tensor of shape (..., 3).
227
+ """
228
+ if len(convention) != 3:
229
+ raise ValueError("Convention must have 3 letters.")
230
+ if convention[1] in (convention[0], convention[2]):
231
+ raise ValueError(f"Invalid convention {convention}.")
232
+ for letter in convention:
233
+ if letter not in ("X", "Y", "Z"):
234
+ raise ValueError(f"Invalid letter {letter} in convention string.")
235
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
236
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
237
+ i0 = _index_from_letter(convention[0])
238
+ i2 = _index_from_letter(convention[2])
239
+ tait_bryan = i0 != i2
240
+ if tait_bryan:
241
+ central_angle = torch.asin(
242
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
243
+ )
244
+ else:
245
+ central_angle = torch.acos(matrix[..., i0, i0])
246
+
247
+ o = (
248
+ _angle_from_tan(
249
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
250
+ ),
251
+ central_angle,
252
+ _angle_from_tan(
253
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
254
+ ),
255
+ )
256
+ return torch.stack(o, -1)
257
+
258
+
259
+ def random_quaternions(
260
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
261
+ ):
262
+ """
263
+ Generate random quaternions representing rotations,
264
+ i.e. versors with nonnegative real part.
265
+
266
+ Args:
267
+ n: Number of quaternions in a batch to return.
268
+ dtype: Type to return.
269
+ device: Desired device of returned tensor. Default:
270
+ uses the current device for the default tensor type.
271
+ requires_grad: Whether the resulting tensor should have the gradient
272
+ flag set.
273
+
274
+ Returns:
275
+ Quaternions as tensor of shape (N, 4).
276
+ """
277
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
278
+ s = (o * o).sum(1)
279
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
280
+ return o
281
+
282
+
283
+ def random_rotations(
284
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
285
+ ):
286
+ """
287
+ Generate random rotations as 3x3 rotation matrices.
288
+
289
+ Args:
290
+ n: Number of rotation matrices in a batch to return.
291
+ dtype: Type to return.
292
+ device: Device of returned tensor. Default: if None,
293
+ uses the current device for the default tensor type.
294
+ requires_grad: Whether the resulting tensor should have the gradient
295
+ flag set.
296
+
297
+ Returns:
298
+ Rotation matrices as tensor of shape (n, 3, 3).
299
+ """
300
+ quaternions = random_quaternions(
301
+ n, dtype=dtype, device=device, requires_grad=requires_grad
302
+ )
303
+ return quaternion_to_matrix(quaternions)
304
+
305
+
306
+ def random_rotation(
307
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
308
+ ):
309
+ """
310
+ Generate a single random 3x3 rotation matrix.
311
+
312
+ Args:
313
+ dtype: Type to return
314
+ device: Device of returned tensor. Default: if None,
315
+ uses the current device for the default tensor type
316
+ requires_grad: Whether the resulting tensor should have the gradient
317
+ flag set
318
+
319
+ Returns:
320
+ Rotation matrix as tensor of shape (3, 3).
321
+ """
322
+ return random_rotations(1, dtype, device, requires_grad)[0]
323
+
324
+
325
+ def standardize_quaternion(quaternions):
326
+ """
327
+ Convert a unit quaternion to a standard form: one in which the real
328
+ part is non negative.
329
+
330
+ Args:
331
+ quaternions: Quaternions with real part first,
332
+ as tensor of shape (..., 4).
333
+
334
+ Returns:
335
+ Standardized quaternions as tensor of shape (..., 4).
336
+ """
337
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
338
+
339
+
340
+ def quaternion_raw_multiply(a, b):
341
+ """
342
+ Multiply two quaternions.
343
+ Usual torch rules for broadcasting apply.
344
+
345
+ Args:
346
+ a: Quaternions as tensor of shape (..., 4), real part first.
347
+ b: Quaternions as tensor of shape (..., 4), real part first.
348
+
349
+ Returns:
350
+ The product of a and b, a tensor of quaternions shape (..., 4).
351
+ """
352
+ aw, ax, ay, az = torch.unbind(a, -1)
353
+ bw, bx, by, bz = torch.unbind(b, -1)
354
+ ow = aw * bw - ax * bx - ay * by - az * bz
355
+ ox = aw * bx + ax * bw + ay * bz - az * by
356
+ oy = aw * by - ax * bz + ay * bw + az * bx
357
+ oz = aw * bz + ax * by - ay * bx + az * bw
358
+ return torch.stack((ow, ox, oy, oz), -1)
359
+
360
+
361
+ def quaternion_multiply(a, b):
362
+ """
363
+ Multiply two quaternions representing rotations, returning the quaternion
364
+ representing their composition, i.e. the versor with nonnegative real part.
365
+ Usual torch rules for broadcasting apply.
366
+
367
+ Args:
368
+ a: Quaternions as tensor of shape (..., 4), real part first.
369
+ b: Quaternions as tensor of shape (..., 4), real part first.
370
+
371
+ Returns:
372
+ The product of a and b, a tensor of quaternions of shape (..., 4).
373
+ """
374
+ ab = quaternion_raw_multiply(a, b)
375
+ return standardize_quaternion(ab)
376
+
377
+
378
+ def quaternion_invert(quaternion):
379
+ """
380
+ Given a quaternion representing rotation, get the quaternion representing
381
+ its inverse.
382
+
383
+ Args:
384
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
385
+ first, which must be versors (unit quaternions).
386
+
387
+ Returns:
388
+ The inverse, a tensor of quaternions of shape (..., 4).
389
+ """
390
+
391
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
392
+
393
+
394
+ def quaternion_apply(quaternion, point):
395
+ """
396
+ Apply the rotation given by a quaternion to a 3D point.
397
+ Usual torch rules for broadcasting apply.
398
+
399
+ Args:
400
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
401
+ point: Tensor of 3D points of shape (..., 3).
402
+
403
+ Returns:
404
+ Tensor of rotated points of shape (..., 3).
405
+ """
406
+ if point.size(-1) != 3:
407
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
408
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
409
+ point_as_quaternion = torch.cat((real_parts, point), -1)
410
+ out = quaternion_raw_multiply(
411
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
412
+ quaternion_invert(quaternion),
413
+ )
414
+ return out[..., 1:]
415
+
416
+
417
+ def axis_angle_to_matrix(axis_angle):
418
+ """
419
+ Convert rotations given as axis/angle to rotation matrices.
420
+
421
+ Args:
422
+ axis_angle: Rotations given as a vector in axis angle form,
423
+ as a tensor of shape (..., 3), where the magnitude is
424
+ the angle turned anticlockwise in radians around the
425
+ vector's direction.
426
+
427
+ Returns:
428
+ Rotation matrices as tensor of shape (..., 3, 3).
429
+ """
430
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
431
+
432
+
433
+ def matrix_to_axis_angle(matrix):
434
+ """
435
+ Convert rotations given as rotation matrices to axis/angle.
436
+
437
+ Args:
438
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
439
+
440
+ Returns:
441
+ Rotations given as a vector in axis angle form, as a tensor
442
+ of shape (..., 3), where the magnitude is the angle
443
+ turned anticlockwise in radians around the vector's
444
+ direction.
445
+ """
446
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
447
+
448
+
449
+ def axis_angle_to_quaternion(axis_angle):
450
+ """
451
+ Convert rotations given as axis/angle to quaternions.
452
+
453
+ Args:
454
+ axis_angle: Rotations given as a vector in axis angle form,
455
+ as a tensor of shape (..., 3), where the magnitude is
456
+ the angle turned anticlockwise in radians around the
457
+ vector's direction.
458
+
459
+ Returns:
460
+ quaternions with real part first, as tensor of shape (..., 4).
461
+ """
462
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
463
+ half_angles = 0.5 * angles
464
+ eps = 1e-6
465
+ small_angles = angles.abs() < eps
466
+ sin_half_angles_over_angles = torch.empty_like(angles)
467
+ sin_half_angles_over_angles[~small_angles] = (
468
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
469
+ )
470
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
471
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
472
+ sin_half_angles_over_angles[small_angles] = (
473
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
474
+ )
475
+ quaternions = torch.cat(
476
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
477
+ )
478
+ return quaternions
479
+
480
+
481
+ def quaternion_to_axis_angle(quaternions):
482
+ """
483
+ Convert rotations given as quaternions to axis/angle.
484
+
485
+ Args:
486
+ quaternions: quaternions with real part first,
487
+ as tensor of shape (..., 4).
488
+
489
+ Returns:
490
+ Rotations given as a vector in axis angle form, as a tensor
491
+ of shape (..., 3), where the magnitude is the angle
492
+ turned anticlockwise in radians around the vector's
493
+ direction.
494
+ """
495
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
496
+ half_angles = torch.atan2(norms, quaternions[..., :1])
497
+ angles = 2 * half_angles
498
+ eps = 1e-6
499
+ small_angles = angles.abs() < eps
500
+ sin_half_angles_over_angles = torch.empty_like(angles)
501
+ sin_half_angles_over_angles[~small_angles] = (
502
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
503
+ )
504
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
505
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
506
+ sin_half_angles_over_angles[small_angles] = (
507
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
508
+ )
509
+ return quaternions[..., 1:] / sin_half_angles_over_angles
510
+
511
+
512
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
513
+ """
514
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
515
+ using Gram--Schmidt orthogonalisation per Section B of [1].
516
+ Args:
517
+ d6: 6D rotation representation, of size (*, 6)
518
+
519
+ Returns:
520
+ batch of rotation matrices of size (*, 3, 3)
521
+
522
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
523
+ On the Continuity of Rotation Representations in Neural Networks.
524
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
525
+ Retrieved from http://arxiv.org/abs/1812.07035
526
+ """
527
+
528
+ a1, a2 = d6[..., :3], d6[..., 3:]
529
+ b1 = F.normalize(a1, dim=-1)
530
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
531
+ b2 = F.normalize(b2, dim=-1)
532
+ b3 = torch.cross(b1, b2, dim=-1)
533
+ return torch.stack((b1, b2, b3), dim=-2)
534
+
535
+
536
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
537
+ """
538
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
539
+ by dropping the last row. Note that 6D representation is not unique.
540
+ Args:
541
+ matrix: batch of rotation matrices of size (*, 3, 3)
542
+
543
+ Returns:
544
+ 6D rotation representation, of size (*, 6)
545
+
546
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
547
+ On the Continuity of Rotation Representations in Neural Networks.
548
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
549
+ Retrieved from http://arxiv.org/abs/1812.07035
550
+ """
551
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
data_utils/split_more_than_2s.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2df6e745cdf7473f13ce3ae2ed759c3cceb60c9197e7f3fd65110e7bc20b6f2d
3
+ size 2398875
data_utils/split_train_val_test.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+
5
+ if __name__ =='__main__':
6
+ id_list = "chemistry conan oliver seth"
7
+ id_list = id_list.split(' ')
8
+
9
+ old_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0'
10
+ new_root = '/home/usename/talkshow_data/ExpressiveWholeBodyDatasetReleaseV1.0/talkshow_data_splited'
11
+
12
+ with open('train_val_test.json') as f:
13
+ split_info = json.load(f)
14
+ phase_list = ['train', 'val', 'test']
15
+ for phase in phase_list:
16
+ phase_path_list = split_info[phase]
17
+ for p in phase_path_list:
18
+ old_path = os.path.join(old_root, p)
19
+ if not os.path.exists(old_path):
20
+ print(f'{old_path} not found, continue' )
21
+ continue
22
+ new_path = os.path.join(new_root, phase, p)
23
+ dir_name = os.path.dirname(new_path)
24
+ if not os.path.isdir(dir_name):
25
+ os.makedirs(dir_name, exist_ok=True)
26
+ shutil.move(old_path, new_path)
27
+
data_utils/train_val_test.json ADDED
The diff for this file is too large to render. See raw diff
 
data_utils/utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ # import librosa #has to do this cause librosa is not supported on my server
3
+ import python_speech_features
4
+ from scipy.io import wavfile
5
+ from scipy import signal
6
+ import librosa
7
+ import torch
8
+ import torchaudio as ta
9
+ import torchaudio.functional as ta_F
10
+ import torchaudio.transforms as ta_T
11
+ # import pyloudnorm as pyln
12
+
13
+
14
+ def load_wav_old(audio_fn, sr = 16000):
15
+ sample_rate, sig = wavfile.read(audio_fn)
16
+ if sample_rate != sr:
17
+ result = int((sig.shape[0]) / sample_rate * sr)
18
+ x_resampled = signal.resample(sig, result)
19
+ x_resampled = x_resampled.astype(np.float64)
20
+ return x_resampled, sr
21
+
22
+ sig = sig / (2**15)
23
+ return sig, sample_rate
24
+
25
+
26
+ def get_mfcc(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
27
+
28
+ y, sr = librosa.load(audio_fn, sr=sr, mono=True)
29
+
30
+ if win_size is None:
31
+ hop_len=int(sr / fps)
32
+ else:
33
+ hop_len=int(sr / win_size)
34
+
35
+ n_fft=2048
36
+
37
+ C = librosa.feature.mfcc(
38
+ y = y,
39
+ sr = sr,
40
+ n_mfcc = n_mfcc,
41
+ hop_length = hop_len,
42
+ n_fft = n_fft
43
+ )
44
+
45
+ if C.shape[0] == n_mfcc:
46
+ C = C.transpose(1, 0)
47
+
48
+ return C
49
+
50
+
51
+ def get_melspec(audio_fn, eps=1e-6, fps = 25, sr=16000, n_mels=64):
52
+ raise NotImplementedError
53
+ '''
54
+ # y, sr = load_wav(audio_fn=audio_fn, sr=sr)
55
+
56
+ # hop_len = int(sr / fps)
57
+ # n_fft = 2048
58
+
59
+ # C = librosa.feature.melspectrogram(
60
+ # y = y,
61
+ # sr = sr,
62
+ # n_fft=n_fft,
63
+ # hop_length=hop_len,
64
+ # n_mels = n_mels,
65
+ # fmin=0,
66
+ # fmax=8000)
67
+
68
+
69
+ # mask = (C == 0).astype(np.float)
70
+ # C = mask * eps + (1-mask) * C
71
+
72
+ # C = np.log(C)
73
+ # #wierd error may occur here
74
+ # assert not (np.isnan(C).any()), audio_fn
75
+ # if C.shape[0] == n_mels:
76
+ # C = C.transpose(1, 0)
77
+
78
+ # return C
79
+ '''
80
+
81
+ def extract_mfcc(audio,sample_rate=16000):
82
+ mfcc = zip(*python_speech_features.mfcc(audio,sample_rate, numcep=64, nfilt=64, nfft=2048, winstep=0.04))
83
+ mfcc = np.stack([np.array(i) for i in mfcc])
84
+ return mfcc
85
+
86
+ def get_mfcc_psf(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
87
+ y, sr = load_wav_old(audio_fn, sr=sr)
88
+
89
+ if y.shape.__len__() > 1:
90
+ y = (y[:,0]+y[:,1])/2
91
+
92
+ if win_size is None:
93
+ hop_len=int(sr / fps)
94
+ else:
95
+ hop_len=int(sr/ win_size)
96
+
97
+ n_fft=2048
98
+
99
+ #hard coded for 25 fps
100
+ if not smlpx:
101
+ C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=0.04)
102
+ else:
103
+ C = python_speech_features.mfcc(y, sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01/15)
104
+ # if C.shape[0] == n_mfcc:
105
+ # C = C.transpose(1, 0)
106
+
107
+ return C
108
+
109
+
110
+ def get_mfcc_psf_min(audio_fn, eps=1e-6, fps=25, smlpx=False, sr=16000, n_mfcc=64, win_size=None):
111
+ y, sr = load_wav_old(audio_fn, sr=sr)
112
+
113
+ if y.shape.__len__() > 1:
114
+ y = (y[:, 0] + y[:, 1]) / 2
115
+ n_fft = 2048
116
+
117
+ slice_len = 22000 * 5
118
+ slice = y.size // slice_len
119
+
120
+ C = []
121
+
122
+ for i in range(slice):
123
+ if i != (slice - 1):
124
+ feat = python_speech_features.mfcc(y[i*slice_len:(i+1)*slice_len], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
125
+ else:
126
+ feat = python_speech_features.mfcc(y[i * slice_len:], sr, numcep=n_mfcc, nfilt=n_mfcc, nfft=n_fft, winstep=1.01 / 15)
127
+
128
+ C.append(feat)
129
+
130
+ return C
131
+
132
+
133
+ def audio_chunking(audio: torch.Tensor, frame_rate: int = 30, chunk_size: int = 16000):
134
+ """
135
+ :param audio: 1 x T tensor containing a 16kHz audio signal
136
+ :param frame_rate: frame rate for video (we need one audio chunk per video frame)
137
+ :param chunk_size: number of audio samples per chunk
138
+ :return: num_chunks x chunk_size tensor containing sliced audio
139
+ """
140
+ samples_per_frame = chunk_size // frame_rate
141
+ padding = (chunk_size - samples_per_frame) // 2
142
+ audio = torch.nn.functional.pad(audio.unsqueeze(0), pad=[padding, padding]).squeeze(0)
143
+ anchor_points = list(range(chunk_size//2, audio.shape[-1]-chunk_size//2, samples_per_frame))
144
+ audio = torch.cat([audio[:, i-chunk_size//2:i+chunk_size//2] for i in anchor_points], dim=0)
145
+ return audio
146
+
147
+
148
+ def get_mfcc_ta(audio_fn, eps=1e-6, fps=15, smlpx=False, sr=16000, n_mfcc=64, win_size=None, type='mfcc', am=None, am_sr=None, encoder_choice='mfcc'):
149
+ if am is None:
150
+ audio, sr_0 = ta.load(audio_fn)
151
+ if sr != sr_0:
152
+ audio = ta.transforms.Resample(sr_0, sr)(audio)
153
+ if audio.shape[0] > 1:
154
+ audio = torch.mean(audio, dim=0, keepdim=True)
155
+
156
+ n_fft = 2048
157
+ if fps == 15:
158
+ hop_length = 1467
159
+ elif fps == 30:
160
+ hop_length = 734
161
+ win_length = hop_length * 2
162
+ n_mels = 256
163
+ n_mfcc = 64
164
+
165
+ if type == 'mfcc':
166
+ mfcc_transform = ta_T.MFCC(
167
+ sample_rate=sr,
168
+ n_mfcc=n_mfcc,
169
+ melkwargs={
170
+ "n_fft": n_fft,
171
+ "n_mels": n_mels,
172
+ # "win_length": win_length,
173
+ "hop_length": hop_length,
174
+ "mel_scale": "htk",
175
+ },
176
+ )
177
+ audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0,1).numpy()
178
+ elif type == 'mel':
179
+ # audio = 0.01 * audio / torch.mean(torch.abs(audio))
180
+ mel_transform = ta_T.MelSpectrogram(
181
+ sample_rate=sr, n_fft=n_fft, win_length=None, hop_length=hop_length, n_mels=n_mels
182
+ )
183
+ audio_ft = mel_transform(audio).squeeze(0).transpose(0,1).numpy()
184
+ # audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).transpose(0,1).numpy()
185
+ elif type == 'mel_mul':
186
+ audio = 0.01 * audio / torch.mean(torch.abs(audio))
187
+ audio = audio_chunking(audio, frame_rate=fps, chunk_size=sr)
188
+ mel_transform = ta_T.MelSpectrogram(
189
+ sample_rate=sr, n_fft=n_fft, win_length=int(sr/20), hop_length=int(sr/100), n_mels=n_mels
190
+ )
191
+ audio_ft = mel_transform(audio).squeeze(1)
192
+ audio_ft = torch.log(audio_ft.clamp(min=1e-10, max=None)).numpy()
193
+ else:
194
+ speech_array, sampling_rate = librosa.load(audio_fn, sr=16000)
195
+
196
+ if encoder_choice == 'faceformer':
197
+ # audio_ft = np.squeeze(am(speech_array, sampling_rate=16000).input_values).reshape(-1, 1)
198
+ audio_ft = speech_array.reshape(-1, 1)
199
+ elif encoder_choice == 'meshtalk':
200
+ audio_ft = 0.01 * speech_array / np.mean(np.abs(speech_array))
201
+ elif encoder_choice == 'onset':
202
+ audio_ft = librosa.onset.onset_detect(y=speech_array, sr=16000, units='time').reshape(-1, 1)
203
+ else:
204
+ audio, sr_0 = ta.load(audio_fn)
205
+ if sr != sr_0:
206
+ audio = ta.transforms.Resample(sr_0, sr)(audio)
207
+ if audio.shape[0] > 1:
208
+ audio = torch.mean(audio, dim=0, keepdim=True)
209
+
210
+ n_fft = 2048
211
+ if fps == 15:
212
+ hop_length = 1467
213
+ elif fps == 30:
214
+ hop_length = 734
215
+ win_length = hop_length * 2
216
+ n_mels = 256
217
+ n_mfcc = 64
218
+
219
+ mfcc_transform = ta_T.MFCC(
220
+ sample_rate=sr,
221
+ n_mfcc=n_mfcc,
222
+ melkwargs={
223
+ "n_fft": n_fft,
224
+ "n_mels": n_mels,
225
+ # "win_length": win_length,
226
+ "hop_length": hop_length,
227
+ "mel_scale": "htk",
228
+ },
229
+ )
230
+ audio_ft = mfcc_transform(audio).squeeze(dim=0).transpose(0, 1).numpy()
231
+ return audio_ft
232
+
233
+
234
+ def get_mfcc_sepa(audio_fn, fps=15, sr=16000):
235
+ audio, sr_0 = ta.load(audio_fn)
236
+ if sr != sr_0:
237
+ audio = ta.transforms.Resample(sr_0, sr)(audio)
238
+ if audio.shape[0] > 1:
239
+ audio = torch.mean(audio, dim=0, keepdim=True)
240
+
241
+ n_fft = 2048
242
+ if fps == 15:
243
+ hop_length = 1467
244
+ elif fps == 30:
245
+ hop_length = 734
246
+ n_mels = 256
247
+ n_mfcc = 64
248
+
249
+ mfcc_transform = ta_T.MFCC(
250
+ sample_rate=sr,
251
+ n_mfcc=n_mfcc,
252
+ melkwargs={
253
+ "n_fft": n_fft,
254
+ "n_mels": n_mels,
255
+ # "win_length": win_length,
256
+ "hop_length": hop_length,
257
+ "mel_scale": "htk",
258
+ },
259
+ )
260
+ audio_ft_0 = mfcc_transform(audio[0, :sr*2]).squeeze(dim=0).transpose(0,1).numpy()
261
+ audio_ft_1 = mfcc_transform(audio[0, sr*2:]).squeeze(dim=0).transpose(0,1).numpy()
262
+ audio_ft = np.concatenate((audio_ft_0, audio_ft_1), axis=0)
263
+ return audio_ft, audio_ft_0.shape[0]
264
+
265
+
266
+ def get_mfcc_old(wav_file):
267
+ sig, sample_rate = load_wav_old(wav_file)
268
+ mfcc = extract_mfcc(sig)
269
+ return mfcc
270
+
271
+
272
+ def smooth_geom(geom, mask: torch.Tensor = None, filter_size: int = 9, sigma: float = 2.0):
273
+ """
274
+ :param geom: T x V x 3 tensor containing a temporal sequence of length T with V vertices in each frame
275
+ :param mask: V-dimensional Tensor containing a mask with vertices to be smoothed
276
+ :param filter_size: size of the Gaussian filter
277
+ :param sigma: standard deviation of the Gaussian filter
278
+ :return: T x V x 3 tensor containing smoothed geometry (i.e., smoothed in the area indicated by the mask)
279
+ """
280
+ assert filter_size % 2 == 1, f"filter size must be odd but is {filter_size}"
281
+ # Gaussian smoothing (low-pass filtering)
282
+ fltr = np.arange(-(filter_size // 2), filter_size // 2 + 1)
283
+ fltr = np.exp(-0.5 * fltr ** 2 / sigma ** 2)
284
+ fltr = torch.Tensor(fltr) / np.sum(fltr)
285
+ # apply fltr
286
+ fltr = fltr.view(1, 1, -1).to(device=geom.device)
287
+ T, V = geom.shape[1], geom.shape[2]
288
+ g = torch.nn.functional.pad(
289
+ geom.permute(2, 0, 1).view(V, 1, T),
290
+ pad=[filter_size // 2, filter_size // 2], mode='replicate'
291
+ )
292
+ g = torch.nn.functional.conv1d(g, fltr).view(V, 1, T)
293
+ smoothed = g.permute(1, 2, 0).contiguous()
294
+ # blend smoothed signal with original signal
295
+ if mask is None:
296
+ return smoothed
297
+ else:
298
+ return smoothed * mask[None, :, None] + geom * (-mask[None, :, None] + 1)
299
+
300
+ if __name__ == '__main__':
301
+ audio_fn = '../sample_audio/clip000028_tCAkv4ggPgI.wav'
302
+
303
+ C = get_mfcc_psf(audio_fn)
304
+ print(C.shape)
305
+
306
+ C_2 = get_mfcc_librosa(audio_fn)
307
+ print(C.shape)
308
+
309
+ print(C)
310
+ print(C_2)
311
+ print((C == C_2).all())
312
+ # print(y.shape, sr)
313
+ # mel_spec = get_melspec(audio_fn)
314
+ # print(mel_spec.shape)
315
+ # mfcc = get_mfcc(audio_fn, sr = 16000)
316
+ # print(mfcc.shape)
317
+ # print(mel_spec.max(), mel_spec.min())
318
+ # print(mfcc.max(), mfcc.min())
evaluation/FGD.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from scipy import linalg
7
+ import math
8
+ from data_utils.rotation_conversion import axis_angle_to_matrix, matrix_to_rotation_6d
9
+
10
+ import warnings
11
+ warnings.filterwarnings("ignore", category=RuntimeWarning) # ignore warnings
12
+
13
+
14
+ change_angle = torch.tensor([6.0181e-05, 5.1597e-05, 2.1344e-04, 2.1899e-04])
15
+ class EmbeddingSpaceEvaluator:
16
+ def __init__(self, ae, vae, device):
17
+
18
+ # init embed net
19
+ self.ae = ae
20
+ # self.vae = vae
21
+
22
+ # storage
23
+ self.real_feat_list = []
24
+ self.generated_feat_list = []
25
+ self.real_joints_list = []
26
+ self.generated_joints_list = []
27
+ self.real_6d_list = []
28
+ self.generated_6d_list = []
29
+ self.audio_beat_list = []
30
+
31
+ def reset(self):
32
+ self.real_feat_list = []
33
+ self.generated_feat_list = []
34
+
35
+ def get_no_of_samples(self):
36
+ return len(self.real_feat_list)
37
+
38
+ def push_samples(self, generated_poses, real_poses):
39
+ # self.net.eval()
40
+ # convert poses to latent features
41
+ real_feat, real_poses = self.ae.extract(real_poses)
42
+ generated_feat, generated_poses = self.ae.extract(generated_poses)
43
+
44
+ num_joints = real_poses.shape[2] // 3
45
+
46
+ real_feat = real_feat.squeeze()
47
+ generated_feat = generated_feat.reshape(generated_feat.shape[0]*generated_feat.shape[1], -1)
48
+
49
+ self.real_feat_list.append(real_feat.data.cpu().numpy())
50
+ self.generated_feat_list.append(generated_feat.data.cpu().numpy())
51
+
52
+ # real_poses = matrix_to_rotation_6d(axis_angle_to_matrix(real_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
53
+ # generated_poses = matrix_to_rotation_6d(axis_angle_to_matrix(generated_poses.reshape(-1, 3))).reshape(-1, num_joints, 6)
54
+ #
55
+ # self.real_feat_list.append(real_poses.data.cpu().numpy())
56
+ # self.generated_feat_list.append(generated_poses.data.cpu().numpy())
57
+
58
+ def push_joints(self, generated_poses, real_poses):
59
+ self.real_joints_list.append(real_poses.data.cpu())
60
+ self.generated_joints_list.append(generated_poses.squeeze().data.cpu())
61
+
62
+ def push_aud(self, aud):
63
+ self.audio_beat_list.append(aud.squeeze().data.cpu())
64
+
65
+ def get_MAAC(self):
66
+ ang_vel_list = []
67
+ for real_joints in self.real_joints_list:
68
+ real_joints[:, 15:21] = real_joints[:, 16:22]
69
+ vec = real_joints[:, 15:21] - real_joints[:, 13:19]
70
+ inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
71
+ inner_product = torch.clamp(inner_product, -1, 1, out=None)
72
+ angle = torch.acos(inner_product) / math.pi
73
+ ang_vel = (angle[1:] - angle[:-1]).abs().mean(dim=0)
74
+ ang_vel_list.append(ang_vel.unsqueeze(dim=0))
75
+ all_vel = torch.cat(ang_vel_list, dim=0)
76
+ MAAC = all_vel.mean(dim=0)
77
+ return MAAC
78
+
79
+ def get_BCscore(self):
80
+ thres = 0.01
81
+ sigma = 0.1
82
+ sum_1 = 0
83
+ total_beat = 0
84
+ for joints, audio_beat_time in zip(self.generated_joints_list, self.audio_beat_list):
85
+ motion_beat_time = []
86
+ if joints.dim() == 4:
87
+ joints = joints[0]
88
+ joints[:, 15:21] = joints[:, 16:22]
89
+ vec = joints[:, 15:21] - joints[:, 13:19]
90
+ inner_product = torch.einsum('kij,kij->ki', [vec[:, 2:], vec[:, :-2]])
91
+ inner_product = torch.clamp(inner_product, -1, 1, out=None)
92
+ angle = torch.acos(inner_product) / math.pi
93
+ ang_vel = (angle[1:] - angle[:-1]).abs() / change_angle / len(change_angle)
94
+
95
+ angle_diff = torch.cat((torch.zeros(1, 4), ang_vel), dim=0)
96
+
97
+ sum_2 = 0
98
+ for i in range(angle_diff.shape[1]):
99
+ motion_beat_time = []
100
+ for t in range(1, joints.shape[0]-1):
101
+ if (angle_diff[t][i] < angle_diff[t - 1][i] and angle_diff[t][i] < angle_diff[t + 1][i]):
102
+ if (angle_diff[t - 1][i] - angle_diff[t][i] >= thres or angle_diff[t + 1][i] - angle_diff[
103
+ t][i] >= thres):
104
+ motion_beat_time.append(float(t) / 30.0)
105
+ if (len(motion_beat_time) == 0):
106
+ continue
107
+ motion_beat_time = torch.tensor(motion_beat_time)
108
+ sum = 0
109
+ for audio in audio_beat_time:
110
+ sum += np.power(math.e, -(np.power((audio.item() - motion_beat_time), 2)).min() / (2 * sigma * sigma))
111
+ sum_2 = sum_2 + sum
112
+ total_beat = total_beat + len(audio_beat_time)
113
+ sum_1 = sum_1 + sum_2
114
+ return sum_1/total_beat
115
+
116
+
117
+ def get_scores(self):
118
+ generated_feats = np.vstack(self.generated_feat_list)
119
+ real_feats = np.vstack(self.real_feat_list)
120
+
121
+ def frechet_distance(samples_A, samples_B):
122
+ A_mu = np.mean(samples_A, axis=0)
123
+ A_sigma = np.cov(samples_A, rowvar=False)
124
+ B_mu = np.mean(samples_B, axis=0)
125
+ B_sigma = np.cov(samples_B, rowvar=False)
126
+ try:
127
+ frechet_dist = self.calculate_frechet_distance(A_mu, A_sigma, B_mu, B_sigma)
128
+ except ValueError:
129
+ frechet_dist = 1e+10
130
+ return frechet_dist
131
+
132
+ ####################################################################
133
+ # frechet distance
134
+ frechet_dist = frechet_distance(generated_feats, real_feats)
135
+
136
+ ####################################################################
137
+ # distance between real and generated samples on the latent feature space
138
+ dists = []
139
+ for i in range(real_feats.shape[0]):
140
+ d = np.sum(np.absolute(real_feats[i] - generated_feats[i])) # MAE
141
+ dists.append(d)
142
+ feat_dist = np.mean(dists)
143
+
144
+ return frechet_dist, feat_dist
145
+
146
+ @staticmethod
147
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
148
+ """ from https://github.com/mseitzer/pytorch-fid/blob/master/fid_score.py """
149
+ """Numpy implementation of the Frechet Distance.
150
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
151
+ and X_2 ~ N(mu_2, C_2) is
152
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
153
+ Stable version by Dougal J. Sutherland.
154
+ Params:
155
+ -- mu1 : Numpy array containing the activations of a layer of the
156
+ inception net (like returned by the function 'get_predictions')
157
+ for generated samples.
158
+ -- mu2 : The sample mean over activations, precalculated on an
159
+ representative data set.
160
+ -- sigma1: The covariance matrix over activations for generated samples.
161
+ -- sigma2: The covariance matrix over activations, precalculated on an
162
+ representative data set.
163
+ Returns:
164
+ -- : The Frechet Distance.
165
+ """
166
+
167
+ mu1 = np.atleast_1d(mu1)
168
+ mu2 = np.atleast_1d(mu2)
169
+
170
+ sigma1 = np.atleast_2d(sigma1)
171
+ sigma2 = np.atleast_2d(sigma2)
172
+
173
+ assert mu1.shape == mu2.shape, \
174
+ 'Training and test mean vectors have different lengths'
175
+ assert sigma1.shape == sigma2.shape, \
176
+ 'Training and test covariances have different dimensions'
177
+
178
+ diff = mu1 - mu2
179
+
180
+ # Product might be almost singular
181
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
182
+ if not np.isfinite(covmean).all():
183
+ msg = ('fid calculation produces singular product; '
184
+ 'adding %s to diagonal of cov estimates') % eps
185
+ print(msg)
186
+ offset = np.eye(sigma1.shape[0]) * eps
187
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
188
+
189
+ # Numerical error might give slight imaginary component
190
+ if np.iscomplexobj(covmean):
191
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
192
+ m = np.max(np.abs(covmean.imag))
193
+ raise ValueError('Imaginary component {}'.format(m))
194
+ covmean = covmean.real
195
+
196
+ tr_covmean = np.trace(covmean)
197
+
198
+ return (diff.dot(diff) + np.trace(sigma1) +
199
+ np.trace(sigma2) - 2 * tr_covmean)
evaluation/__init__.py ADDED
File without changes
evaluation/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (181 Bytes). View file
 
evaluation/__pycache__/metrics.cpython-37.pyc ADDED
Binary file (3.81 kB). View file
 
evaluation/diversity_LVD.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LVD: different initial pose
3
+ diversity: same initial pose
4
+ '''
5
+ import os
6
+ import sys
7
+ sys.path.append(os.getcwd())
8
+
9
+ from glob import glob
10
+
11
+ from argparse import ArgumentParser
12
+ import json
13
+
14
+ from evaluation.util import *
15
+ from evaluation.metrics import *
16
+ from tqdm import tqdm
17
+
18
+ parser = ArgumentParser()
19
+ parser.add_argument('--speaker', required=True, type=str)
20
+ parser.add_argument('--post_fix', nargs='+', default=['base'], type=str)
21
+ args = parser.parse_args()
22
+
23
+ speaker = args.speaker
24
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
25
+
26
+ LVD_list = []
27
+ diversity_list = []
28
+
29
+ for aud in tqdm(test_audios):
30
+ base_name = os.path.splitext(aud)[0]
31
+ gt_path = get_full_path(aud, speaker, 'val')
32
+ _, gt_poses, _ = get_gts(gt_path)
33
+ gt_poses = gt_poses[np.newaxis,...]
34
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
35
+ for post_fix in args.post_fix:
36
+ pred_path = base_name + '_'+post_fix+'.json'
37
+ pred_poses = np.array(json.load(open(pred_path)))
38
+ # print(pred_poses.shape)#(B, seq_len, 108)
39
+ pred_poses = cvt25(pred_poses, gt_poses)
40
+ # print(pred_poses.shape)#(B, seq, pose_dim)
41
+
42
+ gt_valid_points = hand_points(gt_poses)
43
+ pred_valid_points = hand_points(pred_poses)
44
+
45
+ lvd = LVD(gt_valid_points, pred_valid_points)
46
+ # div = diversity(pred_valid_points)
47
+
48
+ LVD_list.append(lvd)
49
+ # diversity_list.append(div)
50
+
51
+ # gt_velocity = peak_velocity(gt_valid_points, order=2)
52
+ # pred_velocity = peak_velocity(pred_valid_points, order=2)
53
+
54
+ # gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
55
+ # pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
56
+
57
+ # gt_consistency_list.append(gt_consistency)
58
+ # pred_consistency_list.append(pred_consistency)
59
+
60
+ lvd = np.mean(LVD_list)
61
+ # diversity_list = np.mean(diversity_list)
62
+
63
+ print('LVD:', lvd)
64
+ # print("diversity:", diversity_list)
evaluation/get_quality_samples.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ '''
3
+ import os
4
+ import sys
5
+ sys.path.append(os.getcwd())
6
+
7
+ from glob import glob
8
+
9
+ from argparse import ArgumentParser
10
+ import json
11
+
12
+ from evaluation.util import *
13
+ from evaluation.metrics import *
14
+ from tqdm import tqdm
15
+
16
+ parser = ArgumentParser()
17
+ parser.add_argument('--speaker', required=True, type=str)
18
+ parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
19
+ args = parser.parse_args()
20
+
21
+ speaker = args.speaker
22
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
23
+
24
+ quality_samples={'gt':[]}
25
+ for post_fix in args.post_fix:
26
+ quality_samples[post_fix] = []
27
+
28
+ for aud in tqdm(test_audios):
29
+ base_name = os.path.splitext(aud)[0]
30
+ gt_path = get_full_path(aud, speaker, 'val')
31
+ _, gt_poses, _ = get_gts(gt_path)
32
+ gt_poses = gt_poses[np.newaxis,...]
33
+ gt_valid_points = valid_points(gt_poses)
34
+ # print(gt_valid_points.shape)
35
+ quality_samples['gt'].append(gt_valid_points)
36
+
37
+ for post_fix in args.post_fix:
38
+ pred_path = base_name + '_'+post_fix+'.json'
39
+ pred_poses = np.array(json.load(open(pred_path)))
40
+ # print(pred_poses.shape)#(B, seq_len, 108)
41
+ pred_poses = cvt25(pred_poses, gt_poses)
42
+ # print(pred_poses.shape)#(B, seq, pose_dim)
43
+
44
+ pred_valid_points = valid_points(pred_poses)[0:1]
45
+ quality_samples[post_fix].append(pred_valid_points)
46
+
47
+ quality_samples['gt'] = np.concatenate(quality_samples['gt'], axis=1)
48
+ for post_fix in args.post_fix:
49
+ quality_samples[post_fix] = np.concatenate(quality_samples[post_fix], axis=1)
50
+
51
+ print('gt:', quality_samples['gt'].shape)
52
+ quality_samples['gt'] = quality_samples['gt'].tolist()
53
+ for post_fix in args.post_fix:
54
+ print(post_fix, ':', quality_samples[post_fix].shape)
55
+ quality_samples[post_fix] = quality_samples[post_fix].tolist()
56
+
57
+ save_dir = '../../experiments/'
58
+ os.makedirs(save_dir, exist_ok=True)
59
+ save_name = os.path.join(save_dir, 'quality_samples_%s.json'%(speaker))
60
+ with open(save_name, 'w') as f:
61
+ json.dump(quality_samples, f)
62
+
evaluation/metrics.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Warning: metrics are for reference only, may have limited significance
3
+ '''
4
+ import os
5
+ import sys
6
+ sys.path.append(os.getcwd())
7
+ import numpy as np
8
+ import torch
9
+
10
+ from data_utils.lower_body import rearrange, symmetry
11
+ import torch.nn.functional as F
12
+
13
+ def data_driven_baselines(gt_kps):
14
+ '''
15
+ gt_kps: T, D
16
+ '''
17
+ gt_velocity = np.abs(gt_kps[1:] - gt_kps[:-1])
18
+
19
+ mean= np.mean(gt_velocity, axis=0)[np.newaxis] #(1, D)
20
+ mean = np.mean(np.abs(gt_velocity-mean))
21
+ last_step = gt_kps[1] - gt_kps[0]
22
+ last_step = last_step[np.newaxis] #(1, D)
23
+ last_step = np.mean(np.abs(gt_velocity-last_step))
24
+ return last_step, mean
25
+
26
+ def Batch_LVD(gt_kps, pr_kps, symmetrical, weight):
27
+ if gt_kps.shape[0] > pr_kps.shape[1]:
28
+ length = pr_kps.shape[1]
29
+ else:
30
+ length = gt_kps.shape[0]
31
+ gt_kps = gt_kps[:length]
32
+ pr_kps = pr_kps[:, :length]
33
+ global symmetry
34
+ symmetry = torch.tensor(symmetry).bool()
35
+
36
+ if symmetrical:
37
+ # rearrange for compute symmetric. ns means non-symmetrical joints, ys means symmetrical joints.
38
+ gt_kps = gt_kps[:, rearrange]
39
+ ns_gt_kps = gt_kps[:, ~symmetry]
40
+ ys_gt_kps = gt_kps[:, symmetry]
41
+ ys_gt_kps = ys_gt_kps.reshape(ys_gt_kps.shape[0], -1, 2, 3)
42
+ ns_gt_velocity = (ns_gt_kps[1:] - ns_gt_kps[:-1]).norm(p=2, dim=-1)
43
+ ys_gt_velocity = (ys_gt_kps[1:] - ys_gt_kps[:-1]).norm(p=2, dim=-1)
44
+ left_gt_vel = ys_gt_velocity[:, :, 0].sum(dim=-1)
45
+ right_gt_vel = ys_gt_velocity[:, :, 1].sum(dim=-1)
46
+ move_side = torch.where(left_gt_vel>right_gt_vel, torch.ones(left_gt_vel.shape).cuda(), torch.zeros(left_gt_vel.shape).cuda())
47
+ ys_gt_velocity = torch.mul(ys_gt_velocity[:, :, 0].transpose(0,1), move_side) + torch.mul(ys_gt_velocity[:, :, 1].transpose(0,1), ~move_side.bool())
48
+ ys_gt_velocity = ys_gt_velocity.transpose(0,1)
49
+ gt_velocity = torch.cat([ns_gt_velocity, ys_gt_velocity], dim=1)
50
+
51
+ pr_kps = pr_kps[:, :, rearrange]
52
+ ns_pr_kps = pr_kps[:, :, ~symmetry]
53
+ ys_pr_kps = pr_kps[:, :, symmetry]
54
+ ys_pr_kps = ys_pr_kps.reshape(ys_pr_kps.shape[0], ys_pr_kps.shape[1], -1, 2, 3)
55
+ ns_pr_velocity = (ns_pr_kps[:, 1:] - ns_pr_kps[:, :-1]).norm(p=2, dim=-1)
56
+ ys_pr_velocity = (ys_pr_kps[:, 1:] - ys_pr_kps[:, :-1]).norm(p=2, dim=-1)
57
+ left_pr_vel = ys_pr_velocity[:, :, :, 0].sum(dim=-1)
58
+ right_pr_vel = ys_pr_velocity[:, :, :, 1].sum(dim=-1)
59
+ move_side = torch.where(left_pr_vel > right_pr_vel, torch.ones(left_pr_vel.shape).cuda(),
60
+ torch.zeros(left_pr_vel.shape).cuda())
61
+ ys_pr_velocity = torch.mul(ys_pr_velocity[..., 0].permute(2, 0, 1), move_side) + torch.mul(
62
+ ys_pr_velocity[..., 1].permute(2, 0, 1), ~move_side.long())
63
+ ys_pr_velocity = ys_pr_velocity.permute(1, 2, 0)
64
+ pr_velocity = torch.cat([ns_pr_velocity, ys_pr_velocity], dim=2)
65
+ else:
66
+ gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
67
+ pr_velocity = (pr_kps[:, 1:] - pr_kps[:, :-1]).norm(p=2, dim=-1)
68
+
69
+ if weight:
70
+ w = F.softmax(gt_velocity.sum(dim=1).normal_(), dim=0)
71
+ else:
72
+ w = 1 / gt_velocity.shape[0]
73
+
74
+ v_diff = ((pr_velocity - gt_velocity).abs().sum(dim=-1) * w).sum(dim=-1).mean()
75
+
76
+ return v_diff
77
+
78
+
79
+ def LVD(gt_kps, pr_kps, symmetrical=False, weight=False):
80
+ gt_kps = gt_kps.squeeze()
81
+ pr_kps = pr_kps.squeeze()
82
+ if len(pr_kps.shape) == 4:
83
+ return Batch_LVD(gt_kps, pr_kps, symmetrical, weight)
84
+ # length = np.minimum(gt_kps.shape[0], pr_kps.shape[0])
85
+ length = gt_kps.shape[0]-10
86
+ # gt_kps = gt_kps[25:length]
87
+ # pr_kps = pr_kps[25:length] #(T, D)
88
+ # if pr_kps.shape[0] < gt_kps.shape[0]:
89
+ # pr_kps = np.pad(pr_kps, [[0, int(gt_kps.shape[0]-pr_kps.shape[0])], [0, 0]], mode='constant')
90
+
91
+ gt_velocity = (gt_kps[1:] - gt_kps[:-1]).norm(p=2, dim=-1)
92
+ pr_velocity = (pr_kps[1:] - pr_kps[:-1]).norm(p=2, dim=-1)
93
+
94
+ return (pr_velocity-gt_velocity).abs().sum(dim=-1).mean()
95
+
96
+ def diversity(kps):
97
+ '''
98
+ kps: bs, seq, dim
99
+ '''
100
+ dis_list = []
101
+ #the distance between each pair
102
+ for i in range(kps.shape[0]):
103
+ for j in range(i+1, kps.shape[0]):
104
+ seq_i = kps[i]
105
+ seq_j = kps[j]
106
+
107
+ dis = np.mean(np.abs(seq_i - seq_j))
108
+ dis_list.append(dis)
109
+ return np.mean(dis_list)
evaluation/mode_transition.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.getcwd())
4
+
5
+ from glob import glob
6
+
7
+ from argparse import ArgumentParser
8
+ import json
9
+
10
+ from evaluation.util import *
11
+ from evaluation.metrics import *
12
+ from tqdm import tqdm
13
+
14
+ parser = ArgumentParser()
15
+ parser.add_argument('--speaker', required=True, type=str)
16
+ parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17
+ args = parser.parse_args()
18
+
19
+ speaker = args.speaker
20
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21
+
22
+ precision_list=[]
23
+ recall_list=[]
24
+ accuracy_list=[]
25
+
26
+ for aud in tqdm(test_audios):
27
+ base_name = os.path.splitext(aud)[0]
28
+ gt_path = get_full_path(aud, speaker, 'val')
29
+ _, gt_poses, _ = get_gts(gt_path)
30
+ if gt_poses.shape[0] < 50:
31
+ continue
32
+ gt_poses = gt_poses[np.newaxis,...]
33
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
34
+ for post_fix in args.post_fix:
35
+ pred_path = base_name + '_'+post_fix+'.json'
36
+ pred_poses = np.array(json.load(open(pred_path)))
37
+ # print(pred_poses.shape)#(B, seq_len, 108)
38
+ pred_poses = cvt25(pred_poses, gt_poses)
39
+ # print(pred_poses.shape)#(B, seq, pose_dim)
40
+
41
+ gt_valid_points = valid_points(gt_poses)
42
+ pred_valid_points = valid_points(pred_poses)
43
+
44
+ # print(gt_valid_points.shape, pred_valid_points.shape)
45
+
46
+ gt_mode_transition_seq = mode_transition_seq(gt_valid_points, speaker)#(B, N)
47
+ pred_mode_transition_seq = mode_transition_seq(pred_valid_points, speaker)#(B, N)
48
+
49
+ # baseline = np.random.randint(0, 2, size=pred_mode_transition_seq.shape)
50
+ # pred_mode_transition_seq = baseline
51
+ precision, recall, accuracy = mode_transition_consistency(pred_mode_transition_seq, gt_mode_transition_seq)
52
+ precision_list.append(precision)
53
+ recall_list.append(recall)
54
+ accuracy_list.append(accuracy)
55
+ print(len(precision_list), len(recall_list), len(accuracy_list))
56
+ precision_list = np.mean(precision_list)
57
+ recall_list = np.mean(recall_list)
58
+ accuracy_list = np.mean(accuracy_list)
59
+
60
+ print('precision, recall, accu:', precision_list, recall_list, accuracy_list)
evaluation/peak_velocity.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.getcwd())
4
+
5
+ from glob import glob
6
+
7
+ from argparse import ArgumentParser
8
+ import json
9
+
10
+ from evaluation.util import *
11
+ from evaluation.metrics import *
12
+ from tqdm import tqdm
13
+
14
+ parser = ArgumentParser()
15
+ parser.add_argument('--speaker', required=True, type=str)
16
+ parser.add_argument('--post_fix', nargs='+', default=['paper_model'], type=str)
17
+ args = parser.parse_args()
18
+
19
+ speaker = args.speaker
20
+ test_audios = sorted(glob('pose_dataset/videos/test_audios/%s/*.wav'%(speaker)))
21
+
22
+ gt_consistency_list=[]
23
+ pred_consistency_list=[]
24
+
25
+ for aud in tqdm(test_audios):
26
+ base_name = os.path.splitext(aud)[0]
27
+ gt_path = get_full_path(aud, speaker, 'val')
28
+ _, gt_poses, _ = get_gts(gt_path)
29
+ gt_poses = gt_poses[np.newaxis,...]
30
+ # print(gt_poses.shape)#(seq_len, 135*2)pose, lhand, rhand, face
31
+ for post_fix in args.post_fix:
32
+ pred_path = base_name + '_'+post_fix+'.json'
33
+ pred_poses = np.array(json.load(open(pred_path)))
34
+ # print(pred_poses.shape)#(B, seq_len, 108)
35
+ pred_poses = cvt25(pred_poses, gt_poses)
36
+ # print(pred_poses.shape)#(B, seq, pose_dim)
37
+
38
+ gt_valid_points = hand_points(gt_poses)
39
+ pred_valid_points = hand_points(pred_poses)
40
+
41
+ gt_velocity = peak_velocity(gt_valid_points, order=2)
42
+ pred_velocity = peak_velocity(pred_valid_points, order=2)
43
+
44
+ gt_consistency = velocity_consistency(gt_velocity, pred_velocity)
45
+ pred_consistency = velocity_consistency(pred_velocity, gt_velocity)
46
+
47
+ gt_consistency_list.append(gt_consistency)
48
+ pred_consistency_list.append(pred_consistency)
49
+
50
+ gt_consistency_list = np.concatenate(gt_consistency_list)
51
+ pred_consistency_list = np.concatenate(pred_consistency_list)
52
+
53
+ print(gt_consistency_list.max(), gt_consistency_list.min())
54
+ print(pred_consistency_list.max(), pred_consistency_list.min())
55
+ print(np.mean(gt_consistency_list), np.mean(pred_consistency_list))
56
+ print(np.std(gt_consistency_list), np.std(pred_consistency_list))
57
+
58
+ draw_cdf(gt_consistency_list, save_name='%s_gt.jpg'%(speaker), color='slateblue')
59
+ draw_cdf(pred_consistency_list, save_name='%s_pred.jpg'%(speaker), color='lightskyblue')
60
+
61
+ to_excel(gt_consistency_list, '%s_gt.xlsx'%(speaker))
62
+ to_excel(pred_consistency_list, '%s_pred.xlsx'%(speaker))
63
+
64
+ np.save('%s_gt.npy'%(speaker), gt_consistency_list)
65
+ np.save('%s_pred.npy'%(speaker), pred_consistency_list)
evaluation/util.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ import numpy as np
4
+ import json
5
+ from matplotlib import pyplot as plt
6
+ import pandas as pd
7
+ def get_gts(clip):
8
+ '''
9
+ clip: abs path to the clip dir
10
+ '''
11
+ keypoints_files = sorted(glob(os.path.join(clip, 'keypoints_new/person_1')+'/*.json'))
12
+
13
+ upper_body_points = list(np.arange(0, 25))
14
+ poses = []
15
+ confs = []
16
+ neck_to_nose_len = []
17
+ mean_position = []
18
+ for kp_file in keypoints_files:
19
+ kp_load = json.load(open(kp_file, 'r'))['people'][0]
20
+ posepts = kp_load['pose_keypoints_2d']
21
+ lhandpts = kp_load['hand_left_keypoints_2d']
22
+ rhandpts = kp_load['hand_right_keypoints_2d']
23
+ facepts = kp_load['face_keypoints_2d']
24
+
25
+ neck = np.array(posepts).reshape(-1,3)[1]
26
+ nose = np.array(posepts).reshape(-1,3)[0]
27
+ x_offset = abs(neck[0]-nose[0])
28
+ y_offset = abs(neck[1]-nose[1])
29
+ neck_to_nose_len.append(y_offset)
30
+ mean_position.append([neck[0],neck[1]])
31
+
32
+ keypoints=np.array(posepts+lhandpts+rhandpts+facepts).reshape(-1,3)[:,:2]
33
+
34
+ upper_body = keypoints[upper_body_points, :]
35
+ hand_points = keypoints[25:, :]
36
+ keypoints = np.vstack([upper_body, hand_points])
37
+
38
+ poses.append(keypoints)
39
+
40
+ if len(neck_to_nose_len) > 0:
41
+ scale_factor = np.mean(neck_to_nose_len)
42
+ else:
43
+ raise ValueError(clip)
44
+ mean_position = np.mean(np.array(mean_position), axis=0)
45
+
46
+ unlocalized_poses = np.array(poses).copy()
47
+ localized_poses = []
48
+ for i in range(len(poses)):
49
+ keypoints = poses[i]
50
+ neck = keypoints[1].copy()
51
+
52
+ keypoints[:, 0] = (keypoints[:, 0] - neck[0]) / scale_factor
53
+ keypoints[:, 1] = (keypoints[:, 1] - neck[1]) / scale_factor
54
+ localized_poses.append(keypoints.reshape(-1))
55
+
56
+ localized_poses=np.array(localized_poses)
57
+ return unlocalized_poses, localized_poses, (scale_factor, mean_position)
58
+
59
+ def get_full_path(wav_name, speaker, split):
60
+ '''
61
+ get clip path from aud file
62
+ '''
63
+ wav_name = os.path.basename(wav_name)
64
+ wav_name = os.path.splitext(wav_name)[0]
65
+ clip_name, vid_name = wav_name[:10], wav_name[11:]
66
+
67
+ full_path = os.path.join('pose_dataset/videos/', speaker, 'clips', vid_name, 'images/half', split, clip_name)
68
+
69
+ assert os.path.isdir(full_path), full_path
70
+
71
+ return full_path
72
+
73
+ def smooth(res):
74
+ '''
75
+ res: (B, seq_len, pose_dim)
76
+ '''
77
+ window = [res[:, 7, :], res[:, 8, :], res[:, 9, :], res[:, 10, :], res[:, 11, :], res[:, 12, :]]
78
+ w_size=7
79
+ for i in range(10, res.shape[1]-3):
80
+ window.append(res[:, i+3, :])
81
+ if len(window) > w_size:
82
+ window = window[1:]
83
+
84
+ if (i%25) in [22, 23, 24, 0, 1, 2, 3]:
85
+ res[:, i, :] = np.mean(window, axis=1)
86
+
87
+ return res
88
+
89
+ def cvt25(pred_poses, gt_poses=None):
90
+ '''
91
+ gt_poses: (1, seq_len, 270), 135 *2
92
+ pred_poses: (B, seq_len, 108), 54 * 2
93
+ '''
94
+ if gt_poses is None:
95
+ gt_poses = np.zeros_like(pred_poses)
96
+ else:
97
+ gt_poses = gt_poses.repeat(pred_poses.shape[0], axis=0)
98
+
99
+ length = min(pred_poses.shape[1], gt_poses.shape[1])
100
+ pred_poses = pred_poses[:, :length, :]
101
+ gt_poses = gt_poses[:, :length, :]
102
+ gt_poses = gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1, 2)
103
+ pred_poses = pred_poses.reshape(pred_poses.shape[0], pred_poses.shape[1], -1, 2)
104
+
105
+ gt_poses[:, :, [1, 2, 3, 4, 5, 6, 7], :] = pred_poses[:, :, 1:8, :]
106
+ gt_poses[:, :, 25:25+21+21, :] = pred_poses[:, :, 12:, :]
107
+
108
+ return gt_poses.reshape(gt_poses.shape[0], gt_poses.shape[1], -1)
109
+
110
+ def hand_points(seq):
111
+ '''
112
+ seq: (B, seq_len, 135*2)
113
+ hands only
114
+ '''
115
+ hand_idx = [1, 2, 3, 4,5 ,6,7] + list(range(25, 25+21+21))
116
+ seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
117
+ return seq[:, :, hand_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
118
+
119
+ def valid_points(seq):
120
+ '''
121
+ hands with some head points
122
+ '''
123
+ valid_idx = [0, 1, 2, 3, 4,5 ,6,7, 8, 9, 10, 11] + list(range(25, 25+21+21))
124
+ seq = seq.reshape(seq.shape[0], seq.shape[1], -1, 2)
125
+
126
+ seq = seq[:, :, valid_idx, :].reshape(seq.shape[0], seq.shape[1], -1)
127
+ assert seq.shape[-1] == 108, seq.shape
128
+ return seq
129
+
130
+ def draw_cdf(seq, save_name='cdf.jpg', color='slatebule'):
131
+ plt.figure()
132
+ plt.hist(seq, bins=100, range=(0, 100), color=color)
133
+ plt.savefig(save_name)
134
+
135
+ def to_excel(seq, save_name='res.xlsx'):
136
+ '''
137
+ seq: (T)
138
+ '''
139
+ df = pd.DataFrame(seq)
140
+ writer = pd.ExcelWriter(save_name)
141
+ df.to_excel(writer, 'sheet1')
142
+ writer.save()
143
+ writer.close()
144
+
145
+
146
+ if __name__ == '__main__':
147
+ random_data = np.random.randint(0, 10, 100)
148
+ draw_cdf(random_data)
losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .losses import *
losses/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (174 Bytes). View file
 
losses/__pycache__/losses.cpython-37.pyc ADDED
Binary file (3.53 kB). View file
 
losses/losses.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.getcwd())
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+
11
+ class KeypointLoss(nn.Module):
12
+ def __init__(self):
13
+ super(KeypointLoss, self).__init__()
14
+
15
+ def forward(self, pred_seq, gt_seq, gt_conf=None):
16
+ #pred_seq: (B, C, T)
17
+ if gt_conf is not None:
18
+ gt_conf = gt_conf >= 0.01
19
+ return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean')
20
+ else:
21
+ return F.mse_loss(pred_seq, gt_seq)
22
+
23
+
24
+ class KLLoss(nn.Module):
25
+ def __init__(self, kl_tolerance):
26
+ super(KLLoss, self).__init__()
27
+ self.kl_tolerance = kl_tolerance
28
+
29
+ def forward(self, mu, var, mul=1):
30
+ kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64
31
+ kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1)
32
+ # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1)
33
+ if self.kl_tolerance is not None:
34
+ # above_line = kld_loss[kld_loss > self.kl_tolerance]
35
+ # if len(above_line) > 0:
36
+ # kld_loss = torch.mean(kld_loss)
37
+ # else:
38
+ # kld_loss = 0
39
+ kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda'))
40
+ # else:
41
+ kld_loss = torch.mean(kld_loss)
42
+ return kld_loss
43
+
44
+
45
+ class L2KLLoss(nn.Module):
46
+ def __init__(self, kl_tolerance):
47
+ super(L2KLLoss, self).__init__()
48
+ self.kl_tolerance = kl_tolerance
49
+
50
+ def forward(self, x):
51
+ # TODO: check
52
+ kld_loss = torch.sum(x ** 2, dim=1)
53
+ if self.kl_tolerance is not None:
54
+ above_line = kld_loss[kld_loss > self.kl_tolerance]
55
+ if len(above_line) > 0:
56
+ kld_loss = torch.mean(kld_loss)
57
+ else:
58
+ kld_loss = 0
59
+ else:
60
+ kld_loss = torch.mean(kld_loss)
61
+ return kld_loss
62
+
63
+ class L2RegLoss(nn.Module):
64
+ def __init__(self):
65
+ super(L2RegLoss, self).__init__()
66
+
67
+ def forward(self, x):
68
+ #TODO: check
69
+ return torch.sum(x**2)
70
+
71
+
72
+ class L2Loss(nn.Module):
73
+ def __init__(self):
74
+ super(L2Loss, self).__init__()
75
+
76
+ def forward(self, x):
77
+ # TODO: check
78
+ return torch.sum(x ** 2)
79
+
80
+
81
+ class AudioLoss(nn.Module):
82
+ def __init__(self):
83
+ super(AudioLoss, self).__init__()
84
+
85
+ def forward(self, dynamics, gt_poses):
86
+ #pay attention, normalized
87
+ mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1)
88
+ gt = gt_poses - mean
89
+ return F.mse_loss(dynamics, gt)
90
+
91
+ L1Loss = nn.L1Loss
nets/LS3DCG.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ not exactly the same as the official repo but the results are good
3
+ '''
4
+ import sys
5
+ import os
6
+
7
+ from data_utils.lower_body import c_index_3d, c_index_6d
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import torch.nn.functional as F
16
+ import math
17
+
18
+ from nets.base import TrainWrapperBaseClass
19
+ from nets.layers import SeqEncoder1D
20
+ from losses import KeypointLoss, L1Loss, KLLoss
21
+ from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta
22
+ from nets.utils import denormalize
23
+
24
+ class Conv1d_tf(nn.Conv1d):
25
+ """
26
+ Conv1d with the padding behavior from TF
27
+ modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py
28
+ """
29
+
30
+ def __init__(self, *args, **kwargs):
31
+ super(Conv1d_tf, self).__init__(*args, **kwargs)
32
+ self.padding = kwargs.get("padding", "same")
33
+
34
+ def _compute_padding(self, input, dim):
35
+ input_size = input.size(dim + 2)
36
+ filter_size = self.weight.size(dim + 2)
37
+ effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1
38
+ out_size = (input_size + self.stride[dim] - 1) // self.stride[dim]
39
+ total_padding = max(
40
+ 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size
41
+ )
42
+ additional_padding = int(total_padding % 2 != 0)
43
+
44
+ return additional_padding, total_padding
45
+
46
+ def forward(self, input):
47
+ if self.padding == "VALID":
48
+ return F.conv1d(
49
+ input,
50
+ self.weight,
51
+ self.bias,
52
+ self.stride,
53
+ padding=0,
54
+ dilation=self.dilation,
55
+ groups=self.groups,
56
+ )
57
+ rows_odd, padding_rows = self._compute_padding(input, dim=0)
58
+ if rows_odd:
59
+ input = F.pad(input, [0, rows_odd])
60
+
61
+ return F.conv1d(
62
+ input,
63
+ self.weight,
64
+ self.bias,
65
+ self.stride,
66
+ padding=(padding_rows // 2),
67
+ dilation=self.dilation,
68
+ groups=self.groups,
69
+ )
70
+
71
+
72
+ def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'):
73
+ if k is None and s is None:
74
+ if not downsample:
75
+ k = 3
76
+ s = 1
77
+ else:
78
+ k = 4
79
+ s = 2
80
+
81
+ if type == '1d':
82
+ conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
83
+ if norm == 'bn':
84
+ norm_block = nn.BatchNorm1d(out_channels)
85
+ elif norm == 'ln':
86
+ norm_block = nn.LayerNorm(out_channels)
87
+ elif type == '2d':
88
+ conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding)
89
+ norm_block = nn.BatchNorm2d(out_channels)
90
+ else:
91
+ assert False
92
+
93
+ return nn.Sequential(
94
+ conv_block,
95
+ norm_block,
96
+ nn.LeakyReLU(0.2, True)
97
+ )
98
+
99
+ class Decoder(nn.Module):
100
+ def __init__(self, in_ch, out_ch):
101
+ super(Decoder, self).__init__()
102
+ self.up1 = nn.Sequential(
103
+ ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2),
104
+ ConvNormRelu(in_ch // 2, in_ch // 2),
105
+ nn.Upsample(scale_factor=2, mode='nearest')
106
+ )
107
+ self.up2 = nn.Sequential(
108
+ ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4),
109
+ ConvNormRelu(in_ch // 4, in_ch // 4),
110
+ nn.Upsample(scale_factor=2, mode='nearest')
111
+ )
112
+ self.up3 = nn.Sequential(
113
+ ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8),
114
+ ConvNormRelu(in_ch // 8, in_ch // 8),
115
+ nn.Conv1d(in_ch // 8, out_ch, 1, 1)
116
+ )
117
+
118
+ def forward(self, x, x1, x2, x3):
119
+ x = F.interpolate(x, x3.shape[2])
120
+ x = torch.cat([x, x3], dim=1)
121
+ x = self.up1(x)
122
+ x = F.interpolate(x, x2.shape[2])
123
+ x = torch.cat([x, x2], dim=1)
124
+ x = self.up2(x)
125
+ x = F.interpolate(x, x1.shape[2])
126
+ x = torch.cat([x, x1], dim=1)
127
+ x = self.up3(x)
128
+ return x
129
+
130
+
131
+ class EncoderDecoder(nn.Module):
132
+ def __init__(self, n_frames, each_dim):
133
+ super().__init__()
134
+ self.n_frames = n_frames
135
+
136
+ self.down1 = nn.Sequential(
137
+ ConvNormRelu(64, 64, '1d', False),
138
+ ConvNormRelu(64, 128, '1d', False),
139
+ )
140
+ self.down2 = nn.Sequential(
141
+ ConvNormRelu(128, 128, '1d', False),
142
+ ConvNormRelu(128, 256, '1d', False),
143
+ )
144
+ self.down3 = nn.Sequential(
145
+ ConvNormRelu(256, 256, '1d', False),
146
+ ConvNormRelu(256, 512, '1d', False),
147
+ )
148
+ self.down4 = nn.Sequential(
149
+ ConvNormRelu(512, 512, '1d', False),
150
+ ConvNormRelu(512, 1024, '1d', False),
151
+ )
152
+
153
+ self.down = nn.MaxPool1d(kernel_size=2)
154
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
155
+
156
+ self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3])
157
+ self.body_decoder = Decoder(1024, each_dim[1])
158
+ self.hand_decoder = Decoder(1024, each_dim[2])
159
+
160
+ def forward(self, spectrogram, time_steps=None):
161
+ if time_steps is None:
162
+ time_steps = self.n_frames
163
+
164
+ x1 = self.down1(spectrogram)
165
+ x = self.down(x1)
166
+ x2 = self.down2(x)
167
+ x = self.down(x2)
168
+ x3 = self.down3(x)
169
+ x = self.down(x3)
170
+ x = self.down4(x)
171
+ x = self.up(x)
172
+
173
+ face = self.face_decoder(x, x1, x2, x3)
174
+ body = self.body_decoder(x, x1, x2, x3)
175
+ hand = self.hand_decoder(x, x1, x2, x3)
176
+
177
+ return face, body, hand
178
+
179
+
180
+ class Generator(nn.Module):
181
+ def __init__(self,
182
+ each_dim,
183
+ training=False,
184
+ device=None
185
+ ):
186
+ super().__init__()
187
+
188
+ self.training = training
189
+ self.device = device
190
+
191
+ self.encoderdecoder = EncoderDecoder(15, each_dim)
192
+
193
+ def forward(self, in_spec, time_steps=None):
194
+ if time_steps is not None:
195
+ self.gen_length = time_steps
196
+
197
+ face, body, hand = self.encoderdecoder(in_spec)
198
+ out = torch.cat([face, body, hand], dim=1)
199
+ out = out.transpose(1, 2)
200
+
201
+ return out
202
+
203
+
204
+ class Discriminator(nn.Module):
205
+ def __init__(self, input_dim):
206
+ super().__init__()
207
+ self.net = nn.Sequential(
208
+ ConvNormRelu(input_dim, 128, '1d'),
209
+ ConvNormRelu(128, 256, '1d'),
210
+ nn.MaxPool1d(kernel_size=2),
211
+ ConvNormRelu(256, 256, '1d'),
212
+ ConvNormRelu(256, 512, '1d'),
213
+ nn.MaxPool1d(kernel_size=2),
214
+ ConvNormRelu(512, 512, '1d'),
215
+ ConvNormRelu(512, 1024, '1d'),
216
+ nn.MaxPool1d(kernel_size=2),
217
+ nn.Conv1d(1024, 1, 1, 1),
218
+ nn.Sigmoid()
219
+ )
220
+
221
+ def forward(self, x):
222
+ x = x.transpose(1, 2)
223
+
224
+ out = self.net(x)
225
+ return out
226
+
227
+
228
+ class TrainWrapper(TrainWrapperBaseClass):
229
+ def __init__(self, args, config) -> None:
230
+ self.args = args
231
+ self.config = config
232
+ self.device = torch.device(self.args.gpu)
233
+ self.global_step = 0
234
+ self.convert_to_6d = self.config.Data.pose.convert_to_6d
235
+ self.init_params()
236
+
237
+ self.generator = Generator(
238
+ each_dim=self.each_dim,
239
+ training=not self.args.infer,
240
+ device=self.device,
241
+ ).to(self.device)
242
+ self.discriminator = Discriminator(
243
+ input_dim=self.each_dim[1] + self.each_dim[2] + 64
244
+ ).to(self.device)
245
+ if self.convert_to_6d:
246
+ self.c_index = c_index_6d
247
+ else:
248
+ self.c_index = c_index_3d
249
+ self.MSELoss = KeypointLoss().to(self.device)
250
+ self.L1Loss = L1Loss().to(self.device)
251
+ super().__init__(args, config)
252
+
253
+ def init_params(self):
254
+ scale = 1
255
+
256
+ global_orient = round(0 * scale)
257
+ leye_pose = reye_pose = round(0 * scale)
258
+ jaw_pose = round(3 * scale)
259
+ body_pose = round((63 - 24) * scale)
260
+ left_hand_pose = right_hand_pose = round(45 * scale)
261
+
262
+ expression = 100
263
+
264
+ b_j = 0
265
+ jaw_dim = jaw_pose
266
+ b_e = b_j + jaw_dim
267
+ eye_dim = leye_pose + reye_pose
268
+ b_b = b_e + eye_dim
269
+ body_dim = global_orient + body_pose
270
+ b_h = b_b + body_dim
271
+ hand_dim = left_hand_pose + right_hand_pose
272
+ b_f = b_h + hand_dim
273
+ face_dim = expression
274
+
275
+ self.dim_list = [b_j, b_e, b_b, b_h, b_f]
276
+ self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim
277
+ self.pose = int(self.full_dim / round(3 * scale))
278
+ self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim]
279
+
280
+ def __call__(self, bat):
281
+ assert (not self.args.infer), "infer mode"
282
+ self.global_step += 1
283
+
284
+ loss_dict = {}
285
+
286
+ aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32)
287
+ expression = bat['expression'].to(self.device).to(torch.float32)
288
+ jaw = poses[:, :3, :]
289
+ poses = poses[:, self.c_index, :]
290
+
291
+ pred = self.generator(in_spec=aud)
292
+
293
+ D_loss, D_loss_dict = self.get_loss(
294
+ pred_poses=pred.detach(),
295
+ gt_poses=poses,
296
+ aud=aud,
297
+ mode='training_D',
298
+ )
299
+
300
+ self.discriminator_optimizer.zero_grad()
301
+ D_loss.backward()
302
+ self.discriminator_optimizer.step()
303
+
304
+ G_loss, G_loss_dict = self.get_loss(
305
+ pred_poses=pred,
306
+ gt_poses=poses,
307
+ aud=aud,
308
+ expression=expression,
309
+ jaw=jaw,
310
+ mode='training_G',
311
+ )
312
+ self.generator_optimizer.zero_grad()
313
+ G_loss.backward()
314
+ self.generator_optimizer.step()
315
+
316
+ total_loss = None
317
+ loss_dict = {}
318
+ for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()):
319
+ loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0)
320
+
321
+ return total_loss, loss_dict
322
+
323
+ def get_loss(self,
324
+ pred_poses,
325
+ gt_poses,
326
+ aud=None,
327
+ jaw=None,
328
+ expression=None,
329
+ mode='training_G',
330
+ ):
331
+ loss_dict = {}
332
+ aud = aud.transpose(1, 2)
333
+ gt_poses = gt_poses.transpose(1, 2)
334
+ gt_aud = torch.cat([gt_poses, aud], dim=2)
335
+ pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2)
336
+
337
+ if mode == 'training_D':
338
+ dis_real = self.discriminator(gt_aud)
339
+ dis_fake = self.discriminator(pred_aud)
340
+ dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss(
341
+ torch.zeros_like(dis_fake).to(self.device), dis_fake)
342
+ loss_dict['dis'] = dis_error
343
+
344
+ return dis_error, loss_dict
345
+ elif mode == 'training_G':
346
+ jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2))
347
+ face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2))
348
+ body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39])
349
+ hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:])
350
+ l1_loss = jaw_loss + face_loss + body_loss + hand_loss
351
+
352
+ dis_output = self.discriminator(pred_aud)
353
+ gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output)
354
+ gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error
355
+
356
+ loss_dict['gen'] = gen_error
357
+ loss_dict['jaw_loss'] = jaw_loss
358
+ loss_dict['face_loss'] = face_loss
359
+ loss_dict['body_loss'] = body_loss
360
+ loss_dict['hand_loss'] = hand_loss
361
+ return gen_loss, loss_dict
362
+ else:
363
+ raise ValueError(mode)
364
+
365
+ def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs):
366
+ output = []
367
+ assert self.args.infer, "train mode"
368
+ self.generator.eval()
369
+
370
+ if self.config.Data.pose.normalization:
371
+ assert norm_stats is not None
372
+ data_mean = norm_stats[0]
373
+ data_std = norm_stats[1]
374
+
375
+ pre_length = self.config.Data.pose.pre_pose_length
376
+ generate_length = self.config.Data.pose.generate_length
377
+ # assert pre_length == initial_pose.shape[-1]
378
+ # pre_poses = initial_pose.permute(0, 2, 1).to(self.device).to(torch.float32)
379
+ # B = pre_poses.shape[0]
380
+
381
+ aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0)
382
+ num_poses_to_generate = aud_feat.shape[-1]
383
+ aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0)
384
+ aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device)
385
+
386
+ with torch.no_grad():
387
+ pred_poses = self.generator(aud_feat)
388
+ pred_poses = pred_poses.cpu().numpy()
389
+ output = pred_poses.squeeze()
390
+
391
+ return output
392
+
393
+ def generate(self, aud, id):
394
+ self.generator.eval()
395
+ pred_poses = self.generator(aud)
396
+ return pred_poses
397
+
398
+
399
+ if __name__ == '__main__':
400
+ from trainer.options import parse_args
401
+
402
+ parser = parse_args()
403
+ args = parser.parse_args(
404
+ ['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64',
405
+ '--infer'])
406
+
407
+ generator = TrainWrapper(args)
408
+
409
+ aud_fn = '../sample_audio/jon.wav'
410
+ initial_pose = torch.randn(64, 108, 4)
411
+ norm_stats = (np.random.randn(108), np.random.randn(108))
412
+ output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats)
413
+
414
+ print(output.shape)
nets/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .smplx_face import TrainWrapper as s2g_face
2
+ from .smplx_body_vq import TrainWrapper as s2g_body_vq
3
+ from .smplx_body_pixel import TrainWrapper as s2g_body_pixel
4
+ from .body_ae import TrainWrapper as s2g_body_ae
5
+ from .LS3DCG import TrainWrapper as LS3DCG
6
+ from .base import TrainWrapperBaseClass
7
+
8
+ from .utils import normalize, denormalize
nets/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (407 Bytes). View file
 
nets/__pycache__/base.cpython-37.pyc ADDED
Binary file (2.29 kB). View file
 
nets/__pycache__/init_model.cpython-37.pyc ADDED
Binary file (460 Bytes). View file
 
nets/__pycache__/layers.cpython-37.pyc ADDED
Binary file (22.7 kB). View file
 
nets/__pycache__/smplx_body_pixel.cpython-37.pyc ADDED
Binary file (9.55 kB). View file
 
nets/__pycache__/smplx_body_vq.cpython-37.pyc ADDED
Binary file (7.89 kB). View file