thanks to damo ❤
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- damo/dreamtalk/.mdl +0 -0
- damo/dreamtalk/.msc +0 -0
- damo/dreamtalk/README.md +131 -0
- damo/dreamtalk/checkpoints/denoising_network.pth +3 -0
- damo/dreamtalk/checkpoints/renderer.pt +3 -0
- damo/dreamtalk/configs/default.py +91 -0
- damo/dreamtalk/configuration.json +11 -0
- damo/dreamtalk/core/networks/__init__.py +14 -0
- damo/dreamtalk/core/networks/diffusion_net.py +340 -0
- damo/dreamtalk/core/networks/diffusion_util.py +131 -0
- damo/dreamtalk/core/networks/disentangle_decoder.py +240 -0
- damo/dreamtalk/core/networks/dynamic_conv.py +156 -0
- damo/dreamtalk/core/networks/dynamic_fc_decoder.py +178 -0
- damo/dreamtalk/core/networks/dynamic_linear.py +50 -0
- damo/dreamtalk/core/networks/generator.py +309 -0
- damo/dreamtalk/core/networks/mish.py +51 -0
- damo/dreamtalk/core/networks/self_attention_pooling.py +53 -0
- damo/dreamtalk/core/networks/transformer.py +293 -0
- damo/dreamtalk/core/utils.py +456 -0
- damo/dreamtalk/data/audio/German1.wav +0 -0
- damo/dreamtalk/data/audio/German2.wav +0 -0
- damo/dreamtalk/data/audio/German3.wav +0 -0
- damo/dreamtalk/data/audio/German4.wav +0 -0
- damo/dreamtalk/data/audio/acknowledgement_chinese.m4a +0 -0
- damo/dreamtalk/data/audio/acknowledgement_english.m4a +0 -0
- damo/dreamtalk/data/audio/chinese1_haierlizhi.wav +0 -0
- damo/dreamtalk/data/audio/chinese2_guanyu.wav +0 -0
- damo/dreamtalk/data/audio/french1.wav +0 -0
- damo/dreamtalk/data/audio/french2.wav +0 -0
- damo/dreamtalk/data/audio/french3.wav +0 -0
- damo/dreamtalk/data/audio/italian1.wav +0 -0
- damo/dreamtalk/data/audio/italian2.wav +0 -0
- damo/dreamtalk/data/audio/italian3.wav +0 -0
- damo/dreamtalk/data/audio/japan1.wav +0 -0
- damo/dreamtalk/data/audio/japan2.wav +0 -0
- damo/dreamtalk/data/audio/japan3.wav +0 -0
- damo/dreamtalk/data/audio/korean1.wav +0 -0
- damo/dreamtalk/data/audio/korean2.wav +0 -0
- damo/dreamtalk/data/audio/korean3.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_narrative.wav +0 -0
- damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav +0 -0
- damo/dreamtalk/data/audio/out_of_domain_narrative.wav +0 -0
- damo/dreamtalk/data/audio/spanish1.wav +0 -0
- damo/dreamtalk/data/audio/spanish2.wav +0 -0
- damo/dreamtalk/data/audio/spanish3.wav +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
damo/dreamtalk/data/pose/RichardShelby_front_neutral_level1_001.mat filter=lfs diff=lfs merge=lfs -text
|
37 |
+
damo/dreamtalk/media/teaser.gif filter=lfs diff=lfs merge=lfs -text
|
damo/dreamtalk/.mdl
ADDED
Binary file (37 Bytes). View file
|
|
damo/dreamtalk/.msc
ADDED
Binary file (12.5 kB). View file
|
|
damo/dreamtalk/README.md
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models
|
2 |
+
|
3 |
+
<a href='https://dreamtalk-project.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2312.09767'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/VF4vlE6ZqWQ)
|
4 |
+
|
5 |
+
DreamTalk is a diffusion-based audio-driven expressive talking head generation framework that can produce high-quality talking head videos across diverse speaking styles. DreamTalk exhibits robust performance with a diverse array of inputs, including songs, speech in multiple languages, noisy audio, and out-of-domain portraits.
|
6 |
+
|
7 |
+
![figure1](media/teaser.gif "teaser")
|
8 |
+
|
9 |
+
## News
|
10 |
+
- __[2023.12]__ Release inference code and pretrained checkpoint.
|
11 |
+
|
12 |
+
## 安装依赖
|
13 |
+
```
|
14 |
+
pip install dlib
|
15 |
+
```
|
16 |
+
|
17 |
+
## Installation
|
18 |
+
|
19 |
+
我在`output_video`文件夹下已经放入了一些生成好的文件, 可运行下面脚本, 对比下结果.
|
20 |
+
|
21 |
+
```python
|
22 |
+
from modelscope.utils.constant import Tasks
|
23 |
+
from modelscope.pipelines import pipeline
|
24 |
+
import os
|
25 |
+
|
26 |
+
pipe = pipeline(task=Tasks.text_to_video_synthesis, model='damo/dreamtalk',
|
27 |
+
style_clip_path="data/style_clip/3DMM/M030_front_surprised_level3_001.mat",
|
28 |
+
pose_path="data/pose/RichardShelby_front_neutral_level1_001.mat",
|
29 |
+
model_revision='master'
|
30 |
+
)
|
31 |
+
# ,model_revision='master')
|
32 |
+
inputs={
|
33 |
+
"output_name": "songbie_yk_male",
|
34 |
+
"wav_path": "data/audio/acknowledgement_english.m4a",
|
35 |
+
"img_crop": True,
|
36 |
+
"image_path": "data/src_img/uncropped/male_face.png",
|
37 |
+
"max_gen_len": 20
|
38 |
+
}
|
39 |
+
pipe(input=inputs)
|
40 |
+
print("end")
|
41 |
+
```
|
42 |
+
|
43 |
+
`wav_path` 为输入音频路径;
|
44 |
+
|
45 |
+
`style_clip_path` 为表情参考文件,从带情绪的视频中提取, 可用来控制生成视频的表情;
|
46 |
+
|
47 |
+
`pose_path` 为头部运动参考文件, 从视频中提取,可用来控制生成视频的头部运动;
|
48 |
+
|
49 |
+
`image_path` 为说话人肖像, 最好是正脸, 理论支持任意分辨率输入, 会被裁减成$256\times256$ 分辨率;
|
50 |
+
|
51 |
+
`max_gen_len` 为最长视频生成时长, 单位为秒, 如果输入音频长于这个时间则会被截断;
|
52 |
+
|
53 |
+
`output_name`为输出名称, 最终生成的视频会在 `output_video` 文件夹下, 中间结果会在 `tmp` 文件夹下.
|
54 |
+
|
55 |
+
如果输入图片已经为$256\times256$ 而且大小合适无需裁剪, 则可使用`disable_img_crop`跳过裁剪步骤, 如下:
|
56 |
+
|
57 |
+
## Download Checkpoints
|
58 |
+
Download the checkpoint of the denoising network:
|
59 |
+
* [ModelScope](tmp)
|
60 |
+
|
61 |
+
|
62 |
+
Download the checkpoint of the renderer:
|
63 |
+
* [ModelScope](tmp)
|
64 |
+
|
65 |
+
Put the downloaded checkpoints into `checkpoints` folder.
|
66 |
+
|
67 |
+
|
68 |
+
## Inference
|
69 |
+
Run the script:
|
70 |
+
|
71 |
+
```
|
72 |
+
python inference_for_demo_video.py \
|
73 |
+
--wav_path data/audio/acknowledgement_english.m4a \
|
74 |
+
--style_clip_path data/style_clip/3DMM/M030_front_neutral_level1_001.mat \
|
75 |
+
--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
|
76 |
+
--image_path data/src_img/uncropped/male_face.png \
|
77 |
+
--cfg_scale 1.0 \
|
78 |
+
--max_gen_len 30 \
|
79 |
+
--output_name acknowledgement_english@M030_front_neutral_level1_001@male_face
|
80 |
+
```
|
81 |
+
|
82 |
+
`wav_path` specifies the input audio. The input audio file extensions such as wav, mp3, m4a, and mp4 (video with sound) should all be compatible.
|
83 |
+
|
84 |
+
`style_clip_path` specifies the reference speaking style and `pose_path` specifies head pose. They are 3DMM paramenter sequences extracted from reference videos. You can follow [PIRenderer](https://github.com/RenYurui/PIRender) to extract 3DMM parameters from your own videos. Note that the video frame rate should be 25 FPS. Besides, videos used for head pose reference should be first cropped to $256\times256$ using scripts in [FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
|
85 |
+
|
86 |
+
`image_path` specifies the input portrait. Its resolution should be larger than $256\times256$. Frontal portraits, with the face directly facing forward and not tilted to one side, usually achieve satisfactory results. The input portrait will be cropped to $256\times256$. If your portrait is already cropped to $256\times256$ and you want to disable cropping, use option `--disable_img_crop` like this:
|
87 |
+
|
88 |
+
```
|
89 |
+
python inference_for_demo_video.py \
|
90 |
+
--wav_path data/audio/acknowledgement_chinese.m4a \
|
91 |
+
--style_clip_path data/style_clip/3DMM/M030_front_surprised_level3_001.mat \
|
92 |
+
--pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
|
93 |
+
--image_path data/src_img/cropped/zp1.png \
|
94 |
+
--disable_img_crop \
|
95 |
+
--cfg_scale 1.0 \
|
96 |
+
--max_gen_len 30 \
|
97 |
+
--output_name acknowledgement_chinese@M030_front_surprised_level3_001@zp1
|
98 |
+
```
|
99 |
+
|
100 |
+
`cfg_scale` controls the scale of classifer-free guidance. It can adjust the intensity of speaking styles.
|
101 |
+
|
102 |
+
`max_gen_len` is the maximum video generation duration, measured in seconds. If the input audio exceeds this length, it will be truncated.
|
103 |
+
|
104 |
+
The generated video will be named `$(output_name).mp4` and put in the output_video folder. Intermediate results, including the cropped portrait, will be in the `tmp/$(output_name)` folder.
|
105 |
+
|
106 |
+
Sample inputs are presented in `data` folder. Due to copyright issues, we are unable to include the songs we have used in this folder.
|
107 |
+
|
108 |
+
|
109 |
+
## Acknowledgements
|
110 |
+
|
111 |
+
We extend our heartfelt thanks for the invaluable contributions made by preceding works to the development of DreamTalk. This includes, but is not limited to:
|
112 |
+
[PIRenderer](https://github.com/RenYurui/PIRender)
|
113 |
+
,[AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face)
|
114 |
+
,[StyleTalk](https://github.com/FuxiVirtualHuman/styletalk)
|
115 |
+
,[Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
|
116 |
+
,[Wav2vec2.0](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-english)
|
117 |
+
,[diffusion-point-cloud](https://github.com/luost26/diffusion-point-cloud)
|
118 |
+
,[FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing). We are dedicated to advancing upon these foundational works with the utmost respect for their original contributions.
|
119 |
+
|
120 |
+
## Citation
|
121 |
+
If you find this codebase useful for your research, please use the following entry.
|
122 |
+
```BibTeX
|
123 |
+
@article{ma2023dreamtalk,
|
124 |
+
title={DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models},
|
125 |
+
author={Ma, Yifeng and Zhang, Shiwei and Wang, Jiayu and Wang, Xiang and Zhang, Yingya and Deng, Zhidong},
|
126 |
+
journal={arXiv preprint arXiv:2312.09767},
|
127 |
+
year={2023}
|
128 |
+
}
|
129 |
+
```
|
130 |
+
|
131 |
+
|
damo/dreamtalk/checkpoints/denoising_network.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93864d1316f60e75b40bd820707bb2464f790b1636ae2b9275ee500d41c4e3ae
|
3 |
+
size 47908943
|
damo/dreamtalk/checkpoints/renderer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a67014839d42d592255c9fc3b3ceecbcd62c27ce0c0a89ed6628292447404242
|
3 |
+
size 335281551
|
damo/dreamtalk/configs/default.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
|
4 |
+
_C = CN()
|
5 |
+
_C.TAG = "style_id_emotion"
|
6 |
+
_C.DECODER_TYPE = "DisentangleDecoder"
|
7 |
+
_C.CONTENT_ENCODER_TYPE = "ContentW2VEncoder"
|
8 |
+
_C.STYLE_ENCODER_TYPE = "StyleEncoder"
|
9 |
+
|
10 |
+
_C.DIFFNET_TYPE = "DiffusionNet"
|
11 |
+
|
12 |
+
_C.WIN_SIZE = 5
|
13 |
+
_C.D_MODEL = 256
|
14 |
+
|
15 |
+
_C.DATASET = CN()
|
16 |
+
_C.DATASET.FACE3D_DIM = 64
|
17 |
+
_C.DATASET.NUM_FRAMES = 64
|
18 |
+
_C.DATASET.STYLE_MAX_LEN = 256
|
19 |
+
|
20 |
+
_C.TRAIN = CN()
|
21 |
+
_C.TRAIN.FACE3D_LATENT = CN()
|
22 |
+
_C.TRAIN.FACE3D_LATENT.TYPE = "face3d"
|
23 |
+
|
24 |
+
_C.DIFFUSION = CN()
|
25 |
+
_C.DIFFUSION.PREDICT_WHAT = "x0" # noise | x0
|
26 |
+
_C.DIFFUSION.SCHEDULE = CN()
|
27 |
+
_C.DIFFUSION.SCHEDULE.NUM_STEPS = 1000
|
28 |
+
_C.DIFFUSION.SCHEDULE.BETA_1 = 1e-4
|
29 |
+
_C.DIFFUSION.SCHEDULE.BETA_T = 0.02
|
30 |
+
_C.DIFFUSION.SCHEDULE.MODE = "linear"
|
31 |
+
|
32 |
+
_C.CONTENT_ENCODER = CN()
|
33 |
+
_C.CONTENT_ENCODER.d_model = _C.D_MODEL
|
34 |
+
_C.CONTENT_ENCODER.nhead = 8
|
35 |
+
_C.CONTENT_ENCODER.num_encoder_layers = 3
|
36 |
+
_C.CONTENT_ENCODER.dim_feedforward = 4 * _C.D_MODEL
|
37 |
+
_C.CONTENT_ENCODER.dropout = 0.1
|
38 |
+
_C.CONTENT_ENCODER.activation = "relu"
|
39 |
+
_C.CONTENT_ENCODER.normalize_before = False
|
40 |
+
_C.CONTENT_ENCODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
|
41 |
+
|
42 |
+
_C.STYLE_ENCODER = CN()
|
43 |
+
_C.STYLE_ENCODER.d_model = _C.D_MODEL
|
44 |
+
_C.STYLE_ENCODER.nhead = 8
|
45 |
+
_C.STYLE_ENCODER.num_encoder_layers = 3
|
46 |
+
_C.STYLE_ENCODER.dim_feedforward = 4 * _C.D_MODEL
|
47 |
+
_C.STYLE_ENCODER.dropout = 0.1
|
48 |
+
_C.STYLE_ENCODER.activation = "relu"
|
49 |
+
_C.STYLE_ENCODER.normalize_before = False
|
50 |
+
_C.STYLE_ENCODER.pos_embed_len = _C.DATASET.STYLE_MAX_LEN
|
51 |
+
_C.STYLE_ENCODER.aggregate_method = (
|
52 |
+
"self_attention_pooling" # average | self_attention_pooling
|
53 |
+
)
|
54 |
+
# _C.STYLE_ENCODER.input_dim = _C.DATASET.FACE3D_DIM
|
55 |
+
|
56 |
+
_C.DECODER = CN()
|
57 |
+
_C.DECODER.d_model = _C.D_MODEL
|
58 |
+
_C.DECODER.nhead = 8
|
59 |
+
_C.DECODER.num_decoder_layers = 3
|
60 |
+
_C.DECODER.dim_feedforward = 4 * _C.D_MODEL
|
61 |
+
_C.DECODER.dropout = 0.1
|
62 |
+
_C.DECODER.activation = "relu"
|
63 |
+
_C.DECODER.normalize_before = False
|
64 |
+
_C.DECODER.return_intermediate_dec = False
|
65 |
+
_C.DECODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
|
66 |
+
_C.DECODER.network_type = "TransformerDecoder"
|
67 |
+
_C.DECODER.dynamic_K = None
|
68 |
+
_C.DECODER.dynamic_ratio = None
|
69 |
+
# _C.DECODER.output_dim = _C.DATASET.FACE3D_DIM
|
70 |
+
# LSFM basis:
|
71 |
+
# _C.DECODER.upper_face3d_indices = tuple(list(range(19)) + list(range(46, 51)))
|
72 |
+
# _C.DECODER.lower_face3d_indices = tuple(range(19, 46))
|
73 |
+
# BFM basis:
|
74 |
+
# fmt: off
|
75 |
+
_C.DECODER.upper_face3d_indices = [6, 8, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
|
76 |
+
# fmt: on
|
77 |
+
_C.DECODER.lower_face3d_indices = [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14]
|
78 |
+
|
79 |
+
_C.CF_GUIDANCE = CN()
|
80 |
+
_C.CF_GUIDANCE.TRAINING = True
|
81 |
+
_C.CF_GUIDANCE.INFERENCE = True
|
82 |
+
_C.CF_GUIDANCE.NULL_PROB = 0.1
|
83 |
+
_C.CF_GUIDANCE.SCALE = 1.0
|
84 |
+
|
85 |
+
_C.INFERENCE = CN()
|
86 |
+
_C.INFERENCE.CHECKPOINT = "checkpoints/denoising_network.pth"
|
87 |
+
|
88 |
+
|
89 |
+
def get_cfg_defaults():
|
90 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
91 |
+
return _C.clone()
|
damo/dreamtalk/configuration.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"framework": "pytorch",
|
3 |
+
"task": "text-to-video-synthesis",
|
4 |
+
"model": {
|
5 |
+
"type": "Dreamtalk-Generation"
|
6 |
+
},
|
7 |
+
"pipeline": {
|
8 |
+
"type": "Dreamtalk-generation-pipe"
|
9 |
+
},
|
10 |
+
"allow_remote": true
|
11 |
+
}
|
damo/dreamtalk/core/networks/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from core.networks.generator import (
|
2 |
+
StyleEncoder,
|
3 |
+
Decoder,
|
4 |
+
ContentW2VEncoder,
|
5 |
+
)
|
6 |
+
from core.networks.disentangle_decoder import DisentangleDecoder
|
7 |
+
|
8 |
+
|
9 |
+
def get_network(name: str):
|
10 |
+
obj = globals().get(name)
|
11 |
+
if obj is None:
|
12 |
+
raise KeyError("Unknown Network: %s" % name)
|
13 |
+
else:
|
14 |
+
return obj
|
damo/dreamtalk/core/networks/diffusion_net.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Module
|
5 |
+
from core.networks.diffusion_util import VarianceSchedule
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def face3d_raw_to_norm(face3d_raw, exp_min, exp_max):
|
10 |
+
"""
|
11 |
+
|
12 |
+
Args:
|
13 |
+
face3d_raw (_type_): (B, L, C_face3d)
|
14 |
+
exp_min (_type_): (C_face3d)
|
15 |
+
exp_max (_type_): (C_face3d)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
_type_: (B, L, C_face3d) in [-1, 1]
|
19 |
+
"""
|
20 |
+
exp_min_expand = exp_min[None, None, :]
|
21 |
+
exp_max_expand = exp_max[None, None, :]
|
22 |
+
face3d_norm_01 = (face3d_raw - exp_min_expand) / (exp_max_expand - exp_min_expand)
|
23 |
+
face3d_norm = face3d_norm_01 * 2 - 1
|
24 |
+
return face3d_norm
|
25 |
+
|
26 |
+
|
27 |
+
def face3d_norm_to_raw(face3d_norm, exp_min, exp_max):
|
28 |
+
"""
|
29 |
+
|
30 |
+
Args:
|
31 |
+
face3d_norm (_type_): (B, L, C_face3d)
|
32 |
+
exp_min (_type_): (C_face3d)
|
33 |
+
exp_max (_type_): (C_face3d)
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
_type_: (B, L, C_face3d)
|
37 |
+
"""
|
38 |
+
exp_min_expand = exp_min[None, None, :]
|
39 |
+
exp_max_expand = exp_max[None, None, :]
|
40 |
+
face3d_norm_01 = (face3d_norm + 1) / 2
|
41 |
+
face3d_raw = face3d_norm_01 * (exp_max_expand - exp_min_expand) + exp_min_expand
|
42 |
+
return face3d_raw
|
43 |
+
|
44 |
+
|
45 |
+
class DiffusionNet(Module):
|
46 |
+
def __init__(self, cfg, net, var_sched: VarianceSchedule):
|
47 |
+
super().__init__()
|
48 |
+
self.cfg = cfg
|
49 |
+
self.net = net
|
50 |
+
self.var_sched = var_sched
|
51 |
+
self.face3d_latent_type = self.cfg.TRAIN.FACE3D_LATENT.TYPE
|
52 |
+
self.predict_what = self.cfg.DIFFUSION.PREDICT_WHAT
|
53 |
+
|
54 |
+
if self.cfg.CF_GUIDANCE.TRAINING:
|
55 |
+
null_style_clip = torch.zeros(
|
56 |
+
self.cfg.DATASET.STYLE_MAX_LEN, self.cfg.DATASET.FACE3D_DIM
|
57 |
+
)
|
58 |
+
self.register_buffer("null_style_clip", null_style_clip)
|
59 |
+
|
60 |
+
null_pad_mask = torch.tensor([False] * self.cfg.DATASET.STYLE_MAX_LEN)
|
61 |
+
self.register_buffer("null_pad_mask", null_pad_mask)
|
62 |
+
|
63 |
+
def _face3d_to_latent(self, face3d):
|
64 |
+
latent = None
|
65 |
+
if self.face3d_latent_type == "face3d":
|
66 |
+
latent = face3d
|
67 |
+
elif self.face3d_latent_type == "normalized_face3d":
|
68 |
+
latent = face3d_raw_to_norm(
|
69 |
+
face3d, exp_min=self.exp_min, exp_max=self.exp_max
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
|
73 |
+
return latent
|
74 |
+
|
75 |
+
def _latent_to_face3d(self, latent):
|
76 |
+
face3d = None
|
77 |
+
if self.face3d_latent_type == "face3d":
|
78 |
+
face3d = latent
|
79 |
+
elif self.face3d_latent_type == "normalized_face3d":
|
80 |
+
latent = torch.clamp(latent, min=-1, max=1)
|
81 |
+
face3d = face3d_norm_to_raw(
|
82 |
+
latent, exp_min=self.exp_min, exp_max=self.exp_max
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
|
86 |
+
return face3d
|
87 |
+
|
88 |
+
def ddim_sample(
|
89 |
+
self,
|
90 |
+
audio,
|
91 |
+
style_clip,
|
92 |
+
style_pad_mask,
|
93 |
+
output_dim,
|
94 |
+
flexibility=0.0,
|
95 |
+
ret_traj=False,
|
96 |
+
use_cf_guidance=False,
|
97 |
+
cfg_scale=2.0,
|
98 |
+
ddim_num_step=50,
|
99 |
+
ready_style_code=None,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
|
103 |
+
Args:
|
104 |
+
audio (_type_): (B, L, W) or (B, L, W, C)
|
105 |
+
style_clip (_type_): (B, L_clipmax, C_face3d)
|
106 |
+
style_pad_mask : (B, L_clipmax)
|
107 |
+
pose_dim (_type_): int
|
108 |
+
flexibility (float, optional): _description_. Defaults to 0.0.
|
109 |
+
ret_traj (bool, optional): _description_. Defaults to False.
|
110 |
+
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
_type_: (B, L, C_face)
|
114 |
+
"""
|
115 |
+
if self.predict_what != "x0":
|
116 |
+
raise NotImplementedError(self.predict_what)
|
117 |
+
|
118 |
+
if ready_style_code is not None and use_cf_guidance:
|
119 |
+
raise NotImplementedError("not implement cfg for ready style code")
|
120 |
+
|
121 |
+
c = self.var_sched.num_steps // ddim_num_step
|
122 |
+
time_steps = torch.tensor(
|
123 |
+
np.asarray(list(range(0, self.var_sched.num_steps, c))) + 1
|
124 |
+
)
|
125 |
+
assert len(time_steps) == ddim_num_step
|
126 |
+
prev_time_steps = torch.cat((torch.tensor([0]), time_steps[:-1]))
|
127 |
+
|
128 |
+
batch_size, output_len = audio.shape[:2]
|
129 |
+
# batch_size = context.size(0)
|
130 |
+
context = {
|
131 |
+
"audio": audio,
|
132 |
+
"style_clip": style_clip,
|
133 |
+
"style_pad_mask": style_pad_mask,
|
134 |
+
"ready_style_code": ready_style_code,
|
135 |
+
}
|
136 |
+
if use_cf_guidance:
|
137 |
+
uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
|
138 |
+
batch_size, 1, 1
|
139 |
+
)
|
140 |
+
uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
|
141 |
+
|
142 |
+
context_double = {
|
143 |
+
"audio": torch.cat([audio] * 2, dim=0),
|
144 |
+
"style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
|
145 |
+
"style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
|
146 |
+
"ready_style_code": None
|
147 |
+
if ready_style_code is None
|
148 |
+
else torch.cat(
|
149 |
+
[
|
150 |
+
ready_style_code,
|
151 |
+
self.net.style_encoder(uncond_style_clip, uncond_pad_mask),
|
152 |
+
],
|
153 |
+
dim=0,
|
154 |
+
),
|
155 |
+
}
|
156 |
+
|
157 |
+
x_t = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
|
158 |
+
|
159 |
+
for idx in list(range(ddim_num_step))[::-1]:
|
160 |
+
t = time_steps[idx]
|
161 |
+
t_prev = prev_time_steps[idx]
|
162 |
+
ddim_alpha = self.var_sched.alpha_bars[t]
|
163 |
+
ddim_alpha_prev = self.var_sched.alpha_bars[t_prev]
|
164 |
+
|
165 |
+
t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
|
166 |
+
if use_cf_guidance:
|
167 |
+
x_t_double = torch.cat([x_t] * 2, dim=0)
|
168 |
+
t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
|
169 |
+
cond_output, uncond_output = self.net(
|
170 |
+
x_t_double, t=t_tensor_double, **context_double
|
171 |
+
).chunk(2)
|
172 |
+
diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
|
173 |
+
else:
|
174 |
+
diff_output = self.net(x_t, t=t_tensor, **context)
|
175 |
+
|
176 |
+
pred_x0 = diff_output
|
177 |
+
eps = (x_t - torch.sqrt(ddim_alpha) * pred_x0) / torch.sqrt(1 - ddim_alpha)
|
178 |
+
c1 = torch.sqrt(ddim_alpha_prev)
|
179 |
+
c2 = torch.sqrt(1 - ddim_alpha_prev)
|
180 |
+
|
181 |
+
x_t = c1 * pred_x0 + c2 * eps
|
182 |
+
|
183 |
+
latent_output = x_t
|
184 |
+
face3d_output = self._latent_to_face3d(latent_output)
|
185 |
+
return face3d_output
|
186 |
+
|
187 |
+
def sample(
|
188 |
+
self,
|
189 |
+
audio,
|
190 |
+
style_clip,
|
191 |
+
style_pad_mask,
|
192 |
+
output_dim,
|
193 |
+
flexibility=0.0,
|
194 |
+
ret_traj=False,
|
195 |
+
use_cf_guidance=False,
|
196 |
+
cfg_scale=2.0,
|
197 |
+
sample_method="ddpm",
|
198 |
+
ddim_num_step=50,
|
199 |
+
ready_style_code=None,
|
200 |
+
):
|
201 |
+
# sample_method = kwargs["sample_method"]
|
202 |
+
if sample_method == "ddpm":
|
203 |
+
if ready_style_code is not None:
|
204 |
+
raise NotImplementedError("ready style code in ddpm")
|
205 |
+
return self.ddpm_sample(
|
206 |
+
audio,
|
207 |
+
style_clip,
|
208 |
+
style_pad_mask,
|
209 |
+
output_dim,
|
210 |
+
flexibility=flexibility,
|
211 |
+
ret_traj=ret_traj,
|
212 |
+
use_cf_guidance=use_cf_guidance,
|
213 |
+
cfg_scale=cfg_scale,
|
214 |
+
)
|
215 |
+
elif sample_method == "ddim":
|
216 |
+
return self.ddim_sample(
|
217 |
+
audio,
|
218 |
+
style_clip,
|
219 |
+
style_pad_mask,
|
220 |
+
output_dim,
|
221 |
+
flexibility=flexibility,
|
222 |
+
ret_traj=ret_traj,
|
223 |
+
use_cf_guidance=use_cf_guidance,
|
224 |
+
cfg_scale=cfg_scale,
|
225 |
+
ddim_num_step=ddim_num_step,
|
226 |
+
ready_style_code=ready_style_code,
|
227 |
+
)
|
228 |
+
|
229 |
+
def ddpm_sample(
|
230 |
+
self,
|
231 |
+
audio,
|
232 |
+
style_clip,
|
233 |
+
style_pad_mask,
|
234 |
+
output_dim,
|
235 |
+
flexibility=0.0,
|
236 |
+
ret_traj=False,
|
237 |
+
use_cf_guidance=False,
|
238 |
+
cfg_scale=2.0,
|
239 |
+
):
|
240 |
+
"""
|
241 |
+
|
242 |
+
Args:
|
243 |
+
audio (_type_): (B, L, W) or (B, L, W, C)
|
244 |
+
style_clip (_type_): (B, L_clipmax, C_face3d)
|
245 |
+
style_pad_mask : (B, L_clipmax)
|
246 |
+
pose_dim (_type_): int
|
247 |
+
flexibility (float, optional): _description_. Defaults to 0.0.
|
248 |
+
ret_traj (bool, optional): _description_. Defaults to False.
|
249 |
+
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
_type_: (B, L, C_face)
|
253 |
+
"""
|
254 |
+
batch_size, output_len = audio.shape[:2]
|
255 |
+
# batch_size = context.size(0)
|
256 |
+
context = {
|
257 |
+
"audio": audio,
|
258 |
+
"style_clip": style_clip,
|
259 |
+
"style_pad_mask": style_pad_mask,
|
260 |
+
}
|
261 |
+
if use_cf_guidance:
|
262 |
+
uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
|
263 |
+
batch_size, 1, 1
|
264 |
+
)
|
265 |
+
uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
|
266 |
+
context_double = {
|
267 |
+
"audio": torch.cat([audio] * 2, dim=0),
|
268 |
+
"style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
|
269 |
+
"style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
|
270 |
+
}
|
271 |
+
|
272 |
+
x_T = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
|
273 |
+
traj = {self.var_sched.num_steps: x_T}
|
274 |
+
for t in range(self.var_sched.num_steps, 0, -1):
|
275 |
+
alpha = self.var_sched.alphas[t]
|
276 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
277 |
+
alpha_bar_prev = self.var_sched.alpha_bars[t - 1]
|
278 |
+
sigma = self.var_sched.get_sigmas(t, flexibility)
|
279 |
+
|
280 |
+
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
|
281 |
+
x_t = traj[t]
|
282 |
+
t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
|
283 |
+
if use_cf_guidance:
|
284 |
+
x_t_double = torch.cat([x_t] * 2, dim=0)
|
285 |
+
t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
|
286 |
+
cond_output, uncond_output = self.net(
|
287 |
+
x_t_double, t=t_tensor_double, **context_double
|
288 |
+
).chunk(2)
|
289 |
+
diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
|
290 |
+
else:
|
291 |
+
diff_output = self.net(x_t, t=t_tensor, **context)
|
292 |
+
|
293 |
+
if self.predict_what == "noise":
|
294 |
+
c0 = 1.0 / torch.sqrt(alpha)
|
295 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
296 |
+
x_next = c0 * (x_t - c1 * diff_output) + sigma * z
|
297 |
+
elif self.predict_what == "x0":
|
298 |
+
d0 = torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar)
|
299 |
+
d1 = torch.sqrt(alpha_bar_prev) * (1 - alpha) / (1 - alpha_bar)
|
300 |
+
x_next = d0 * x_t + d1 * diff_output + sigma * z
|
301 |
+
traj[t - 1] = x_next.detach()
|
302 |
+
traj[t] = traj[t].cpu()
|
303 |
+
if not ret_traj:
|
304 |
+
del traj[t]
|
305 |
+
|
306 |
+
if ret_traj:
|
307 |
+
raise NotImplementedError
|
308 |
+
return traj
|
309 |
+
else:
|
310 |
+
latent_output = traj[0]
|
311 |
+
face3d_output = self._latent_to_face3d(latent_output)
|
312 |
+
return face3d_output
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
|
317 |
+
|
318 |
+
diffnet = DiffusionNet(
|
319 |
+
net=NoisePredictor(),
|
320 |
+
var_sched=VarianceSchedule(
|
321 |
+
num_steps=500, beta_1=1e-4, beta_T=0.02, mode="linear"
|
322 |
+
),
|
323 |
+
)
|
324 |
+
|
325 |
+
import torch
|
326 |
+
|
327 |
+
gt_face3d = torch.randn(16, 64, 64)
|
328 |
+
audio = torch.randn(16, 64, 11)
|
329 |
+
style_clip = torch.randn(16, 256, 64)
|
330 |
+
style_pad_mask = torch.ones(16, 256)
|
331 |
+
|
332 |
+
context = {
|
333 |
+
"audio": audio,
|
334 |
+
"style_clip": style_clip,
|
335 |
+
"style_pad_mask": style_pad_mask,
|
336 |
+
}
|
337 |
+
|
338 |
+
loss = diffnet.get_loss(gt_face3d, context)
|
339 |
+
|
340 |
+
print("hello")
|
damo/dreamtalk/core/networks/diffusion_util.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Module
|
5 |
+
from core.networks import get_network
|
6 |
+
from core.utils import sinusoidal_embedding
|
7 |
+
|
8 |
+
|
9 |
+
class VarianceSchedule(Module):
|
10 |
+
def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
|
11 |
+
super().__init__()
|
12 |
+
assert mode in ("linear",)
|
13 |
+
self.num_steps = num_steps
|
14 |
+
self.beta_1 = beta_1
|
15 |
+
self.beta_T = beta_T
|
16 |
+
self.mode = mode
|
17 |
+
|
18 |
+
if mode == "linear":
|
19 |
+
betas = torch.linspace(beta_1, beta_T, steps=num_steps)
|
20 |
+
|
21 |
+
betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
|
22 |
+
|
23 |
+
alphas = 1 - betas
|
24 |
+
log_alphas = torch.log(alphas)
|
25 |
+
for i in range(1, log_alphas.size(0)): # 1 to T
|
26 |
+
log_alphas[i] += log_alphas[i - 1]
|
27 |
+
alpha_bars = log_alphas.exp()
|
28 |
+
|
29 |
+
sigmas_flex = torch.sqrt(betas)
|
30 |
+
sigmas_inflex = torch.zeros_like(sigmas_flex)
|
31 |
+
for i in range(1, sigmas_flex.size(0)):
|
32 |
+
sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
|
33 |
+
i
|
34 |
+
]
|
35 |
+
sigmas_inflex = torch.sqrt(sigmas_inflex)
|
36 |
+
|
37 |
+
self.register_buffer("betas", betas)
|
38 |
+
self.register_buffer("alphas", alphas)
|
39 |
+
self.register_buffer("alpha_bars", alpha_bars)
|
40 |
+
self.register_buffer("sigmas_flex", sigmas_flex)
|
41 |
+
self.register_buffer("sigmas_inflex", sigmas_inflex)
|
42 |
+
|
43 |
+
def uniform_sample_t(self, batch_size):
|
44 |
+
ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
|
45 |
+
return ts.tolist()
|
46 |
+
|
47 |
+
def get_sigmas(self, t, flexibility):
|
48 |
+
assert 0 <= flexibility and flexibility <= 1
|
49 |
+
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
|
50 |
+
1 - flexibility
|
51 |
+
)
|
52 |
+
return sigmas
|
53 |
+
|
54 |
+
|
55 |
+
class NoisePredictor(nn.Module):
|
56 |
+
def __init__(self, cfg):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
content_encoder_class = get_network(cfg.CONTENT_ENCODER_TYPE)
|
60 |
+
self.content_encoder = content_encoder_class(**cfg.CONTENT_ENCODER)
|
61 |
+
|
62 |
+
style_encoder_class = get_network(cfg.STYLE_ENCODER_TYPE)
|
63 |
+
cfg.defrost()
|
64 |
+
cfg.STYLE_ENCODER.input_dim = cfg.DATASET.FACE3D_DIM
|
65 |
+
cfg.freeze()
|
66 |
+
self.style_encoder = style_encoder_class(**cfg.STYLE_ENCODER)
|
67 |
+
|
68 |
+
decoder_class = get_network(cfg.DECODER_TYPE)
|
69 |
+
cfg.defrost()
|
70 |
+
cfg.DECODER.output_dim = cfg.DATASET.FACE3D_DIM
|
71 |
+
cfg.freeze()
|
72 |
+
self.decoder = decoder_class(**cfg.DECODER)
|
73 |
+
|
74 |
+
self.content_xt_to_decoder_input_wo_time = nn.Sequential(
|
75 |
+
nn.Linear(cfg.D_MODEL + cfg.DATASET.FACE3D_DIM, cfg.D_MODEL),
|
76 |
+
nn.ReLU(),
|
77 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
78 |
+
nn.ReLU(),
|
79 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
80 |
+
)
|
81 |
+
|
82 |
+
self.time_sinusoidal_dim = cfg.D_MODEL
|
83 |
+
self.time_embed_net = nn.Sequential(
|
84 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
85 |
+
nn.SiLU(),
|
86 |
+
nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(self, x_t, t, audio, style_clip, style_pad_mask, ready_style_code=None):
|
90 |
+
"""_summary_
|
91 |
+
|
92 |
+
Args:
|
93 |
+
x_t (_type_): (B, L, C_face)
|
94 |
+
t (_type_): (B,) dtype:float32
|
95 |
+
audio (_type_): (B, L, W)
|
96 |
+
style_clip (_type_): (B, L_clipmax, C_face3d)
|
97 |
+
style_pad_mask : (B, L_clipmax)
|
98 |
+
ready_style_code: (B, C_model)
|
99 |
+
Returns:
|
100 |
+
e_theta : (B, L, C_face)
|
101 |
+
"""
|
102 |
+
W = audio.shape[2]
|
103 |
+
content = self.content_encoder(audio)
|
104 |
+
# (B, L, W, C_model)
|
105 |
+
x_t_expand = x_t.unsqueeze(2).repeat(1, 1, W, 1)
|
106 |
+
# (B, L, C_face) -> (B, L, W, C_face)
|
107 |
+
content_xt_concat = torch.cat((content, x_t_expand), dim=3)
|
108 |
+
# (B, L, W, C_model+C_face)
|
109 |
+
decoder_input_without_time = self.content_xt_to_decoder_input_wo_time(
|
110 |
+
content_xt_concat
|
111 |
+
)
|
112 |
+
# (B, L, W, C_model)
|
113 |
+
|
114 |
+
time_sinusoidal = sinusoidal_embedding(t, self.time_sinusoidal_dim)
|
115 |
+
# (B, C_embed)
|
116 |
+
time_embedding = self.time_embed_net(time_sinusoidal)
|
117 |
+
# (B, C_model)
|
118 |
+
B, C = time_embedding.shape
|
119 |
+
time_embed_expand = time_embedding.view(B, 1, 1, C)
|
120 |
+
decoder_input = decoder_input_without_time + time_embed_expand
|
121 |
+
# (B, L, W, C_model)
|
122 |
+
|
123 |
+
if ready_style_code is not None:
|
124 |
+
style_code = ready_style_code
|
125 |
+
else:
|
126 |
+
style_code = self.style_encoder(style_clip, style_pad_mask)
|
127 |
+
# (B, C_model)
|
128 |
+
|
129 |
+
e_theta = self.decoder(decoder_input, style_code)
|
130 |
+
# (B, L, C_face)
|
131 |
+
return e_theta
|
damo/dreamtalk/core/networks/disentangle_decoder.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .transformer import (
|
5 |
+
PositionalEncoding,
|
6 |
+
TransformerDecoderLayer,
|
7 |
+
TransformerDecoder,
|
8 |
+
)
|
9 |
+
from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder
|
10 |
+
from core.utils import _reset_parameters
|
11 |
+
|
12 |
+
|
13 |
+
def get_decoder_network(
|
14 |
+
network_type,
|
15 |
+
d_model,
|
16 |
+
nhead,
|
17 |
+
dim_feedforward,
|
18 |
+
dropout,
|
19 |
+
activation,
|
20 |
+
normalize_before,
|
21 |
+
num_decoder_layers,
|
22 |
+
return_intermediate_dec,
|
23 |
+
dynamic_K,
|
24 |
+
dynamic_ratio,
|
25 |
+
):
|
26 |
+
decoder = None
|
27 |
+
if network_type == "TransformerDecoder":
|
28 |
+
decoder_layer = TransformerDecoderLayer(
|
29 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
30 |
+
)
|
31 |
+
norm = nn.LayerNorm(d_model)
|
32 |
+
decoder = TransformerDecoder(
|
33 |
+
decoder_layer,
|
34 |
+
num_decoder_layers,
|
35 |
+
norm,
|
36 |
+
return_intermediate_dec,
|
37 |
+
)
|
38 |
+
elif network_type == "DynamicFCDecoder":
|
39 |
+
d_style = d_model
|
40 |
+
decoder_layer = DynamicFCDecoderLayer(
|
41 |
+
d_model,
|
42 |
+
nhead,
|
43 |
+
d_style,
|
44 |
+
dynamic_K,
|
45 |
+
dynamic_ratio,
|
46 |
+
dim_feedforward,
|
47 |
+
dropout,
|
48 |
+
activation,
|
49 |
+
normalize_before,
|
50 |
+
)
|
51 |
+
norm = nn.LayerNorm(d_model)
|
52 |
+
decoder = DynamicFCDecoder(
|
53 |
+
decoder_layer, num_decoder_layers, norm, return_intermediate_dec
|
54 |
+
)
|
55 |
+
elif network_type == "DynamicFCEncoder":
|
56 |
+
d_style = d_model
|
57 |
+
decoder_layer = DynamicFCEncoderLayer(
|
58 |
+
d_model,
|
59 |
+
nhead,
|
60 |
+
d_style,
|
61 |
+
dynamic_K,
|
62 |
+
dynamic_ratio,
|
63 |
+
dim_feedforward,
|
64 |
+
dropout,
|
65 |
+
activation,
|
66 |
+
normalize_before,
|
67 |
+
)
|
68 |
+
norm = nn.LayerNorm(d_model)
|
69 |
+
decoder = DynamicFCEncoder(decoder_layer, num_decoder_layers, norm)
|
70 |
+
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Invalid network_type {network_type}")
|
73 |
+
|
74 |
+
return decoder
|
75 |
+
|
76 |
+
|
77 |
+
class DisentangleDecoder(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
d_model=512,
|
81 |
+
nhead=8,
|
82 |
+
num_decoder_layers=3,
|
83 |
+
dim_feedforward=2048,
|
84 |
+
dropout=0.1,
|
85 |
+
activation="relu",
|
86 |
+
normalize_before=False,
|
87 |
+
return_intermediate_dec=False,
|
88 |
+
pos_embed_len=80,
|
89 |
+
upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))),
|
90 |
+
lower_face3d_indices=tuple(range(19, 46)),
|
91 |
+
network_type="None",
|
92 |
+
dynamic_K=None,
|
93 |
+
dynamic_ratio=None,
|
94 |
+
**_,
|
95 |
+
) -> None:
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
self.upper_face3d_indices = upper_face3d_indices
|
99 |
+
self.lower_face3d_indices = lower_face3d_indices
|
100 |
+
|
101 |
+
# upper_decoder_layer = TransformerDecoderLayer(
|
102 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
103 |
+
# )
|
104 |
+
# upper_decoder_norm = nn.LayerNorm(d_model)
|
105 |
+
# self.upper_decoder = TransformerDecoder(
|
106 |
+
# upper_decoder_layer,
|
107 |
+
# num_decoder_layers,
|
108 |
+
# upper_decoder_norm,
|
109 |
+
# return_intermediate=return_intermediate_dec,
|
110 |
+
# )
|
111 |
+
self.upper_decoder = get_decoder_network(
|
112 |
+
network_type,
|
113 |
+
d_model,
|
114 |
+
nhead,
|
115 |
+
dim_feedforward,
|
116 |
+
dropout,
|
117 |
+
activation,
|
118 |
+
normalize_before,
|
119 |
+
num_decoder_layers,
|
120 |
+
return_intermediate_dec,
|
121 |
+
dynamic_K,
|
122 |
+
dynamic_ratio,
|
123 |
+
)
|
124 |
+
_reset_parameters(self.upper_decoder)
|
125 |
+
|
126 |
+
# lower_decoder_layer = TransformerDecoderLayer(
|
127 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
128 |
+
# )
|
129 |
+
# lower_decoder_norm = nn.LayerNorm(d_model)
|
130 |
+
# self.lower_decoder = TransformerDecoder(
|
131 |
+
# lower_decoder_layer,
|
132 |
+
# num_decoder_layers,
|
133 |
+
# lower_decoder_norm,
|
134 |
+
# return_intermediate=return_intermediate_dec,
|
135 |
+
# )
|
136 |
+
self.lower_decoder = get_decoder_network(
|
137 |
+
network_type,
|
138 |
+
d_model,
|
139 |
+
nhead,
|
140 |
+
dim_feedforward,
|
141 |
+
dropout,
|
142 |
+
activation,
|
143 |
+
normalize_before,
|
144 |
+
num_decoder_layers,
|
145 |
+
return_intermediate_dec,
|
146 |
+
dynamic_K,
|
147 |
+
dynamic_ratio,
|
148 |
+
)
|
149 |
+
_reset_parameters(self.lower_decoder)
|
150 |
+
|
151 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
152 |
+
|
153 |
+
tail_hidden_dim = d_model // 2
|
154 |
+
self.upper_tail_fc = nn.Sequential(
|
155 |
+
nn.Linear(d_model, tail_hidden_dim),
|
156 |
+
nn.ReLU(),
|
157 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
158 |
+
nn.ReLU(),
|
159 |
+
nn.Linear(tail_hidden_dim, len(upper_face3d_indices)),
|
160 |
+
)
|
161 |
+
self.lower_tail_fc = nn.Sequential(
|
162 |
+
nn.Linear(d_model, tail_hidden_dim),
|
163 |
+
nn.ReLU(),
|
164 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
165 |
+
nn.ReLU(),
|
166 |
+
nn.Linear(tail_hidden_dim, len(lower_face3d_indices)),
|
167 |
+
)
|
168 |
+
|
169 |
+
def forward(self, content, style_code):
|
170 |
+
"""
|
171 |
+
|
172 |
+
Args:
|
173 |
+
content (_type_): (B, num_frames, window, C_dmodel)
|
174 |
+
style_code (_type_): (B, C_dmodel)
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
face3d: (B, L_clip, C_3dmm)
|
178 |
+
"""
|
179 |
+
B, N, W, C = content.shape
|
180 |
+
style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
|
181 |
+
style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
182 |
+
# (W, B*N, C)
|
183 |
+
|
184 |
+
content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
185 |
+
# (W, B*N, C)
|
186 |
+
tgt = torch.zeros_like(style)
|
187 |
+
pos_embed = self.pos_embed(W)
|
188 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
189 |
+
|
190 |
+
upper_face3d_feat = self.upper_decoder(
|
191 |
+
tgt, content, pos=pos_embed, query_pos=style
|
192 |
+
)[0]
|
193 |
+
# (W, B*N, C)
|
194 |
+
upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
|
195 |
+
:, :, W // 2, :
|
196 |
+
]
|
197 |
+
# (B, N, C)
|
198 |
+
upper_face3d = self.upper_tail_fc(upper_face3d_feat)
|
199 |
+
# (B, N, C_exp)
|
200 |
+
|
201 |
+
lower_face3d_feat = self.lower_decoder(
|
202 |
+
tgt, content, pos=pos_embed, query_pos=style
|
203 |
+
)[0]
|
204 |
+
lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
|
205 |
+
:, :, W // 2, :
|
206 |
+
]
|
207 |
+
lower_face3d = self.lower_tail_fc(lower_face3d_feat)
|
208 |
+
C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices)
|
209 |
+
face3d = torch.zeros(B, N, C_exp).to(upper_face3d)
|
210 |
+
face3d[:, :, self.upper_face3d_indices] = upper_face3d
|
211 |
+
face3d[:, :, self.lower_face3d_indices] = lower_face3d
|
212 |
+
return face3d
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
import sys
|
217 |
+
|
218 |
+
sys.path.append("/home/mayifeng/Research/styleTH")
|
219 |
+
|
220 |
+
from configs.default import get_cfg_defaults
|
221 |
+
|
222 |
+
cfg = get_cfg_defaults()
|
223 |
+
cfg.merge_from_file("configs/styleTH_unpair_lsfm_emotion.yaml")
|
224 |
+
cfg.freeze()
|
225 |
+
|
226 |
+
# content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
|
227 |
+
|
228 |
+
# dummy_audio = torch.randint(0, 41, (5, 64, 11))
|
229 |
+
# dummy_content = content_encoder(dummy_audio)
|
230 |
+
|
231 |
+
# style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
|
232 |
+
# dummy_face3d_seq = torch.randn(5, 64, 64)
|
233 |
+
# dummy_style_code = style_encoder(dummy_face3d_seq)
|
234 |
+
|
235 |
+
decoder = DisentangleDecoder(**cfg.DECODER)
|
236 |
+
dummy_content = torch.randn(5, 64, 11, 256)
|
237 |
+
dummy_style = torch.randn(5, 256)
|
238 |
+
dummy_output = decoder(dummy_content, dummy_style)
|
239 |
+
|
240 |
+
print("hello")
|
damo/dreamtalk/core/networks/dynamic_conv.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class Attention(nn.Module):
|
9 |
+
def __init__(self, cond_planes, ratio, K, temperature=30, init_weight=True):
|
10 |
+
super().__init__()
|
11 |
+
# self.avgpool = nn.AdaptiveAvgPool2d(1)
|
12 |
+
self.temprature = temperature
|
13 |
+
assert cond_planes > ratio
|
14 |
+
hidden_planes = cond_planes // ratio
|
15 |
+
self.net = nn.Sequential(
|
16 |
+
nn.Conv2d(cond_planes, hidden_planes, kernel_size=1, bias=False),
|
17 |
+
nn.ReLU(),
|
18 |
+
nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False),
|
19 |
+
)
|
20 |
+
|
21 |
+
if init_weight:
|
22 |
+
self._initialize_weights()
|
23 |
+
|
24 |
+
def update_temprature(self):
|
25 |
+
if self.temprature > 1:
|
26 |
+
self.temprature -= 1
|
27 |
+
|
28 |
+
def _initialize_weights(self):
|
29 |
+
for m in self.modules():
|
30 |
+
if isinstance(m, nn.Conv2d):
|
31 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
32 |
+
if m.bias is not None:
|
33 |
+
nn.init.constant_(m.bias, 0)
|
34 |
+
if isinstance(m, nn.BatchNorm2d):
|
35 |
+
nn.init.constant_(m.weight, 1)
|
36 |
+
nn.init.constant_(m.bias, 0)
|
37 |
+
|
38 |
+
def forward(self, cond):
|
39 |
+
"""
|
40 |
+
|
41 |
+
Args:
|
42 |
+
cond (_type_): (B, C_style)
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
_type_: (B, K)
|
46 |
+
"""
|
47 |
+
|
48 |
+
# att = self.avgpool(cond) # bs,dim,1,1
|
49 |
+
att = cond.view(cond.shape[0], cond.shape[1], 1, 1)
|
50 |
+
att = self.net(att).view(cond.shape[0], -1) # bs,K
|
51 |
+
return F.softmax(att / self.temprature, -1)
|
52 |
+
|
53 |
+
|
54 |
+
class DynamicConv(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_planes,
|
58 |
+
out_planes,
|
59 |
+
cond_planes,
|
60 |
+
kernel_size,
|
61 |
+
stride,
|
62 |
+
padding=0,
|
63 |
+
dilation=1,
|
64 |
+
groups=1,
|
65 |
+
bias=True,
|
66 |
+
K=4,
|
67 |
+
temperature=30,
|
68 |
+
ratio=4,
|
69 |
+
init_weight=True,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.in_planes = in_planes
|
73 |
+
self.out_planes = out_planes
|
74 |
+
self.cond_planes = cond_planes
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.stride = stride
|
77 |
+
self.padding = padding
|
78 |
+
self.dilation = dilation
|
79 |
+
self.groups = groups
|
80 |
+
self.bias = bias
|
81 |
+
self.K = K
|
82 |
+
self.init_weight = init_weight
|
83 |
+
self.attention = Attention(
|
84 |
+
cond_planes=cond_planes, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight
|
85 |
+
)
|
86 |
+
|
87 |
+
self.weight = nn.Parameter(
|
88 |
+
torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size), requires_grad=True
|
89 |
+
)
|
90 |
+
if bias:
|
91 |
+
self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
|
92 |
+
else:
|
93 |
+
self.bias = None
|
94 |
+
|
95 |
+
if self.init_weight:
|
96 |
+
self._initialize_weights()
|
97 |
+
|
98 |
+
def _initialize_weights(self):
|
99 |
+
for i in range(self.K):
|
100 |
+
nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
|
101 |
+
if self.bias is not None:
|
102 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
|
103 |
+
if fan_in != 0:
|
104 |
+
bound = 1 / math.sqrt(fan_in)
|
105 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
106 |
+
|
107 |
+
def forward(self, x, cond):
|
108 |
+
"""
|
109 |
+
|
110 |
+
Args:
|
111 |
+
x (_type_): (B, C_in, L, 1)
|
112 |
+
cond (_type_): (B, C_style)
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
_type_: (B, C_out, L, 1)
|
116 |
+
"""
|
117 |
+
bs, in_planels, h, w = x.shape
|
118 |
+
softmax_att = self.attention(cond) # bs,K
|
119 |
+
x = x.view(1, -1, h, w)
|
120 |
+
weight = self.weight.view(self.K, -1) # K,-1
|
121 |
+
aggregate_weight = torch.mm(softmax_att, weight).view(
|
122 |
+
bs * self.out_planes, self.in_planes // self.groups, self.kernel_size, self.kernel_size
|
123 |
+
) # bs*out_p,in_p,k,k
|
124 |
+
|
125 |
+
if self.bias is not None:
|
126 |
+
bias = self.bias.view(self.K, -1) # K,out_p
|
127 |
+
aggregate_bias = torch.mm(softmax_att, bias).view(-1) # bs*out_p
|
128 |
+
output = F.conv2d(
|
129 |
+
x, # 1, bs*in_p, L, 1
|
130 |
+
weight=aggregate_weight,
|
131 |
+
bias=aggregate_bias,
|
132 |
+
stride=self.stride,
|
133 |
+
padding=self.padding,
|
134 |
+
groups=self.groups * bs,
|
135 |
+
dilation=self.dilation,
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
output = F.conv2d(
|
139 |
+
x,
|
140 |
+
weight=aggregate_weight,
|
141 |
+
bias=None,
|
142 |
+
stride=self.stride,
|
143 |
+
padding=self.padding,
|
144 |
+
groups=self.groups * bs,
|
145 |
+
dilation=self.dilation,
|
146 |
+
)
|
147 |
+
|
148 |
+
output = output.view(bs, self.out_planes, h, w)
|
149 |
+
return output
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
input = torch.randn(3, 32, 64, 64)
|
154 |
+
m = DynamicConv(in_planes=32, out_planes=64, kernel_size=3, stride=1, padding=1, bias=True)
|
155 |
+
out = m(input)
|
156 |
+
print(out.shape)
|
damo/dreamtalk/core/networks/dynamic_fc_decoder.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from core.networks.transformer import _get_activation_fn, _get_clones
|
5 |
+
from core.networks.dynamic_linear import DynamicLinear
|
6 |
+
|
7 |
+
|
8 |
+
class DynamicFCDecoderLayer(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
d_model,
|
12 |
+
nhead,
|
13 |
+
d_style,
|
14 |
+
dynamic_K,
|
15 |
+
dynamic_ratio,
|
16 |
+
dim_feedforward=2048,
|
17 |
+
dropout=0.1,
|
18 |
+
activation="relu",
|
19 |
+
normalize_before=False,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
23 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
24 |
+
# Implementation of Feedforward model
|
25 |
+
# self.linear1 = nn.Linear(d_model, dim_feedforward)
|
26 |
+
self.linear1 = DynamicLinear(d_model, dim_feedforward, d_style, K=dynamic_K, ratio=dynamic_ratio)
|
27 |
+
self.dropout = nn.Dropout(dropout)
|
28 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
29 |
+
# self.linear2 = DynamicLinear(dim_feedforward, d_model, d_style, K=dynamic_K, ratio=dynamic_ratio)
|
30 |
+
|
31 |
+
self.norm1 = nn.LayerNorm(d_model)
|
32 |
+
self.norm2 = nn.LayerNorm(d_model)
|
33 |
+
self.norm3 = nn.LayerNorm(d_model)
|
34 |
+
self.dropout1 = nn.Dropout(dropout)
|
35 |
+
self.dropout2 = nn.Dropout(dropout)
|
36 |
+
self.dropout3 = nn.Dropout(dropout)
|
37 |
+
|
38 |
+
self.activation = _get_activation_fn(activation)
|
39 |
+
self.normalize_before = normalize_before
|
40 |
+
|
41 |
+
def with_pos_embed(self, tensor, pos):
|
42 |
+
return tensor if pos is None else tensor + pos
|
43 |
+
|
44 |
+
def forward_post(
|
45 |
+
self,
|
46 |
+
tgt,
|
47 |
+
memory,
|
48 |
+
style,
|
49 |
+
tgt_mask=None,
|
50 |
+
memory_mask=None,
|
51 |
+
tgt_key_padding_mask=None,
|
52 |
+
memory_key_padding_mask=None,
|
53 |
+
pos=None,
|
54 |
+
query_pos=None,
|
55 |
+
):
|
56 |
+
# q = k = self.with_pos_embed(tgt, query_pos)
|
57 |
+
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
58 |
+
tgt = tgt + self.dropout1(tgt2)
|
59 |
+
tgt = self.norm1(tgt)
|
60 |
+
tgt2 = self.multihead_attn(
|
61 |
+
query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
|
62 |
+
)[0]
|
63 |
+
tgt = tgt + self.dropout2(tgt2)
|
64 |
+
tgt = self.norm2(tgt)
|
65 |
+
# tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))), style)
|
66 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))))
|
67 |
+
tgt = tgt + self.dropout3(tgt2)
|
68 |
+
tgt = self.norm3(tgt)
|
69 |
+
return tgt
|
70 |
+
|
71 |
+
# def forward_pre(
|
72 |
+
# self,
|
73 |
+
# tgt,
|
74 |
+
# memory,
|
75 |
+
# tgt_mask=None,
|
76 |
+
# memory_mask=None,
|
77 |
+
# tgt_key_padding_mask=None,
|
78 |
+
# memory_key_padding_mask=None,
|
79 |
+
# pos=None,
|
80 |
+
# query_pos=None,
|
81 |
+
# ):
|
82 |
+
# tgt2 = self.norm1(tgt)
|
83 |
+
# # q = k = self.with_pos_embed(tgt2, query_pos)
|
84 |
+
# tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
|
85 |
+
# tgt = tgt + self.dropout1(tgt2)
|
86 |
+
# tgt2 = self.norm2(tgt)
|
87 |
+
# tgt2 = self.multihead_attn(
|
88 |
+
# query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
|
89 |
+
# )[0]
|
90 |
+
# tgt = tgt + self.dropout2(tgt2)
|
91 |
+
# tgt2 = self.norm3(tgt)
|
92 |
+
# tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
93 |
+
# tgt = tgt + self.dropout3(tgt2)
|
94 |
+
# return tgt
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
tgt,
|
99 |
+
memory,
|
100 |
+
style,
|
101 |
+
tgt_mask=None,
|
102 |
+
memory_mask=None,
|
103 |
+
tgt_key_padding_mask=None,
|
104 |
+
memory_key_padding_mask=None,
|
105 |
+
pos=None,
|
106 |
+
query_pos=None,
|
107 |
+
):
|
108 |
+
if self.normalize_before:
|
109 |
+
raise NotImplementedError
|
110 |
+
# return self.forward_pre(
|
111 |
+
# tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
112 |
+
# )
|
113 |
+
return self.forward_post(
|
114 |
+
tgt, memory, style, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
class DynamicFCDecoder(nn.Module):
|
119 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
120 |
+
super().__init__()
|
121 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
122 |
+
self.num_layers = num_layers
|
123 |
+
self.norm = norm
|
124 |
+
self.return_intermediate = return_intermediate
|
125 |
+
|
126 |
+
def forward(
|
127 |
+
self,
|
128 |
+
tgt,
|
129 |
+
memory,
|
130 |
+
tgt_mask=None,
|
131 |
+
memory_mask=None,
|
132 |
+
tgt_key_padding_mask=None,
|
133 |
+
memory_key_padding_mask=None,
|
134 |
+
pos=None,
|
135 |
+
query_pos=None,
|
136 |
+
):
|
137 |
+
style = query_pos[0]
|
138 |
+
# (B*N, C)
|
139 |
+
output = tgt + pos + query_pos
|
140 |
+
|
141 |
+
intermediate = []
|
142 |
+
|
143 |
+
for layer in self.layers:
|
144 |
+
output = layer(
|
145 |
+
output,
|
146 |
+
memory,
|
147 |
+
style,
|
148 |
+
tgt_mask=tgt_mask,
|
149 |
+
memory_mask=memory_mask,
|
150 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
151 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
152 |
+
pos=pos,
|
153 |
+
query_pos=query_pos,
|
154 |
+
)
|
155 |
+
if self.return_intermediate:
|
156 |
+
intermediate.append(self.norm(output))
|
157 |
+
|
158 |
+
if self.norm is not None:
|
159 |
+
output = self.norm(output)
|
160 |
+
if self.return_intermediate:
|
161 |
+
intermediate.pop()
|
162 |
+
intermediate.append(output)
|
163 |
+
|
164 |
+
if self.return_intermediate:
|
165 |
+
return torch.stack(intermediate)
|
166 |
+
|
167 |
+
return output.unsqueeze(0)
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
query = torch.randn(11, 1024, 256)
|
172 |
+
content = torch.randn(11, 1024, 256)
|
173 |
+
style = torch.randn(1024, 256)
|
174 |
+
pos = torch.randn(11, 1, 256)
|
175 |
+
m = DynamicFCDecoderLayer(256, 4, 256, 4, 4, 1024)
|
176 |
+
|
177 |
+
out = m(query, content, style, pos=pos)
|
178 |
+
print(out.shape)
|
damo/dreamtalk/core/networks/dynamic_linear.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from core.networks.dynamic_conv import DynamicConv
|
8 |
+
|
9 |
+
|
10 |
+
class DynamicLinear(nn.Module):
|
11 |
+
def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.dynamic_conv = DynamicConv(
|
15 |
+
in_planes,
|
16 |
+
out_planes,
|
17 |
+
cond_planes,
|
18 |
+
kernel_size=1,
|
19 |
+
stride=1,
|
20 |
+
padding=0,
|
21 |
+
bias=bias,
|
22 |
+
K=K,
|
23 |
+
ratio=ratio,
|
24 |
+
temperature=temperature,
|
25 |
+
init_weight=init_weight,
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, x, cond):
|
29 |
+
"""
|
30 |
+
|
31 |
+
Args:
|
32 |
+
x (_type_): (L, B, C_in)
|
33 |
+
cond (_type_): (B, C_style)
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
_type_: (L, B, C_out)
|
37 |
+
"""
|
38 |
+
x = x.permute(1, 2, 0).unsqueeze(-1)
|
39 |
+
out = self.dynamic_conv(x, cond)
|
40 |
+
# (B, C_out, L, 1)
|
41 |
+
out = out.squeeze().permute(2, 0, 1)
|
42 |
+
return out
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
input = torch.randn(11, 1024, 255)
|
47 |
+
cond = torch.randn(1024, 256)
|
48 |
+
m = DynamicLinear(255, 1000, 256, K=7, temperature=5, ratio=8)
|
49 |
+
out = m(input, cond)
|
50 |
+
print(out.shape)
|
damo/dreamtalk/core/networks/generator.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .transformer import (
|
5 |
+
TransformerEncoder,
|
6 |
+
TransformerEncoderLayer,
|
7 |
+
PositionalEncoding,
|
8 |
+
TransformerDecoderLayer,
|
9 |
+
TransformerDecoder,
|
10 |
+
)
|
11 |
+
from core.utils import _reset_parameters
|
12 |
+
from core.networks.self_attention_pooling import SelfAttentionPooling
|
13 |
+
|
14 |
+
|
15 |
+
# class ContentEncoder(nn.Module):
|
16 |
+
# def __init__(
|
17 |
+
# self,
|
18 |
+
# d_model=512,
|
19 |
+
# nhead=8,
|
20 |
+
# num_encoder_layers=6,
|
21 |
+
# dim_feedforward=2048,
|
22 |
+
# dropout=0.1,
|
23 |
+
# activation="relu",
|
24 |
+
# normalize_before=False,
|
25 |
+
# pos_embed_len=80,
|
26 |
+
# ph_embed_dim=128,
|
27 |
+
# ):
|
28 |
+
# super().__init__()
|
29 |
+
|
30 |
+
# encoder_layer = TransformerEncoderLayer(
|
31 |
+
# d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
32 |
+
# )
|
33 |
+
# encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
34 |
+
# self.encoder = TransformerEncoder(
|
35 |
+
# encoder_layer, num_encoder_layers, encoder_norm
|
36 |
+
# )
|
37 |
+
|
38 |
+
# _reset_parameters(self.encoder)
|
39 |
+
|
40 |
+
# self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
41 |
+
|
42 |
+
# self.ph_embedding = nn.Embedding(41, ph_embed_dim)
|
43 |
+
# self.increase_embed_dim = nn.Linear(ph_embed_dim, d_model)
|
44 |
+
|
45 |
+
# def forward(self, x):
|
46 |
+
# """
|
47 |
+
|
48 |
+
# Args:
|
49 |
+
# x (_type_): (B, num_frames, window)
|
50 |
+
|
51 |
+
# Returns:
|
52 |
+
# content: (B, num_frames, window, C_dmodel)
|
53 |
+
# """
|
54 |
+
# x_embedding = self.ph_embedding(x)
|
55 |
+
# x_embedding = self.increase_embed_dim(x_embedding)
|
56 |
+
# # (B, N, W, C)
|
57 |
+
# B, N, W, C = x_embedding.shape
|
58 |
+
# x_embedding = x_embedding.reshape(B * N, W, C)
|
59 |
+
# x_embedding = x_embedding.permute(1, 0, 2)
|
60 |
+
# # (W, B*N, C)
|
61 |
+
|
62 |
+
# pos = self.pos_embed(W)
|
63 |
+
# pos = pos.permute(1, 0, 2)
|
64 |
+
# # (W, 1, C)
|
65 |
+
|
66 |
+
# content = self.encoder(x_embedding, pos=pos)
|
67 |
+
# # (W, B*N, C)
|
68 |
+
# content = content.permute(1, 0, 2).reshape(B, N, W, C)
|
69 |
+
# # (B, N, W, C)
|
70 |
+
|
71 |
+
# return content
|
72 |
+
|
73 |
+
|
74 |
+
class ContentW2VEncoder(nn.Module):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
d_model=512,
|
78 |
+
nhead=8,
|
79 |
+
num_encoder_layers=6,
|
80 |
+
dim_feedforward=2048,
|
81 |
+
dropout=0.1,
|
82 |
+
activation="relu",
|
83 |
+
normalize_before=False,
|
84 |
+
pos_embed_len=80,
|
85 |
+
ph_embed_dim=128,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
encoder_layer = TransformerEncoderLayer(
|
90 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
91 |
+
)
|
92 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
93 |
+
self.encoder = TransformerEncoder(
|
94 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
95 |
+
)
|
96 |
+
|
97 |
+
_reset_parameters(self.encoder)
|
98 |
+
|
99 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
100 |
+
|
101 |
+
self.increase_embed_dim = nn.Linear(1024, d_model)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
"""
|
105 |
+
Args:
|
106 |
+
x (_type_): (B, num_frames, window, C_wav2vec)
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
content: (B, num_frames, window, C_dmodel)
|
110 |
+
"""
|
111 |
+
x_embedding = self.increase_embed_dim(
|
112 |
+
x
|
113 |
+
) # [16, 64, 11, 1024] -> [16, 64, 11, 256]
|
114 |
+
# (B, N, W, C)
|
115 |
+
B, N, W, C = x_embedding.shape
|
116 |
+
x_embedding = x_embedding.reshape(B * N, W, C)
|
117 |
+
x_embedding = x_embedding.permute(1, 0, 2) # [11, 1024, 256]
|
118 |
+
# (W, B*N, C)
|
119 |
+
|
120 |
+
pos = self.pos_embed(W)
|
121 |
+
pos = pos.permute(1, 0, 2) # [11, 1, 256]
|
122 |
+
# (W, 1, C)
|
123 |
+
|
124 |
+
content = self.encoder(x_embedding, pos=pos) # [11, 1024, 256]
|
125 |
+
# (W, B*N, C)
|
126 |
+
content = content.permute(1, 0, 2).reshape(B, N, W, C)
|
127 |
+
# (B, N, W, C)
|
128 |
+
|
129 |
+
return content
|
130 |
+
|
131 |
+
|
132 |
+
class StyleEncoder(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
d_model=512,
|
136 |
+
nhead=8,
|
137 |
+
num_encoder_layers=6,
|
138 |
+
dim_feedforward=2048,
|
139 |
+
dropout=0.1,
|
140 |
+
activation="relu",
|
141 |
+
normalize_before=False,
|
142 |
+
pos_embed_len=80,
|
143 |
+
input_dim=128,
|
144 |
+
aggregate_method="average",
|
145 |
+
):
|
146 |
+
super().__init__()
|
147 |
+
encoder_layer = TransformerEncoderLayer(
|
148 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
149 |
+
)
|
150 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
151 |
+
self.encoder = TransformerEncoder(
|
152 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
153 |
+
)
|
154 |
+
_reset_parameters(self.encoder)
|
155 |
+
|
156 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
157 |
+
|
158 |
+
self.increase_embed_dim = nn.Linear(input_dim, d_model)
|
159 |
+
|
160 |
+
self.aggregate_method = None
|
161 |
+
if aggregate_method == "self_attention_pooling":
|
162 |
+
self.aggregate_method = SelfAttentionPooling(d_model)
|
163 |
+
elif aggregate_method == "average":
|
164 |
+
pass
|
165 |
+
else:
|
166 |
+
raise ValueError(f"Invalid aggregate method {aggregate_method}")
|
167 |
+
|
168 |
+
def forward(self, x, pad_mask=None):
|
169 |
+
"""
|
170 |
+
|
171 |
+
Args:
|
172 |
+
x (_type_): (B, num_frames(L), C_exp)
|
173 |
+
pad_mask: (B, num_frames)
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
style_code: (B, C_model)
|
177 |
+
"""
|
178 |
+
x = self.increase_embed_dim(x)
|
179 |
+
# (B, L, C)
|
180 |
+
x = x.permute(1, 0, 2)
|
181 |
+
# (L, B, C)
|
182 |
+
|
183 |
+
pos = self.pos_embed(x.shape[0])
|
184 |
+
pos = pos.permute(1, 0, 2)
|
185 |
+
# (L, 1, C)
|
186 |
+
|
187 |
+
style = self.encoder(x, pos=pos, src_key_padding_mask=pad_mask)
|
188 |
+
# (L, B, C)
|
189 |
+
|
190 |
+
if self.aggregate_method is not None:
|
191 |
+
permute_style = style.permute(1, 0, 2)
|
192 |
+
# (B, L, C)
|
193 |
+
style_code = self.aggregate_method(permute_style, pad_mask)
|
194 |
+
return style_code
|
195 |
+
|
196 |
+
if pad_mask is None:
|
197 |
+
style = style.permute(1, 2, 0)
|
198 |
+
# (B, C, L)
|
199 |
+
style_code = style.mean(2)
|
200 |
+
# (B, C)
|
201 |
+
else:
|
202 |
+
permute_style = style.permute(1, 0, 2)
|
203 |
+
# (B, L, C)
|
204 |
+
permute_style[pad_mask] = 0
|
205 |
+
sum_style_code = permute_style.sum(dim=1)
|
206 |
+
# (B, C)
|
207 |
+
valid_token_num = (~pad_mask).sum(dim=1).unsqueeze(-1)
|
208 |
+
# (B, 1)
|
209 |
+
style_code = sum_style_code / valid_token_num
|
210 |
+
# (B, C)
|
211 |
+
|
212 |
+
return style_code
|
213 |
+
|
214 |
+
|
215 |
+
class Decoder(nn.Module):
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
d_model=512,
|
219 |
+
nhead=8,
|
220 |
+
num_decoder_layers=3,
|
221 |
+
dim_feedforward=2048,
|
222 |
+
dropout=0.1,
|
223 |
+
activation="relu",
|
224 |
+
normalize_before=False,
|
225 |
+
return_intermediate_dec=False,
|
226 |
+
pos_embed_len=80,
|
227 |
+
output_dim=64,
|
228 |
+
**_,
|
229 |
+
) -> None:
|
230 |
+
super().__init__()
|
231 |
+
|
232 |
+
decoder_layer = TransformerDecoderLayer(
|
233 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
234 |
+
)
|
235 |
+
decoder_norm = nn.LayerNorm(d_model)
|
236 |
+
self.decoder = TransformerDecoder(
|
237 |
+
decoder_layer,
|
238 |
+
num_decoder_layers,
|
239 |
+
decoder_norm,
|
240 |
+
return_intermediate=return_intermediate_dec,
|
241 |
+
)
|
242 |
+
_reset_parameters(self.decoder)
|
243 |
+
|
244 |
+
self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
|
245 |
+
|
246 |
+
tail_hidden_dim = d_model // 2
|
247 |
+
self.tail_fc = nn.Sequential(
|
248 |
+
nn.Linear(d_model, tail_hidden_dim),
|
249 |
+
nn.ReLU(),
|
250 |
+
nn.Linear(tail_hidden_dim, tail_hidden_dim),
|
251 |
+
nn.ReLU(),
|
252 |
+
nn.Linear(tail_hidden_dim, output_dim),
|
253 |
+
)
|
254 |
+
|
255 |
+
def forward(self, content, style_code):
|
256 |
+
"""
|
257 |
+
|
258 |
+
Args:
|
259 |
+
content (_type_): (B, num_frames, window, C_dmodel)
|
260 |
+
style_code (_type_): (B, C_dmodel)
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
face3d: (B, num_frames, C_3dmm)
|
264 |
+
"""
|
265 |
+
B, N, W, C = content.shape
|
266 |
+
style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
|
267 |
+
style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
268 |
+
# (W, B*N, C)
|
269 |
+
|
270 |
+
content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
|
271 |
+
# (W, B*N, C)
|
272 |
+
tgt = torch.zeros_like(style)
|
273 |
+
pos_embed = self.pos_embed(W)
|
274 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
275 |
+
face3d_feat = self.decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
|
276 |
+
# (W, B*N, C)
|
277 |
+
face3d_feat = face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
|
278 |
+
# (B, N, C)
|
279 |
+
face3d = self.tail_fc(face3d_feat)
|
280 |
+
# (B, N, C_exp)
|
281 |
+
return face3d
|
282 |
+
|
283 |
+
|
284 |
+
if __name__ == "__main__":
|
285 |
+
import sys
|
286 |
+
|
287 |
+
sys.path.append("/home/mayifeng/Research/styleTH")
|
288 |
+
|
289 |
+
from configs.default import get_cfg_defaults
|
290 |
+
|
291 |
+
cfg = get_cfg_defaults()
|
292 |
+
cfg.merge_from_file("configs/styleTH_bp.yaml")
|
293 |
+
cfg.freeze()
|
294 |
+
|
295 |
+
# content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
|
296 |
+
|
297 |
+
# dummy_audio = torch.randint(0, 41, (5, 64, 11))
|
298 |
+
# dummy_content = content_encoder(dummy_audio)
|
299 |
+
|
300 |
+
# style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
|
301 |
+
# dummy_face3d_seq = torch.randn(5, 64, 64)
|
302 |
+
# dummy_style_code = style_encoder(dummy_face3d_seq)
|
303 |
+
|
304 |
+
decoder = Decoder(**cfg.DECODER)
|
305 |
+
dummy_content = torch.randn(5, 64, 11, 512)
|
306 |
+
dummy_style = torch.randn(5, 512)
|
307 |
+
dummy_output = decoder(dummy_content, dummy_style)
|
308 |
+
|
309 |
+
print("hello")
|
damo/dreamtalk/core/networks/mish.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Applies the mish function element-wise:
|
3 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
4 |
+
"""
|
5 |
+
|
6 |
+
# import pytorch
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
@torch.jit.script
|
12 |
+
def mish(input):
|
13 |
+
"""
|
14 |
+
Applies the mish function element-wise:
|
15 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
16 |
+
See additional documentation for mish class.
|
17 |
+
"""
|
18 |
+
return input * torch.tanh(F.softplus(input))
|
19 |
+
|
20 |
+
class Mish(nn.Module):
|
21 |
+
"""
|
22 |
+
Applies the mish function element-wise:
|
23 |
+
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
|
24 |
+
|
25 |
+
Shape:
|
26 |
+
- Input: (N, *) where * means, any number of additional
|
27 |
+
dimensions
|
28 |
+
- Output: (N, *), same shape as the input
|
29 |
+
|
30 |
+
Examples:
|
31 |
+
>>> m = Mish()
|
32 |
+
>>> input = torch.randn(2)
|
33 |
+
>>> output = m(input)
|
34 |
+
|
35 |
+
Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
"""
|
40 |
+
Init method.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
def forward(self, input):
|
45 |
+
"""
|
46 |
+
Forward pass of the function.
|
47 |
+
"""
|
48 |
+
if torch.__version__ >= "1.9":
|
49 |
+
return F.mish(input)
|
50 |
+
else:
|
51 |
+
return mish(input)
|
damo/dreamtalk/core/networks/self_attention_pooling.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from core.networks.mish import Mish
|
4 |
+
|
5 |
+
|
6 |
+
class SelfAttentionPooling(nn.Module):
|
7 |
+
"""
|
8 |
+
Implementation of SelfAttentionPooling
|
9 |
+
Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
|
10 |
+
https://arxiv.org/pdf/2008.01077v1.pdf
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, input_dim):
|
14 |
+
super(SelfAttentionPooling, self).__init__()
|
15 |
+
self.W = nn.Sequential(nn.Linear(input_dim, input_dim), Mish(), nn.Linear(input_dim, 1))
|
16 |
+
self.softmax = nn.functional.softmax
|
17 |
+
|
18 |
+
def forward(self, batch_rep, att_mask=None):
|
19 |
+
"""
|
20 |
+
N: batch size, T: sequence length, H: Hidden dimension
|
21 |
+
input:
|
22 |
+
batch_rep : size (N, T, H)
|
23 |
+
attention_weight:
|
24 |
+
att_w : size (N, T, 1)
|
25 |
+
att_mask:
|
26 |
+
att_mask: size (N, T): if True, mask this item.
|
27 |
+
return:
|
28 |
+
utter_rep: size (N, H)
|
29 |
+
"""
|
30 |
+
|
31 |
+
att_logits = self.W(batch_rep).squeeze(-1)
|
32 |
+
# (N, T)
|
33 |
+
if att_mask is not None:
|
34 |
+
att_mask_logits = att_mask.to(dtype=batch_rep.dtype) * -100000.0
|
35 |
+
# (N, T)
|
36 |
+
att_logits = att_mask_logits + att_logits
|
37 |
+
|
38 |
+
att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
|
39 |
+
utter_rep = torch.sum(batch_rep * att_w, dim=1)
|
40 |
+
|
41 |
+
return utter_rep
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
batch = torch.randn(8, 64, 256)
|
46 |
+
self_attn_pool = SelfAttentionPooling(256)
|
47 |
+
att_mask = torch.zeros(8, 64)
|
48 |
+
att_mask[:, 60:] = 1
|
49 |
+
att_mask = att_mask.to(torch.bool)
|
50 |
+
output = self_attn_pool(batch, att_mask)
|
51 |
+
# (8, 256)
|
52 |
+
|
53 |
+
print("hello")
|
damo/dreamtalk/core/networks/transformer.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import copy
|
6 |
+
|
7 |
+
|
8 |
+
class PositionalEncoding(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, d_hid, n_position=200):
|
11 |
+
super(PositionalEncoding, self).__init__()
|
12 |
+
|
13 |
+
# Not a parameter
|
14 |
+
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
15 |
+
|
16 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
17 |
+
''' Sinusoid position encoding table '''
|
18 |
+
# TODO: make it with torch instead of numpy
|
19 |
+
|
20 |
+
def get_position_angle_vec(position):
|
21 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
22 |
+
|
23 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
24 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
25 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
26 |
+
|
27 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
28 |
+
|
29 |
+
def forward(self, winsize):
|
30 |
+
return self.pos_table[:, :winsize].clone().detach()
|
31 |
+
|
32 |
+
def _get_activation_fn(activation):
|
33 |
+
"""Return an activation function given a string"""
|
34 |
+
if activation == "relu":
|
35 |
+
return F.relu
|
36 |
+
if activation == "gelu":
|
37 |
+
return F.gelu
|
38 |
+
if activation == "glu":
|
39 |
+
return F.glu
|
40 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
41 |
+
|
42 |
+
def _get_clones(module, N):
|
43 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
44 |
+
|
45 |
+
class Transformer(nn.Module):
|
46 |
+
|
47 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
48 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
49 |
+
activation="relu", normalize_before=False,
|
50 |
+
return_intermediate_dec=True):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
54 |
+
dropout, activation, normalize_before)
|
55 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
56 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
57 |
+
|
58 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
59 |
+
dropout, activation, normalize_before)
|
60 |
+
decoder_norm = nn.LayerNorm(d_model)
|
61 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
62 |
+
return_intermediate=return_intermediate_dec)
|
63 |
+
|
64 |
+
self._reset_parameters()
|
65 |
+
|
66 |
+
self.d_model = d_model
|
67 |
+
self.nhead = nhead
|
68 |
+
|
69 |
+
def _reset_parameters(self):
|
70 |
+
for p in self.parameters():
|
71 |
+
if p.dim() > 1:
|
72 |
+
nn.init.xavier_uniform_(p)
|
73 |
+
|
74 |
+
def forward(self,opt, src, query_embed, pos_embed):
|
75 |
+
# flatten NxCxHxW to HWxNxC
|
76 |
+
|
77 |
+
src = src.permute(1, 0, 2)
|
78 |
+
pos_embed = pos_embed.permute(1, 0, 2)
|
79 |
+
query_embed = query_embed.permute(1, 0, 2)
|
80 |
+
|
81 |
+
tgt = torch.zeros_like(query_embed)
|
82 |
+
memory = self.encoder(src, pos=pos_embed)
|
83 |
+
|
84 |
+
hs = self.decoder(tgt, memory,
|
85 |
+
pos=pos_embed, query_pos=query_embed)
|
86 |
+
return hs
|
87 |
+
|
88 |
+
|
89 |
+
class TransformerEncoder(nn.Module):
|
90 |
+
|
91 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
92 |
+
super().__init__()
|
93 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
94 |
+
self.num_layers = num_layers
|
95 |
+
self.norm = norm
|
96 |
+
|
97 |
+
def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):
|
98 |
+
output = src+pos
|
99 |
+
|
100 |
+
for layer in self.layers:
|
101 |
+
output = layer(output, src_mask=mask,
|
102 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
103 |
+
|
104 |
+
if self.norm is not None:
|
105 |
+
output = self.norm(output)
|
106 |
+
|
107 |
+
return output
|
108 |
+
|
109 |
+
|
110 |
+
class TransformerDecoder(nn.Module):
|
111 |
+
|
112 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
113 |
+
super().__init__()
|
114 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
115 |
+
self.num_layers = num_layers
|
116 |
+
self.norm = norm
|
117 |
+
self.return_intermediate = return_intermediate
|
118 |
+
|
119 |
+
def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,
|
120 |
+
memory_key_padding_mask = None,
|
121 |
+
pos = None,
|
122 |
+
query_pos = None):
|
123 |
+
output = tgt+pos+query_pos
|
124 |
+
|
125 |
+
intermediate = []
|
126 |
+
|
127 |
+
for layer in self.layers:
|
128 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
129 |
+
memory_mask=memory_mask,
|
130 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
131 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
132 |
+
pos=pos, query_pos=query_pos)
|
133 |
+
if self.return_intermediate:
|
134 |
+
intermediate.append(self.norm(output))
|
135 |
+
|
136 |
+
if self.norm is not None:
|
137 |
+
output = self.norm(output)
|
138 |
+
if self.return_intermediate:
|
139 |
+
intermediate.pop()
|
140 |
+
intermediate.append(output)
|
141 |
+
|
142 |
+
if self.return_intermediate:
|
143 |
+
return torch.stack(intermediate)
|
144 |
+
|
145 |
+
return output.unsqueeze(0)
|
146 |
+
|
147 |
+
|
148 |
+
class TransformerEncoderLayer(nn.Module):
|
149 |
+
|
150 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
151 |
+
activation="relu", normalize_before=False):
|
152 |
+
super().__init__()
|
153 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
154 |
+
# Implementation of Feedforward model
|
155 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
156 |
+
self.dropout = nn.Dropout(dropout)
|
157 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
158 |
+
|
159 |
+
self.norm1 = nn.LayerNorm(d_model)
|
160 |
+
self.norm2 = nn.LayerNorm(d_model)
|
161 |
+
self.dropout1 = nn.Dropout(dropout)
|
162 |
+
self.dropout2 = nn.Dropout(dropout)
|
163 |
+
|
164 |
+
self.activation = _get_activation_fn(activation)
|
165 |
+
self.normalize_before = normalize_before
|
166 |
+
|
167 |
+
def with_pos_embed(self, tensor, pos):
|
168 |
+
return tensor if pos is None else tensor + pos
|
169 |
+
|
170 |
+
def forward_post(self,
|
171 |
+
src,
|
172 |
+
src_mask = None,
|
173 |
+
src_key_padding_mask = None,
|
174 |
+
pos = None):
|
175 |
+
# q = k = self.with_pos_embed(src, pos)
|
176 |
+
src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,
|
177 |
+
key_padding_mask=src_key_padding_mask)[0]
|
178 |
+
src = src + self.dropout1(src2)
|
179 |
+
src = self.norm1(src)
|
180 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
181 |
+
src = src + self.dropout2(src2)
|
182 |
+
src = self.norm2(src)
|
183 |
+
return src
|
184 |
+
|
185 |
+
def forward_pre(self, src,
|
186 |
+
src_mask = None,
|
187 |
+
src_key_padding_mask = None,
|
188 |
+
pos = None):
|
189 |
+
src2 = self.norm1(src)
|
190 |
+
# q = k = self.with_pos_embed(src2, pos)
|
191 |
+
src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,
|
192 |
+
key_padding_mask=src_key_padding_mask)[0]
|
193 |
+
src = src + self.dropout1(src2)
|
194 |
+
src2 = self.norm2(src)
|
195 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
196 |
+
src = src + self.dropout2(src2)
|
197 |
+
return src
|
198 |
+
|
199 |
+
def forward(self, src,
|
200 |
+
src_mask = None,
|
201 |
+
src_key_padding_mask = None,
|
202 |
+
pos = None):
|
203 |
+
if self.normalize_before:
|
204 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
205 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
206 |
+
|
207 |
+
|
208 |
+
class TransformerDecoderLayer(nn.Module):
|
209 |
+
|
210 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
211 |
+
activation="relu", normalize_before=False):
|
212 |
+
super().__init__()
|
213 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
214 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
215 |
+
# Implementation of Feedforward model
|
216 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
217 |
+
self.dropout = nn.Dropout(dropout)
|
218 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
219 |
+
|
220 |
+
self.norm1 = nn.LayerNorm(d_model)
|
221 |
+
self.norm2 = nn.LayerNorm(d_model)
|
222 |
+
self.norm3 = nn.LayerNorm(d_model)
|
223 |
+
self.dropout1 = nn.Dropout(dropout)
|
224 |
+
self.dropout2 = nn.Dropout(dropout)
|
225 |
+
self.dropout3 = nn.Dropout(dropout)
|
226 |
+
|
227 |
+
self.activation = _get_activation_fn(activation)
|
228 |
+
self.normalize_before = normalize_before
|
229 |
+
|
230 |
+
def with_pos_embed(self, tensor, pos):
|
231 |
+
return tensor if pos is None else tensor + pos
|
232 |
+
|
233 |
+
def forward_post(self, tgt, memory,
|
234 |
+
tgt_mask = None,
|
235 |
+
memory_mask = None,
|
236 |
+
tgt_key_padding_mask = None,
|
237 |
+
memory_key_padding_mask = None,
|
238 |
+
pos = None,
|
239 |
+
query_pos = None):
|
240 |
+
# q = k = self.with_pos_embed(tgt, query_pos)
|
241 |
+
tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,
|
242 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
243 |
+
tgt = tgt + self.dropout1(tgt2)
|
244 |
+
tgt = self.norm1(tgt)
|
245 |
+
tgt2 = self.multihead_attn(query=tgt,
|
246 |
+
key=memory,
|
247 |
+
value=memory, attn_mask=memory_mask,
|
248 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
249 |
+
tgt = tgt + self.dropout2(tgt2)
|
250 |
+
tgt = self.norm2(tgt)
|
251 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
252 |
+
tgt = tgt + self.dropout3(tgt2)
|
253 |
+
tgt = self.norm3(tgt)
|
254 |
+
return tgt
|
255 |
+
|
256 |
+
def forward_pre(self, tgt, memory,
|
257 |
+
tgt_mask = None,
|
258 |
+
memory_mask = None,
|
259 |
+
tgt_key_padding_mask = None,
|
260 |
+
memory_key_padding_mask = None,
|
261 |
+
pos = None,
|
262 |
+
query_pos = None):
|
263 |
+
tgt2 = self.norm1(tgt)
|
264 |
+
# q = k = self.with_pos_embed(tgt2, query_pos)
|
265 |
+
tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,
|
266 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
267 |
+
tgt = tgt + self.dropout1(tgt2)
|
268 |
+
tgt2 = self.norm2(tgt)
|
269 |
+
tgt2 = self.multihead_attn(query=tgt2,
|
270 |
+
key=memory,
|
271 |
+
value=memory, attn_mask=memory_mask,
|
272 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
273 |
+
tgt = tgt + self.dropout2(tgt2)
|
274 |
+
tgt2 = self.norm3(tgt)
|
275 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
276 |
+
tgt = tgt + self.dropout3(tgt2)
|
277 |
+
return tgt
|
278 |
+
|
279 |
+
def forward(self, tgt, memory,
|
280 |
+
tgt_mask = None,
|
281 |
+
memory_mask = None,
|
282 |
+
tgt_key_padding_mask = None,
|
283 |
+
memory_key_padding_mask = None,
|
284 |
+
pos = None,
|
285 |
+
query_pos = None):
|
286 |
+
if self.normalize_before:
|
287 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
288 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
289 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
290 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
291 |
+
|
292 |
+
|
293 |
+
|
damo/dreamtalk/core/utils.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from collections import defaultdict
|
4 |
+
import logging
|
5 |
+
import pickle
|
6 |
+
import json
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from scipy.io import loadmat
|
12 |
+
|
13 |
+
from configs.default import get_cfg_defaults
|
14 |
+
import dlib
|
15 |
+
import cv2
|
16 |
+
|
17 |
+
|
18 |
+
def _reset_parameters(model):
|
19 |
+
for p in model.parameters():
|
20 |
+
if p.dim() > 1:
|
21 |
+
nn.init.xavier_uniform_(p)
|
22 |
+
|
23 |
+
|
24 |
+
def get_video_style(video_name, style_type):
|
25 |
+
person_id, direction, emotion, level, *_ = video_name.split("_")
|
26 |
+
if style_type == "id_dir_emo_level":
|
27 |
+
style = "_".join([person_id, direction, emotion, level])
|
28 |
+
elif style_type == "emotion":
|
29 |
+
style = emotion
|
30 |
+
elif style_type == "id":
|
31 |
+
style = person_id
|
32 |
+
else:
|
33 |
+
raise ValueError("Unknown style type")
|
34 |
+
|
35 |
+
return style
|
36 |
+
|
37 |
+
|
38 |
+
def get_style_video_lists(video_list, style_type):
|
39 |
+
style2video_list = defaultdict(list)
|
40 |
+
for video in video_list:
|
41 |
+
style = get_video_style(video, style_type)
|
42 |
+
style2video_list[style].append(video)
|
43 |
+
|
44 |
+
return style2video_list
|
45 |
+
|
46 |
+
|
47 |
+
def get_face3d_clip(
|
48 |
+
video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
|
49 |
+
):
|
50 |
+
"""_summary_
|
51 |
+
|
52 |
+
Args:
|
53 |
+
video_name (_type_): _description_
|
54 |
+
video_root_dir (_type_): _description_
|
55 |
+
num_frames (_type_): _description_
|
56 |
+
start_idx (_type_): "random" , middle, int
|
57 |
+
dtype (_type_, optional): _description_. Defaults to torch.float32.
|
58 |
+
|
59 |
+
Raises:
|
60 |
+
ValueError: _description_
|
61 |
+
ValueError: _description_
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
_type_: _description_
|
65 |
+
"""
|
66 |
+
video_path = os.path.join(video_root_dir, video_name)
|
67 |
+
if video_path[-3:] == "mat":
|
68 |
+
face3d_all = loadmat(video_path)["coeff"]
|
69 |
+
face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
70 |
+
elif video_path[-3:] == "txt":
|
71 |
+
face3d_exp = np.loadtxt(video_path)
|
72 |
+
else:
|
73 |
+
raise ValueError("Invalid 3DMM file extension")
|
74 |
+
|
75 |
+
length = face3d_exp.shape[0]
|
76 |
+
clip_num_frames = num_frames
|
77 |
+
if start_idx == "random":
|
78 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
79 |
+
elif start_idx == "middle":
|
80 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
81 |
+
elif isinstance(start_idx, int):
|
82 |
+
clip_start_idx = start_idx
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
85 |
+
|
86 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
87 |
+
face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
|
88 |
+
|
89 |
+
return face3d_clip
|
90 |
+
|
91 |
+
|
92 |
+
def get_video_style_clip(
|
93 |
+
video_name,
|
94 |
+
video_root_dir,
|
95 |
+
style_max_len,
|
96 |
+
start_idx="random",
|
97 |
+
dtype=torch.float32,
|
98 |
+
return_start_idx=False,
|
99 |
+
):
|
100 |
+
video_path = os.path.join(video_root_dir, video_name)
|
101 |
+
if video_path[-3:] == "mat":
|
102 |
+
face3d_all = loadmat(video_path)["coeff"]
|
103 |
+
face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
|
104 |
+
elif video_path[-3:] == "txt":
|
105 |
+
face3d_exp = np.loadtxt(video_path)
|
106 |
+
else:
|
107 |
+
raise ValueError("Invalid 3DMM file extension")
|
108 |
+
|
109 |
+
face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
|
110 |
+
|
111 |
+
length = face3d_exp.shape[0]
|
112 |
+
if length >= style_max_len:
|
113 |
+
clip_num_frames = style_max_len
|
114 |
+
if start_idx == "random":
|
115 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
116 |
+
elif start_idx == "middle":
|
117 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
118 |
+
elif isinstance(start_idx, int):
|
119 |
+
clip_start_idx = start_idx
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
122 |
+
|
123 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
124 |
+
pad_mask = torch.tensor([False] * style_max_len)
|
125 |
+
else:
|
126 |
+
clip_start_idx = None
|
127 |
+
padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
|
128 |
+
face3d_clip = torch.cat((face3d_exp, padding), dim=0)
|
129 |
+
pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
|
130 |
+
|
131 |
+
if return_start_idx:
|
132 |
+
return face3d_clip, pad_mask, clip_start_idx
|
133 |
+
else:
|
134 |
+
return face3d_clip, pad_mask
|
135 |
+
|
136 |
+
|
137 |
+
def get_video_style_clip_from_np(
|
138 |
+
face3d_exp,
|
139 |
+
style_max_len,
|
140 |
+
start_idx="random",
|
141 |
+
dtype=torch.float32,
|
142 |
+
return_start_idx=False,
|
143 |
+
):
|
144 |
+
face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
|
145 |
+
|
146 |
+
length = face3d_exp.shape[0]
|
147 |
+
if length >= style_max_len:
|
148 |
+
clip_num_frames = style_max_len
|
149 |
+
if start_idx == "random":
|
150 |
+
clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
|
151 |
+
elif start_idx == "middle":
|
152 |
+
clip_start_idx = (length - clip_num_frames + 1) // 2
|
153 |
+
elif isinstance(start_idx, int):
|
154 |
+
clip_start_idx = start_idx
|
155 |
+
else:
|
156 |
+
raise ValueError(f"Invalid start_idx {start_idx}")
|
157 |
+
|
158 |
+
face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
|
159 |
+
pad_mask = torch.tensor([False] * style_max_len)
|
160 |
+
else:
|
161 |
+
clip_start_idx = None
|
162 |
+
padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
|
163 |
+
face3d_clip = torch.cat((face3d_exp, padding), dim=0)
|
164 |
+
pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
|
165 |
+
|
166 |
+
if return_start_idx:
|
167 |
+
return face3d_clip, pad_mask, clip_start_idx
|
168 |
+
else:
|
169 |
+
return face3d_clip, pad_mask
|
170 |
+
|
171 |
+
|
172 |
+
def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
|
173 |
+
"""
|
174 |
+
|
175 |
+
Args:
|
176 |
+
audio_feat (np.ndarray): (N, 1024)
|
177 |
+
start_idx (_type_): _description_
|
178 |
+
num_frames (_type_): _description_
|
179 |
+
"""
|
180 |
+
center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
|
181 |
+
audio_window_list = []
|
182 |
+
padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
|
183 |
+
for center_idx in center_idx_list:
|
184 |
+
cur_audio_window = []
|
185 |
+
for i in range(center_idx - win_size, center_idx + win_size + 1):
|
186 |
+
if i < 0:
|
187 |
+
cur_audio_window.append(padding)
|
188 |
+
elif i >= len(audio_feat):
|
189 |
+
cur_audio_window.append(padding)
|
190 |
+
else:
|
191 |
+
cur_audio_window.append(audio_feat[i])
|
192 |
+
cur_audio_win_array = np.stack(cur_audio_window, axis=0)
|
193 |
+
audio_window_list.append(cur_audio_win_array)
|
194 |
+
|
195 |
+
audio_window_array = np.stack(audio_window_list, axis=0)
|
196 |
+
return audio_window_array
|
197 |
+
|
198 |
+
|
199 |
+
def setup_config():
|
200 |
+
parser = argparse.ArgumentParser(description="voice2pose main program")
|
201 |
+
parser.add_argument(
|
202 |
+
"--config_file", default="", metavar="FILE", help="path to config file"
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--resume_from", type=str, default=None, help="the checkpoint to resume from"
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--test_only", action="store_true", help="perform testing and evaluation only"
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--demo_input", type=str, default=None, help="path to input for demo"
|
212 |
+
)
|
213 |
+
parser.add_argument(
|
214 |
+
"--checkpoint", type=str, default=None, help="the checkpoint to test with"
|
215 |
+
)
|
216 |
+
parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
|
217 |
+
parser.add_argument(
|
218 |
+
"opts",
|
219 |
+
help="Modify config options using the command-line",
|
220 |
+
default=None,
|
221 |
+
nargs=argparse.REMAINDER,
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--local_rank",
|
225 |
+
type=int,
|
226 |
+
help="local rank for DistributedDataParallel",
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--master_port",
|
230 |
+
type=str,
|
231 |
+
default="12345",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--max_audio_len",
|
235 |
+
type=int,
|
236 |
+
default=450,
|
237 |
+
help="max_audio_len for inference",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--ddim_num_step",
|
241 |
+
type=int,
|
242 |
+
default=10,
|
243 |
+
)
|
244 |
+
parser.add_argument(
|
245 |
+
"--inference_seed",
|
246 |
+
type=int,
|
247 |
+
default=1,
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--inference_sample_method",
|
251 |
+
type=str,
|
252 |
+
default="ddim",
|
253 |
+
)
|
254 |
+
args = parser.parse_args()
|
255 |
+
|
256 |
+
cfg = get_cfg_defaults()
|
257 |
+
cfg.merge_from_file(args.config_file)
|
258 |
+
cfg.merge_from_list(args.opts)
|
259 |
+
cfg.freeze()
|
260 |
+
return args, cfg
|
261 |
+
|
262 |
+
|
263 |
+
def setup_logger(base_path, exp_name):
|
264 |
+
rootLogger = logging.getLogger()
|
265 |
+
rootLogger.setLevel(logging.INFO)
|
266 |
+
|
267 |
+
logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
|
268 |
+
|
269 |
+
log_path = "{0}/{1}.log".format(base_path, exp_name)
|
270 |
+
fileHandler = logging.FileHandler(log_path)
|
271 |
+
fileHandler.setFormatter(logFormatter)
|
272 |
+
rootLogger.addHandler(fileHandler)
|
273 |
+
|
274 |
+
consoleHandler = logging.StreamHandler()
|
275 |
+
consoleHandler.setFormatter(logFormatter)
|
276 |
+
rootLogger.addHandler(consoleHandler)
|
277 |
+
rootLogger.handlers[0].setLevel(logging.INFO)
|
278 |
+
|
279 |
+
logging.info("log path: %s" % log_path)
|
280 |
+
|
281 |
+
|
282 |
+
def cosine_loss(a, v, y, logloss=nn.BCELoss()):
|
283 |
+
d = nn.functional.cosine_similarity(a, v)
|
284 |
+
loss = logloss(d.unsqueeze(1), y)
|
285 |
+
return loss
|
286 |
+
|
287 |
+
|
288 |
+
def get_pose_params(mat_path):
|
289 |
+
"""Get pose parameters from mat file
|
290 |
+
|
291 |
+
Args:
|
292 |
+
mat_path (str): path of mat file
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
|
296 |
+
"""
|
297 |
+
mat_dict = loadmat(mat_path)
|
298 |
+
|
299 |
+
np_3dmm = mat_dict["coeff"]
|
300 |
+
angles = np_3dmm[:, 224:227]
|
301 |
+
translations = np_3dmm[:, 254:257]
|
302 |
+
|
303 |
+
np_trans_params = mat_dict["transform_params"]
|
304 |
+
crop = np_trans_params[:, -3:]
|
305 |
+
|
306 |
+
pose_params = np.concatenate((angles, translations, crop), axis=1)
|
307 |
+
|
308 |
+
return pose_params
|
309 |
+
|
310 |
+
|
311 |
+
def sinusoidal_embedding(timesteps, dim):
|
312 |
+
"""
|
313 |
+
|
314 |
+
Args:
|
315 |
+
timesteps (_type_): (B,)
|
316 |
+
dim (_type_): (C_embed)
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
_type_: (B, C_embed)
|
320 |
+
"""
|
321 |
+
# check input
|
322 |
+
half = dim // 2
|
323 |
+
timesteps = timesteps.float()
|
324 |
+
|
325 |
+
# compute sinusoidal embedding
|
326 |
+
sinusoid = torch.outer(
|
327 |
+
timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
|
328 |
+
)
|
329 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
330 |
+
if dim % 2 != 0:
|
331 |
+
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
|
336 |
+
"""
|
337 |
+
|
338 |
+
Args:
|
339 |
+
audio_feat (np.ndarray): (250, 1024)
|
340 |
+
start_idx (_type_): _description_
|
341 |
+
num_frames (_type_): _description_
|
342 |
+
"""
|
343 |
+
center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
|
344 |
+
audio_window_list = []
|
345 |
+
padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
|
346 |
+
for center_idx in center_idx_list:
|
347 |
+
cur_audio_window = []
|
348 |
+
for i in range(center_idx - win_size, center_idx + win_size + 1):
|
349 |
+
if i < 0:
|
350 |
+
cur_audio_window.append(padding)
|
351 |
+
elif i >= len(audio_feat):
|
352 |
+
cur_audio_window.append(padding)
|
353 |
+
else:
|
354 |
+
cur_audio_window.append(audio_feat[i])
|
355 |
+
cur_audio_win_array = np.stack(cur_audio_window, axis=0)
|
356 |
+
audio_window_list.append(cur_audio_win_array)
|
357 |
+
|
358 |
+
audio_window_array = np.stack(audio_window_list, axis=0)
|
359 |
+
return audio_window_array
|
360 |
+
|
361 |
+
|
362 |
+
def reshape_audio_feat(style_audio_all_raw, stride):
|
363 |
+
"""_summary_
|
364 |
+
|
365 |
+
Args:
|
366 |
+
style_audio_all_raw (_type_): (stride * L, C)
|
367 |
+
stride (_type_): int
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
_type_: (L, C * stride)
|
371 |
+
"""
|
372 |
+
style_audio_all_raw = style_audio_all_raw[
|
373 |
+
: style_audio_all_raw.shape[0] // stride * stride
|
374 |
+
]
|
375 |
+
style_audio_all_raw = style_audio_all_raw.reshape(
|
376 |
+
style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
|
377 |
+
)
|
378 |
+
style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
|
379 |
+
return style_audio_all
|
380 |
+
|
381 |
+
|
382 |
+
import random
|
383 |
+
|
384 |
+
|
385 |
+
def get_derangement_tuple(n):
|
386 |
+
while True:
|
387 |
+
v = [i for i in range(n)]
|
388 |
+
for j in range(n - 1, -1, -1):
|
389 |
+
p = random.randint(0, j)
|
390 |
+
if v[p] == j:
|
391 |
+
break
|
392 |
+
else:
|
393 |
+
v[j], v[p] = v[p], v[j]
|
394 |
+
else:
|
395 |
+
if v[0] != 0:
|
396 |
+
return tuple(v)
|
397 |
+
|
398 |
+
|
399 |
+
def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
|
400 |
+
left, top, right, bot = bbox
|
401 |
+
width = right - left
|
402 |
+
height = bot - top
|
403 |
+
|
404 |
+
width_increase = max(
|
405 |
+
increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
|
406 |
+
)
|
407 |
+
height_increase = max(
|
408 |
+
increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
|
409 |
+
)
|
410 |
+
|
411 |
+
left_t = int(left - width_increase * width)
|
412 |
+
top_t = int(top - height_increase * height)
|
413 |
+
right_t = int(right + width_increase * width)
|
414 |
+
bot_t = int(bot + height_increase * height)
|
415 |
+
|
416 |
+
left_oob = -min(0, left_t)
|
417 |
+
right_oob = right - min(right_t, w)
|
418 |
+
top_oob = -min(0, top_t)
|
419 |
+
bot_oob = bot - min(bot_t, h)
|
420 |
+
|
421 |
+
if max(left_oob, right_oob, top_oob, bot_oob) > 0:
|
422 |
+
max_w = max(left_oob, right_oob)
|
423 |
+
max_h = max(top_oob, bot_oob)
|
424 |
+
if max_w > max_h:
|
425 |
+
return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
|
426 |
+
else:
|
427 |
+
return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
|
428 |
+
|
429 |
+
else:
|
430 |
+
return (left_t, top_t, right_t, bot_t)
|
431 |
+
|
432 |
+
|
433 |
+
def crop_src_image(src_img, save_img, increase_ratio, detector=None):
|
434 |
+
if detector is None:
|
435 |
+
detector = dlib.get_frontal_face_detector()
|
436 |
+
|
437 |
+
img = cv2.imread(src_img)
|
438 |
+
faces = detector(img, 0)
|
439 |
+
h, width, _ = img.shape
|
440 |
+
if len(faces) > 0:
|
441 |
+
bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
|
442 |
+
l = bbox[3] - bbox[1]
|
443 |
+
bbox[1] = bbox[1] - l * 0.1
|
444 |
+
bbox[3] = bbox[3] - l * 0.1
|
445 |
+
bbox[1] = max(0, bbox[1])
|
446 |
+
bbox[3] = min(h, bbox[3])
|
447 |
+
bbox = compute_aspect_preserved_bbox(
|
448 |
+
tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
|
449 |
+
)
|
450 |
+
img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
451 |
+
img = cv2.resize(img, (256, 256))
|
452 |
+
cv2.imwrite(save_img, img)
|
453 |
+
else:
|
454 |
+
raise ValueError("No face detected in the input image")
|
455 |
+
# img = cv2.resize(img, (256, 256))
|
456 |
+
# cv2.imwrite(save_img, img)
|
damo/dreamtalk/data/audio/German1.wav
ADDED
Binary file (279 kB). View file
|
|
damo/dreamtalk/data/audio/German2.wav
ADDED
Binary file (219 kB). View file
|
|
damo/dreamtalk/data/audio/German3.wav
ADDED
Binary file (240 kB). View file
|
|
damo/dreamtalk/data/audio/German4.wav
ADDED
Binary file (219 kB). View file
|
|
damo/dreamtalk/data/audio/acknowledgement_chinese.m4a
ADDED
Binary file (537 kB). View file
|
|
damo/dreamtalk/data/audio/acknowledgement_english.m4a
ADDED
Binary file (511 kB). View file
|
|
damo/dreamtalk/data/audio/chinese1_haierlizhi.wav
ADDED
Binary file (420 kB). View file
|
|
damo/dreamtalk/data/audio/chinese2_guanyu.wav
ADDED
Binary file (638 kB). View file
|
|
damo/dreamtalk/data/audio/french1.wav
ADDED
Binary file (220 kB). View file
|
|
damo/dreamtalk/data/audio/french2.wav
ADDED
Binary file (177 kB). View file
|
|
damo/dreamtalk/data/audio/french3.wav
ADDED
Binary file (168 kB). View file
|
|
damo/dreamtalk/data/audio/italian1.wav
ADDED
Binary file (285 kB). View file
|
|
damo/dreamtalk/data/audio/italian2.wav
ADDED
Binary file (170 kB). View file
|
|
damo/dreamtalk/data/audio/italian3.wav
ADDED
Binary file (197 kB). View file
|
|
damo/dreamtalk/data/audio/japan1.wav
ADDED
Binary file (197 kB). View file
|
|
damo/dreamtalk/data/audio/japan2.wav
ADDED
Binary file (231 kB). View file
|
|
damo/dreamtalk/data/audio/japan3.wav
ADDED
Binary file (234 kB). View file
|
|
damo/dreamtalk/data/audio/korean1.wav
ADDED
Binary file (328 kB). View file
|
|
damo/dreamtalk/data/audio/korean2.wav
ADDED
Binary file (210 kB). View file
|
|
damo/dreamtalk/data/audio/korean3.wav
ADDED
Binary file (148 kB). View file
|
|
damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav
ADDED
Binary file (206 kB). View file
|
|
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav
ADDED
Binary file (206 kB). View file
|
|
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav
ADDED
Binary file (206 kB). View file
|
|
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav
ADDED
Binary file (206 kB). View file
|
|
damo/dreamtalk/data/audio/noisy_audio_narrative.wav
ADDED
Binary file (206 kB). View file
|
|
damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav
ADDED
Binary file (206 kB). View file
|
|
damo/dreamtalk/data/audio/out_of_domain_narrative.wav
ADDED
Binary file (445 kB). View file
|
|
damo/dreamtalk/data/audio/spanish1.wav
ADDED
Binary file (144 kB). View file
|
|
damo/dreamtalk/data/audio/spanish2.wav
ADDED
Binary file (150 kB). View file
|
|
damo/dreamtalk/data/audio/spanish3.wav
ADDED
Binary file (212 kB). View file
|
|