camenduru commited on
Commit
f4153a9
1 Parent(s): 7c5611b

thanks to damo ❤

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. damo/dreamtalk/.mdl +0 -0
  3. damo/dreamtalk/.msc +0 -0
  4. damo/dreamtalk/README.md +131 -0
  5. damo/dreamtalk/checkpoints/denoising_network.pth +3 -0
  6. damo/dreamtalk/checkpoints/renderer.pt +3 -0
  7. damo/dreamtalk/configs/default.py +91 -0
  8. damo/dreamtalk/configuration.json +11 -0
  9. damo/dreamtalk/core/networks/__init__.py +14 -0
  10. damo/dreamtalk/core/networks/diffusion_net.py +340 -0
  11. damo/dreamtalk/core/networks/diffusion_util.py +131 -0
  12. damo/dreamtalk/core/networks/disentangle_decoder.py +240 -0
  13. damo/dreamtalk/core/networks/dynamic_conv.py +156 -0
  14. damo/dreamtalk/core/networks/dynamic_fc_decoder.py +178 -0
  15. damo/dreamtalk/core/networks/dynamic_linear.py +50 -0
  16. damo/dreamtalk/core/networks/generator.py +309 -0
  17. damo/dreamtalk/core/networks/mish.py +51 -0
  18. damo/dreamtalk/core/networks/self_attention_pooling.py +53 -0
  19. damo/dreamtalk/core/networks/transformer.py +293 -0
  20. damo/dreamtalk/core/utils.py +456 -0
  21. damo/dreamtalk/data/audio/German1.wav +0 -0
  22. damo/dreamtalk/data/audio/German2.wav +0 -0
  23. damo/dreamtalk/data/audio/German3.wav +0 -0
  24. damo/dreamtalk/data/audio/German4.wav +0 -0
  25. damo/dreamtalk/data/audio/acknowledgement_chinese.m4a +0 -0
  26. damo/dreamtalk/data/audio/acknowledgement_english.m4a +0 -0
  27. damo/dreamtalk/data/audio/chinese1_haierlizhi.wav +0 -0
  28. damo/dreamtalk/data/audio/chinese2_guanyu.wav +0 -0
  29. damo/dreamtalk/data/audio/french1.wav +0 -0
  30. damo/dreamtalk/data/audio/french2.wav +0 -0
  31. damo/dreamtalk/data/audio/french3.wav +0 -0
  32. damo/dreamtalk/data/audio/italian1.wav +0 -0
  33. damo/dreamtalk/data/audio/italian2.wav +0 -0
  34. damo/dreamtalk/data/audio/italian3.wav +0 -0
  35. damo/dreamtalk/data/audio/japan1.wav +0 -0
  36. damo/dreamtalk/data/audio/japan2.wav +0 -0
  37. damo/dreamtalk/data/audio/japan3.wav +0 -0
  38. damo/dreamtalk/data/audio/korean1.wav +0 -0
  39. damo/dreamtalk/data/audio/korean2.wav +0 -0
  40. damo/dreamtalk/data/audio/korean3.wav +0 -0
  41. damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav +0 -0
  42. damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav +0 -0
  43. damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav +0 -0
  44. damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav +0 -0
  45. damo/dreamtalk/data/audio/noisy_audio_narrative.wav +0 -0
  46. damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav +0 -0
  47. damo/dreamtalk/data/audio/out_of_domain_narrative.wav +0 -0
  48. damo/dreamtalk/data/audio/spanish1.wav +0 -0
  49. damo/dreamtalk/data/audio/spanish2.wav +0 -0
  50. damo/dreamtalk/data/audio/spanish3.wav +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ damo/dreamtalk/data/pose/RichardShelby_front_neutral_level1_001.mat filter=lfs diff=lfs merge=lfs -text
37
+ damo/dreamtalk/media/teaser.gif filter=lfs diff=lfs merge=lfs -text
damo/dreamtalk/.mdl ADDED
Binary file (37 Bytes). View file
 
damo/dreamtalk/.msc ADDED
Binary file (12.5 kB). View file
 
damo/dreamtalk/README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models
2
+
3
+ <a href='https://dreamtalk-project.github.io/'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2312.09767'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/VF4vlE6ZqWQ)
4
+
5
+ DreamTalk is a diffusion-based audio-driven expressive talking head generation framework that can produce high-quality talking head videos across diverse speaking styles. DreamTalk exhibits robust performance with a diverse array of inputs, including songs, speech in multiple languages, noisy audio, and out-of-domain portraits.
6
+
7
+ ![figure1](media/teaser.gif "teaser")
8
+
9
+ ## News
10
+ - __[2023.12]__ Release inference code and pretrained checkpoint.
11
+
12
+ ## 安装依赖
13
+ ```
14
+ pip install dlib
15
+ ```
16
+
17
+ ## Installation
18
+
19
+ 我在`output_video`文件夹下已经放入了一些生成好的文件, 可运行下面脚本, 对比下结果.
20
+
21
+ ```python
22
+ from modelscope.utils.constant import Tasks
23
+ from modelscope.pipelines import pipeline
24
+ import os
25
+
26
+ pipe = pipeline(task=Tasks.text_to_video_synthesis, model='damo/dreamtalk',
27
+ style_clip_path="data/style_clip/3DMM/M030_front_surprised_level3_001.mat",
28
+ pose_path="data/pose/RichardShelby_front_neutral_level1_001.mat",
29
+ model_revision='master'
30
+ )
31
+ # ,model_revision='master')
32
+ inputs={
33
+ "output_name": "songbie_yk_male",
34
+ "wav_path": "data/audio/acknowledgement_english.m4a",
35
+ "img_crop": True,
36
+ "image_path": "data/src_img/uncropped/male_face.png",
37
+ "max_gen_len": 20
38
+ }
39
+ pipe(input=inputs)
40
+ print("end")
41
+ ```
42
+
43
+ `wav_path` 为输入音频路径;
44
+
45
+ `style_clip_path` 为表情参考文件,从带情绪的视频中提取, 可用来控制生成视频的表情;
46
+
47
+ `pose_path` 为头部运动参考文件, 从视频中提取,可用来控制生成视频的头部运动;
48
+
49
+ `image_path` 为说话人肖像, 最好是正脸, 理论支持任意分辨率输入, 会被裁减成$256\times256$ 分辨率;
50
+
51
+ `max_gen_len` 为最长视频生成时长, 单位为秒, 如果输入音频长于这个时间则会被截断;
52
+
53
+ `output_name`为输出名称, 最终生成的视频会在 `output_video` 文件夹下, 中间结果会在 `tmp` 文件夹下.
54
+
55
+ 如果输入图片已经为$256\times256$ 而且大小合适无需裁剪, 则可使用`disable_img_crop`跳过裁剪步骤, 如下:
56
+
57
+ ## Download Checkpoints
58
+ Download the checkpoint of the denoising network:
59
+ * [ModelScope](tmp)
60
+
61
+
62
+ Download the checkpoint of the renderer:
63
+ * [ModelScope](tmp)
64
+
65
+ Put the downloaded checkpoints into `checkpoints` folder.
66
+
67
+
68
+ ## Inference
69
+ Run the script:
70
+
71
+ ```
72
+ python inference_for_demo_video.py \
73
+ --wav_path data/audio/acknowledgement_english.m4a \
74
+ --style_clip_path data/style_clip/3DMM/M030_front_neutral_level1_001.mat \
75
+ --pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
76
+ --image_path data/src_img/uncropped/male_face.png \
77
+ --cfg_scale 1.0 \
78
+ --max_gen_len 30 \
79
+ --output_name acknowledgement_english@M030_front_neutral_level1_001@male_face
80
+ ```
81
+
82
+ `wav_path` specifies the input audio. The input audio file extensions such as wav, mp3, m4a, and mp4 (video with sound) should all be compatible.
83
+
84
+ `style_clip_path` specifies the reference speaking style and `pose_path` specifies head pose. They are 3DMM paramenter sequences extracted from reference videos. You can follow [PIRenderer](https://github.com/RenYurui/PIRender) to extract 3DMM parameters from your own videos. Note that the video frame rate should be 25 FPS. Besides, videos used for head pose reference should be first cropped to $256\times256$ using scripts in [FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing).
85
+
86
+ `image_path` specifies the input portrait. Its resolution should be larger than $256\times256$. Frontal portraits, with the face directly facing forward and not tilted to one side, usually achieve satisfactory results. The input portrait will be cropped to $256\times256$. If your portrait is already cropped to $256\times256$ and you want to disable cropping, use option `--disable_img_crop` like this:
87
+
88
+ ```
89
+ python inference_for_demo_video.py \
90
+ --wav_path data/audio/acknowledgement_chinese.m4a \
91
+ --style_clip_path data/style_clip/3DMM/M030_front_surprised_level3_001.mat \
92
+ --pose_path data/pose/RichardShelby_front_neutral_level1_001.mat \
93
+ --image_path data/src_img/cropped/zp1.png \
94
+ --disable_img_crop \
95
+ --cfg_scale 1.0 \
96
+ --max_gen_len 30 \
97
+ --output_name acknowledgement_chinese@M030_front_surprised_level3_001@zp1
98
+ ```
99
+
100
+ `cfg_scale` controls the scale of classifer-free guidance. It can adjust the intensity of speaking styles.
101
+
102
+ `max_gen_len` is the maximum video generation duration, measured in seconds. If the input audio exceeds this length, it will be truncated.
103
+
104
+ The generated video will be named `$(output_name).mp4` and put in the output_video folder. Intermediate results, including the cropped portrait, will be in the `tmp/$(output_name)` folder.
105
+
106
+ Sample inputs are presented in `data` folder. Due to copyright issues, we are unable to include the songs we have used in this folder.
107
+
108
+
109
+ ## Acknowledgements
110
+
111
+ We extend our heartfelt thanks for the invaluable contributions made by preceding works to the development of DreamTalk. This includes, but is not limited to:
112
+ [PIRenderer](https://github.com/RenYurui/PIRender)
113
+ ,[AVCT](https://github.com/FuxiVirtualHuman/AAAI22-one-shot-talking-face)
114
+ ,[StyleTalk](https://github.com/FuxiVirtualHuman/styletalk)
115
+ ,[Deep3DFaceRecon_pytorch](https://github.com/sicxu/Deep3DFaceRecon_pytorch)
116
+ ,[Wav2vec2.0](https://huggingface.co/jonatasgrosman/wav2vec2-large-xlsr-53-english)
117
+ ,[diffusion-point-cloud](https://github.com/luost26/diffusion-point-cloud)
118
+ ,[FOMM video preprocessing](https://github.com/AliaksandrSiarohin/video-preprocessing). We are dedicated to advancing upon these foundational works with the utmost respect for their original contributions.
119
+
120
+ ## Citation
121
+ If you find this codebase useful for your research, please use the following entry.
122
+ ```BibTeX
123
+ @article{ma2023dreamtalk,
124
+ title={DreamTalk: When Expressive Talking Head Generation Meets Diffusion Probabilistic Models},
125
+ author={Ma, Yifeng and Zhang, Shiwei and Wang, Jiayu and Wang, Xiang and Zhang, Yingya and Deng, Zhidong},
126
+ journal={arXiv preprint arXiv:2312.09767},
127
+ year={2023}
128
+ }
129
+ ```
130
+
131
+
damo/dreamtalk/checkpoints/denoising_network.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93864d1316f60e75b40bd820707bb2464f790b1636ae2b9275ee500d41c4e3ae
3
+ size 47908943
damo/dreamtalk/checkpoints/renderer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a67014839d42d592255c9fc3b3ceecbcd62c27ce0c0a89ed6628292447404242
3
+ size 335281551
damo/dreamtalk/configs/default.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+
4
+ _C = CN()
5
+ _C.TAG = "style_id_emotion"
6
+ _C.DECODER_TYPE = "DisentangleDecoder"
7
+ _C.CONTENT_ENCODER_TYPE = "ContentW2VEncoder"
8
+ _C.STYLE_ENCODER_TYPE = "StyleEncoder"
9
+
10
+ _C.DIFFNET_TYPE = "DiffusionNet"
11
+
12
+ _C.WIN_SIZE = 5
13
+ _C.D_MODEL = 256
14
+
15
+ _C.DATASET = CN()
16
+ _C.DATASET.FACE3D_DIM = 64
17
+ _C.DATASET.NUM_FRAMES = 64
18
+ _C.DATASET.STYLE_MAX_LEN = 256
19
+
20
+ _C.TRAIN = CN()
21
+ _C.TRAIN.FACE3D_LATENT = CN()
22
+ _C.TRAIN.FACE3D_LATENT.TYPE = "face3d"
23
+
24
+ _C.DIFFUSION = CN()
25
+ _C.DIFFUSION.PREDICT_WHAT = "x0" # noise | x0
26
+ _C.DIFFUSION.SCHEDULE = CN()
27
+ _C.DIFFUSION.SCHEDULE.NUM_STEPS = 1000
28
+ _C.DIFFUSION.SCHEDULE.BETA_1 = 1e-4
29
+ _C.DIFFUSION.SCHEDULE.BETA_T = 0.02
30
+ _C.DIFFUSION.SCHEDULE.MODE = "linear"
31
+
32
+ _C.CONTENT_ENCODER = CN()
33
+ _C.CONTENT_ENCODER.d_model = _C.D_MODEL
34
+ _C.CONTENT_ENCODER.nhead = 8
35
+ _C.CONTENT_ENCODER.num_encoder_layers = 3
36
+ _C.CONTENT_ENCODER.dim_feedforward = 4 * _C.D_MODEL
37
+ _C.CONTENT_ENCODER.dropout = 0.1
38
+ _C.CONTENT_ENCODER.activation = "relu"
39
+ _C.CONTENT_ENCODER.normalize_before = False
40
+ _C.CONTENT_ENCODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
41
+
42
+ _C.STYLE_ENCODER = CN()
43
+ _C.STYLE_ENCODER.d_model = _C.D_MODEL
44
+ _C.STYLE_ENCODER.nhead = 8
45
+ _C.STYLE_ENCODER.num_encoder_layers = 3
46
+ _C.STYLE_ENCODER.dim_feedforward = 4 * _C.D_MODEL
47
+ _C.STYLE_ENCODER.dropout = 0.1
48
+ _C.STYLE_ENCODER.activation = "relu"
49
+ _C.STYLE_ENCODER.normalize_before = False
50
+ _C.STYLE_ENCODER.pos_embed_len = _C.DATASET.STYLE_MAX_LEN
51
+ _C.STYLE_ENCODER.aggregate_method = (
52
+ "self_attention_pooling" # average | self_attention_pooling
53
+ )
54
+ # _C.STYLE_ENCODER.input_dim = _C.DATASET.FACE3D_DIM
55
+
56
+ _C.DECODER = CN()
57
+ _C.DECODER.d_model = _C.D_MODEL
58
+ _C.DECODER.nhead = 8
59
+ _C.DECODER.num_decoder_layers = 3
60
+ _C.DECODER.dim_feedforward = 4 * _C.D_MODEL
61
+ _C.DECODER.dropout = 0.1
62
+ _C.DECODER.activation = "relu"
63
+ _C.DECODER.normalize_before = False
64
+ _C.DECODER.return_intermediate_dec = False
65
+ _C.DECODER.pos_embed_len = 2 * _C.WIN_SIZE + 1
66
+ _C.DECODER.network_type = "TransformerDecoder"
67
+ _C.DECODER.dynamic_K = None
68
+ _C.DECODER.dynamic_ratio = None
69
+ # _C.DECODER.output_dim = _C.DATASET.FACE3D_DIM
70
+ # LSFM basis:
71
+ # _C.DECODER.upper_face3d_indices = tuple(list(range(19)) + list(range(46, 51)))
72
+ # _C.DECODER.lower_face3d_indices = tuple(range(19, 46))
73
+ # BFM basis:
74
+ # fmt: off
75
+ _C.DECODER.upper_face3d_indices = [6, 8, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
76
+ # fmt: on
77
+ _C.DECODER.lower_face3d_indices = [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13, 14]
78
+
79
+ _C.CF_GUIDANCE = CN()
80
+ _C.CF_GUIDANCE.TRAINING = True
81
+ _C.CF_GUIDANCE.INFERENCE = True
82
+ _C.CF_GUIDANCE.NULL_PROB = 0.1
83
+ _C.CF_GUIDANCE.SCALE = 1.0
84
+
85
+ _C.INFERENCE = CN()
86
+ _C.INFERENCE.CHECKPOINT = "checkpoints/denoising_network.pth"
87
+
88
+
89
+ def get_cfg_defaults():
90
+ """Get a yacs CfgNode object with default values for my_project."""
91
+ return _C.clone()
damo/dreamtalk/configuration.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "framework": "pytorch",
3
+ "task": "text-to-video-synthesis",
4
+ "model": {
5
+ "type": "Dreamtalk-Generation"
6
+ },
7
+ "pipeline": {
8
+ "type": "Dreamtalk-generation-pipe"
9
+ },
10
+ "allow_remote": true
11
+ }
damo/dreamtalk/core/networks/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from core.networks.generator import (
2
+ StyleEncoder,
3
+ Decoder,
4
+ ContentW2VEncoder,
5
+ )
6
+ from core.networks.disentangle_decoder import DisentangleDecoder
7
+
8
+
9
+ def get_network(name: str):
10
+ obj = globals().get(name)
11
+ if obj is None:
12
+ raise KeyError("Unknown Network: %s" % name)
13
+ else:
14
+ return obj
damo/dreamtalk/core/networks/diffusion_net.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn import Module
5
+ from core.networks.diffusion_util import VarianceSchedule
6
+ import numpy as np
7
+
8
+
9
+ def face3d_raw_to_norm(face3d_raw, exp_min, exp_max):
10
+ """
11
+
12
+ Args:
13
+ face3d_raw (_type_): (B, L, C_face3d)
14
+ exp_min (_type_): (C_face3d)
15
+ exp_max (_type_): (C_face3d)
16
+
17
+ Returns:
18
+ _type_: (B, L, C_face3d) in [-1, 1]
19
+ """
20
+ exp_min_expand = exp_min[None, None, :]
21
+ exp_max_expand = exp_max[None, None, :]
22
+ face3d_norm_01 = (face3d_raw - exp_min_expand) / (exp_max_expand - exp_min_expand)
23
+ face3d_norm = face3d_norm_01 * 2 - 1
24
+ return face3d_norm
25
+
26
+
27
+ def face3d_norm_to_raw(face3d_norm, exp_min, exp_max):
28
+ """
29
+
30
+ Args:
31
+ face3d_norm (_type_): (B, L, C_face3d)
32
+ exp_min (_type_): (C_face3d)
33
+ exp_max (_type_): (C_face3d)
34
+
35
+ Returns:
36
+ _type_: (B, L, C_face3d)
37
+ """
38
+ exp_min_expand = exp_min[None, None, :]
39
+ exp_max_expand = exp_max[None, None, :]
40
+ face3d_norm_01 = (face3d_norm + 1) / 2
41
+ face3d_raw = face3d_norm_01 * (exp_max_expand - exp_min_expand) + exp_min_expand
42
+ return face3d_raw
43
+
44
+
45
+ class DiffusionNet(Module):
46
+ def __init__(self, cfg, net, var_sched: VarianceSchedule):
47
+ super().__init__()
48
+ self.cfg = cfg
49
+ self.net = net
50
+ self.var_sched = var_sched
51
+ self.face3d_latent_type = self.cfg.TRAIN.FACE3D_LATENT.TYPE
52
+ self.predict_what = self.cfg.DIFFUSION.PREDICT_WHAT
53
+
54
+ if self.cfg.CF_GUIDANCE.TRAINING:
55
+ null_style_clip = torch.zeros(
56
+ self.cfg.DATASET.STYLE_MAX_LEN, self.cfg.DATASET.FACE3D_DIM
57
+ )
58
+ self.register_buffer("null_style_clip", null_style_clip)
59
+
60
+ null_pad_mask = torch.tensor([False] * self.cfg.DATASET.STYLE_MAX_LEN)
61
+ self.register_buffer("null_pad_mask", null_pad_mask)
62
+
63
+ def _face3d_to_latent(self, face3d):
64
+ latent = None
65
+ if self.face3d_latent_type == "face3d":
66
+ latent = face3d
67
+ elif self.face3d_latent_type == "normalized_face3d":
68
+ latent = face3d_raw_to_norm(
69
+ face3d, exp_min=self.exp_min, exp_max=self.exp_max
70
+ )
71
+ else:
72
+ raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
73
+ return latent
74
+
75
+ def _latent_to_face3d(self, latent):
76
+ face3d = None
77
+ if self.face3d_latent_type == "face3d":
78
+ face3d = latent
79
+ elif self.face3d_latent_type == "normalized_face3d":
80
+ latent = torch.clamp(latent, min=-1, max=1)
81
+ face3d = face3d_norm_to_raw(
82
+ latent, exp_min=self.exp_min, exp_max=self.exp_max
83
+ )
84
+ else:
85
+ raise ValueError(f"Invalid face3d latent type: {self.face3d_latent_type}")
86
+ return face3d
87
+
88
+ def ddim_sample(
89
+ self,
90
+ audio,
91
+ style_clip,
92
+ style_pad_mask,
93
+ output_dim,
94
+ flexibility=0.0,
95
+ ret_traj=False,
96
+ use_cf_guidance=False,
97
+ cfg_scale=2.0,
98
+ ddim_num_step=50,
99
+ ready_style_code=None,
100
+ ):
101
+ """
102
+
103
+ Args:
104
+ audio (_type_): (B, L, W) or (B, L, W, C)
105
+ style_clip (_type_): (B, L_clipmax, C_face3d)
106
+ style_pad_mask : (B, L_clipmax)
107
+ pose_dim (_type_): int
108
+ flexibility (float, optional): _description_. Defaults to 0.0.
109
+ ret_traj (bool, optional): _description_. Defaults to False.
110
+
111
+
112
+ Returns:
113
+ _type_: (B, L, C_face)
114
+ """
115
+ if self.predict_what != "x0":
116
+ raise NotImplementedError(self.predict_what)
117
+
118
+ if ready_style_code is not None and use_cf_guidance:
119
+ raise NotImplementedError("not implement cfg for ready style code")
120
+
121
+ c = self.var_sched.num_steps // ddim_num_step
122
+ time_steps = torch.tensor(
123
+ np.asarray(list(range(0, self.var_sched.num_steps, c))) + 1
124
+ )
125
+ assert len(time_steps) == ddim_num_step
126
+ prev_time_steps = torch.cat((torch.tensor([0]), time_steps[:-1]))
127
+
128
+ batch_size, output_len = audio.shape[:2]
129
+ # batch_size = context.size(0)
130
+ context = {
131
+ "audio": audio,
132
+ "style_clip": style_clip,
133
+ "style_pad_mask": style_pad_mask,
134
+ "ready_style_code": ready_style_code,
135
+ }
136
+ if use_cf_guidance:
137
+ uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
138
+ batch_size, 1, 1
139
+ )
140
+ uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
141
+
142
+ context_double = {
143
+ "audio": torch.cat([audio] * 2, dim=0),
144
+ "style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
145
+ "style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
146
+ "ready_style_code": None
147
+ if ready_style_code is None
148
+ else torch.cat(
149
+ [
150
+ ready_style_code,
151
+ self.net.style_encoder(uncond_style_clip, uncond_pad_mask),
152
+ ],
153
+ dim=0,
154
+ ),
155
+ }
156
+
157
+ x_t = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
158
+
159
+ for idx in list(range(ddim_num_step))[::-1]:
160
+ t = time_steps[idx]
161
+ t_prev = prev_time_steps[idx]
162
+ ddim_alpha = self.var_sched.alpha_bars[t]
163
+ ddim_alpha_prev = self.var_sched.alpha_bars[t_prev]
164
+
165
+ t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
166
+ if use_cf_guidance:
167
+ x_t_double = torch.cat([x_t] * 2, dim=0)
168
+ t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
169
+ cond_output, uncond_output = self.net(
170
+ x_t_double, t=t_tensor_double, **context_double
171
+ ).chunk(2)
172
+ diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
173
+ else:
174
+ diff_output = self.net(x_t, t=t_tensor, **context)
175
+
176
+ pred_x0 = diff_output
177
+ eps = (x_t - torch.sqrt(ddim_alpha) * pred_x0) / torch.sqrt(1 - ddim_alpha)
178
+ c1 = torch.sqrt(ddim_alpha_prev)
179
+ c2 = torch.sqrt(1 - ddim_alpha_prev)
180
+
181
+ x_t = c1 * pred_x0 + c2 * eps
182
+
183
+ latent_output = x_t
184
+ face3d_output = self._latent_to_face3d(latent_output)
185
+ return face3d_output
186
+
187
+ def sample(
188
+ self,
189
+ audio,
190
+ style_clip,
191
+ style_pad_mask,
192
+ output_dim,
193
+ flexibility=0.0,
194
+ ret_traj=False,
195
+ use_cf_guidance=False,
196
+ cfg_scale=2.0,
197
+ sample_method="ddpm",
198
+ ddim_num_step=50,
199
+ ready_style_code=None,
200
+ ):
201
+ # sample_method = kwargs["sample_method"]
202
+ if sample_method == "ddpm":
203
+ if ready_style_code is not None:
204
+ raise NotImplementedError("ready style code in ddpm")
205
+ return self.ddpm_sample(
206
+ audio,
207
+ style_clip,
208
+ style_pad_mask,
209
+ output_dim,
210
+ flexibility=flexibility,
211
+ ret_traj=ret_traj,
212
+ use_cf_guidance=use_cf_guidance,
213
+ cfg_scale=cfg_scale,
214
+ )
215
+ elif sample_method == "ddim":
216
+ return self.ddim_sample(
217
+ audio,
218
+ style_clip,
219
+ style_pad_mask,
220
+ output_dim,
221
+ flexibility=flexibility,
222
+ ret_traj=ret_traj,
223
+ use_cf_guidance=use_cf_guidance,
224
+ cfg_scale=cfg_scale,
225
+ ddim_num_step=ddim_num_step,
226
+ ready_style_code=ready_style_code,
227
+ )
228
+
229
+ def ddpm_sample(
230
+ self,
231
+ audio,
232
+ style_clip,
233
+ style_pad_mask,
234
+ output_dim,
235
+ flexibility=0.0,
236
+ ret_traj=False,
237
+ use_cf_guidance=False,
238
+ cfg_scale=2.0,
239
+ ):
240
+ """
241
+
242
+ Args:
243
+ audio (_type_): (B, L, W) or (B, L, W, C)
244
+ style_clip (_type_): (B, L_clipmax, C_face3d)
245
+ style_pad_mask : (B, L_clipmax)
246
+ pose_dim (_type_): int
247
+ flexibility (float, optional): _description_. Defaults to 0.0.
248
+ ret_traj (bool, optional): _description_. Defaults to False.
249
+
250
+
251
+ Returns:
252
+ _type_: (B, L, C_face)
253
+ """
254
+ batch_size, output_len = audio.shape[:2]
255
+ # batch_size = context.size(0)
256
+ context = {
257
+ "audio": audio,
258
+ "style_clip": style_clip,
259
+ "style_pad_mask": style_pad_mask,
260
+ }
261
+ if use_cf_guidance:
262
+ uncond_style_clip = self.null_style_clip.unsqueeze(0).repeat(
263
+ batch_size, 1, 1
264
+ )
265
+ uncond_pad_mask = self.null_pad_mask.unsqueeze(0).repeat(batch_size, 1)
266
+ context_double = {
267
+ "audio": torch.cat([audio] * 2, dim=0),
268
+ "style_clip": torch.cat([style_clip, uncond_style_clip], dim=0),
269
+ "style_pad_mask": torch.cat([style_pad_mask, uncond_pad_mask], dim=0),
270
+ }
271
+
272
+ x_T = torch.randn([batch_size, output_len, output_dim]).to(audio.device)
273
+ traj = {self.var_sched.num_steps: x_T}
274
+ for t in range(self.var_sched.num_steps, 0, -1):
275
+ alpha = self.var_sched.alphas[t]
276
+ alpha_bar = self.var_sched.alpha_bars[t]
277
+ alpha_bar_prev = self.var_sched.alpha_bars[t - 1]
278
+ sigma = self.var_sched.get_sigmas(t, flexibility)
279
+
280
+ z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
281
+ x_t = traj[t]
282
+ t_tensor = torch.tensor([t] * batch_size).to(audio.device).float()
283
+ if use_cf_guidance:
284
+ x_t_double = torch.cat([x_t] * 2, dim=0)
285
+ t_tensor_double = torch.cat([t_tensor] * 2, dim=0)
286
+ cond_output, uncond_output = self.net(
287
+ x_t_double, t=t_tensor_double, **context_double
288
+ ).chunk(2)
289
+ diff_output = uncond_output + cfg_scale * (cond_output - uncond_output)
290
+ else:
291
+ diff_output = self.net(x_t, t=t_tensor, **context)
292
+
293
+ if self.predict_what == "noise":
294
+ c0 = 1.0 / torch.sqrt(alpha)
295
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
296
+ x_next = c0 * (x_t - c1 * diff_output) + sigma * z
297
+ elif self.predict_what == "x0":
298
+ d0 = torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar)
299
+ d1 = torch.sqrt(alpha_bar_prev) * (1 - alpha) / (1 - alpha_bar)
300
+ x_next = d0 * x_t + d1 * diff_output + sigma * z
301
+ traj[t - 1] = x_next.detach()
302
+ traj[t] = traj[t].cpu()
303
+ if not ret_traj:
304
+ del traj[t]
305
+
306
+ if ret_traj:
307
+ raise NotImplementedError
308
+ return traj
309
+ else:
310
+ latent_output = traj[0]
311
+ face3d_output = self._latent_to_face3d(latent_output)
312
+ return face3d_output
313
+
314
+
315
+ if __name__ == "__main__":
316
+ from core.networks.diffusion_util import NoisePredictor, VarianceSchedule
317
+
318
+ diffnet = DiffusionNet(
319
+ net=NoisePredictor(),
320
+ var_sched=VarianceSchedule(
321
+ num_steps=500, beta_1=1e-4, beta_T=0.02, mode="linear"
322
+ ),
323
+ )
324
+
325
+ import torch
326
+
327
+ gt_face3d = torch.randn(16, 64, 64)
328
+ audio = torch.randn(16, 64, 11)
329
+ style_clip = torch.randn(16, 256, 64)
330
+ style_pad_mask = torch.ones(16, 256)
331
+
332
+ context = {
333
+ "audio": audio,
334
+ "style_clip": style_clip,
335
+ "style_pad_mask": style_pad_mask,
336
+ }
337
+
338
+ loss = diffnet.get_loss(gt_face3d, context)
339
+
340
+ print("hello")
damo/dreamtalk/core/networks/diffusion_util.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import Module
5
+ from core.networks import get_network
6
+ from core.utils import sinusoidal_embedding
7
+
8
+
9
+ class VarianceSchedule(Module):
10
+ def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
11
+ super().__init__()
12
+ assert mode in ("linear",)
13
+ self.num_steps = num_steps
14
+ self.beta_1 = beta_1
15
+ self.beta_T = beta_T
16
+ self.mode = mode
17
+
18
+ if mode == "linear":
19
+ betas = torch.linspace(beta_1, beta_T, steps=num_steps)
20
+
21
+ betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
22
+
23
+ alphas = 1 - betas
24
+ log_alphas = torch.log(alphas)
25
+ for i in range(1, log_alphas.size(0)): # 1 to T
26
+ log_alphas[i] += log_alphas[i - 1]
27
+ alpha_bars = log_alphas.exp()
28
+
29
+ sigmas_flex = torch.sqrt(betas)
30
+ sigmas_inflex = torch.zeros_like(sigmas_flex)
31
+ for i in range(1, sigmas_flex.size(0)):
32
+ sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
33
+ i
34
+ ]
35
+ sigmas_inflex = torch.sqrt(sigmas_inflex)
36
+
37
+ self.register_buffer("betas", betas)
38
+ self.register_buffer("alphas", alphas)
39
+ self.register_buffer("alpha_bars", alpha_bars)
40
+ self.register_buffer("sigmas_flex", sigmas_flex)
41
+ self.register_buffer("sigmas_inflex", sigmas_inflex)
42
+
43
+ def uniform_sample_t(self, batch_size):
44
+ ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
45
+ return ts.tolist()
46
+
47
+ def get_sigmas(self, t, flexibility):
48
+ assert 0 <= flexibility and flexibility <= 1
49
+ sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
50
+ 1 - flexibility
51
+ )
52
+ return sigmas
53
+
54
+
55
+ class NoisePredictor(nn.Module):
56
+ def __init__(self, cfg):
57
+ super().__init__()
58
+
59
+ content_encoder_class = get_network(cfg.CONTENT_ENCODER_TYPE)
60
+ self.content_encoder = content_encoder_class(**cfg.CONTENT_ENCODER)
61
+
62
+ style_encoder_class = get_network(cfg.STYLE_ENCODER_TYPE)
63
+ cfg.defrost()
64
+ cfg.STYLE_ENCODER.input_dim = cfg.DATASET.FACE3D_DIM
65
+ cfg.freeze()
66
+ self.style_encoder = style_encoder_class(**cfg.STYLE_ENCODER)
67
+
68
+ decoder_class = get_network(cfg.DECODER_TYPE)
69
+ cfg.defrost()
70
+ cfg.DECODER.output_dim = cfg.DATASET.FACE3D_DIM
71
+ cfg.freeze()
72
+ self.decoder = decoder_class(**cfg.DECODER)
73
+
74
+ self.content_xt_to_decoder_input_wo_time = nn.Sequential(
75
+ nn.Linear(cfg.D_MODEL + cfg.DATASET.FACE3D_DIM, cfg.D_MODEL),
76
+ nn.ReLU(),
77
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
78
+ nn.ReLU(),
79
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
80
+ )
81
+
82
+ self.time_sinusoidal_dim = cfg.D_MODEL
83
+ self.time_embed_net = nn.Sequential(
84
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
85
+ nn.SiLU(),
86
+ nn.Linear(cfg.D_MODEL, cfg.D_MODEL),
87
+ )
88
+
89
+ def forward(self, x_t, t, audio, style_clip, style_pad_mask, ready_style_code=None):
90
+ """_summary_
91
+
92
+ Args:
93
+ x_t (_type_): (B, L, C_face)
94
+ t (_type_): (B,) dtype:float32
95
+ audio (_type_): (B, L, W)
96
+ style_clip (_type_): (B, L_clipmax, C_face3d)
97
+ style_pad_mask : (B, L_clipmax)
98
+ ready_style_code: (B, C_model)
99
+ Returns:
100
+ e_theta : (B, L, C_face)
101
+ """
102
+ W = audio.shape[2]
103
+ content = self.content_encoder(audio)
104
+ # (B, L, W, C_model)
105
+ x_t_expand = x_t.unsqueeze(2).repeat(1, 1, W, 1)
106
+ # (B, L, C_face) -> (B, L, W, C_face)
107
+ content_xt_concat = torch.cat((content, x_t_expand), dim=3)
108
+ # (B, L, W, C_model+C_face)
109
+ decoder_input_without_time = self.content_xt_to_decoder_input_wo_time(
110
+ content_xt_concat
111
+ )
112
+ # (B, L, W, C_model)
113
+
114
+ time_sinusoidal = sinusoidal_embedding(t, self.time_sinusoidal_dim)
115
+ # (B, C_embed)
116
+ time_embedding = self.time_embed_net(time_sinusoidal)
117
+ # (B, C_model)
118
+ B, C = time_embedding.shape
119
+ time_embed_expand = time_embedding.view(B, 1, 1, C)
120
+ decoder_input = decoder_input_without_time + time_embed_expand
121
+ # (B, L, W, C_model)
122
+
123
+ if ready_style_code is not None:
124
+ style_code = ready_style_code
125
+ else:
126
+ style_code = self.style_encoder(style_clip, style_pad_mask)
127
+ # (B, C_model)
128
+
129
+ e_theta = self.decoder(decoder_input, style_code)
130
+ # (B, L, C_face)
131
+ return e_theta
damo/dreamtalk/core/networks/disentangle_decoder.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .transformer import (
5
+ PositionalEncoding,
6
+ TransformerDecoderLayer,
7
+ TransformerDecoder,
8
+ )
9
+ from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder
10
+ from core.utils import _reset_parameters
11
+
12
+
13
+ def get_decoder_network(
14
+ network_type,
15
+ d_model,
16
+ nhead,
17
+ dim_feedforward,
18
+ dropout,
19
+ activation,
20
+ normalize_before,
21
+ num_decoder_layers,
22
+ return_intermediate_dec,
23
+ dynamic_K,
24
+ dynamic_ratio,
25
+ ):
26
+ decoder = None
27
+ if network_type == "TransformerDecoder":
28
+ decoder_layer = TransformerDecoderLayer(
29
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
30
+ )
31
+ norm = nn.LayerNorm(d_model)
32
+ decoder = TransformerDecoder(
33
+ decoder_layer,
34
+ num_decoder_layers,
35
+ norm,
36
+ return_intermediate_dec,
37
+ )
38
+ elif network_type == "DynamicFCDecoder":
39
+ d_style = d_model
40
+ decoder_layer = DynamicFCDecoderLayer(
41
+ d_model,
42
+ nhead,
43
+ d_style,
44
+ dynamic_K,
45
+ dynamic_ratio,
46
+ dim_feedforward,
47
+ dropout,
48
+ activation,
49
+ normalize_before,
50
+ )
51
+ norm = nn.LayerNorm(d_model)
52
+ decoder = DynamicFCDecoder(
53
+ decoder_layer, num_decoder_layers, norm, return_intermediate_dec
54
+ )
55
+ elif network_type == "DynamicFCEncoder":
56
+ d_style = d_model
57
+ decoder_layer = DynamicFCEncoderLayer(
58
+ d_model,
59
+ nhead,
60
+ d_style,
61
+ dynamic_K,
62
+ dynamic_ratio,
63
+ dim_feedforward,
64
+ dropout,
65
+ activation,
66
+ normalize_before,
67
+ )
68
+ norm = nn.LayerNorm(d_model)
69
+ decoder = DynamicFCEncoder(decoder_layer, num_decoder_layers, norm)
70
+
71
+ else:
72
+ raise ValueError(f"Invalid network_type {network_type}")
73
+
74
+ return decoder
75
+
76
+
77
+ class DisentangleDecoder(nn.Module):
78
+ def __init__(
79
+ self,
80
+ d_model=512,
81
+ nhead=8,
82
+ num_decoder_layers=3,
83
+ dim_feedforward=2048,
84
+ dropout=0.1,
85
+ activation="relu",
86
+ normalize_before=False,
87
+ return_intermediate_dec=False,
88
+ pos_embed_len=80,
89
+ upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))),
90
+ lower_face3d_indices=tuple(range(19, 46)),
91
+ network_type="None",
92
+ dynamic_K=None,
93
+ dynamic_ratio=None,
94
+ **_,
95
+ ) -> None:
96
+ super().__init__()
97
+
98
+ self.upper_face3d_indices = upper_face3d_indices
99
+ self.lower_face3d_indices = lower_face3d_indices
100
+
101
+ # upper_decoder_layer = TransformerDecoderLayer(
102
+ # d_model, nhead, dim_feedforward, dropout, activation, normalize_before
103
+ # )
104
+ # upper_decoder_norm = nn.LayerNorm(d_model)
105
+ # self.upper_decoder = TransformerDecoder(
106
+ # upper_decoder_layer,
107
+ # num_decoder_layers,
108
+ # upper_decoder_norm,
109
+ # return_intermediate=return_intermediate_dec,
110
+ # )
111
+ self.upper_decoder = get_decoder_network(
112
+ network_type,
113
+ d_model,
114
+ nhead,
115
+ dim_feedforward,
116
+ dropout,
117
+ activation,
118
+ normalize_before,
119
+ num_decoder_layers,
120
+ return_intermediate_dec,
121
+ dynamic_K,
122
+ dynamic_ratio,
123
+ )
124
+ _reset_parameters(self.upper_decoder)
125
+
126
+ # lower_decoder_layer = TransformerDecoderLayer(
127
+ # d_model, nhead, dim_feedforward, dropout, activation, normalize_before
128
+ # )
129
+ # lower_decoder_norm = nn.LayerNorm(d_model)
130
+ # self.lower_decoder = TransformerDecoder(
131
+ # lower_decoder_layer,
132
+ # num_decoder_layers,
133
+ # lower_decoder_norm,
134
+ # return_intermediate=return_intermediate_dec,
135
+ # )
136
+ self.lower_decoder = get_decoder_network(
137
+ network_type,
138
+ d_model,
139
+ nhead,
140
+ dim_feedforward,
141
+ dropout,
142
+ activation,
143
+ normalize_before,
144
+ num_decoder_layers,
145
+ return_intermediate_dec,
146
+ dynamic_K,
147
+ dynamic_ratio,
148
+ )
149
+ _reset_parameters(self.lower_decoder)
150
+
151
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
152
+
153
+ tail_hidden_dim = d_model // 2
154
+ self.upper_tail_fc = nn.Sequential(
155
+ nn.Linear(d_model, tail_hidden_dim),
156
+ nn.ReLU(),
157
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
158
+ nn.ReLU(),
159
+ nn.Linear(tail_hidden_dim, len(upper_face3d_indices)),
160
+ )
161
+ self.lower_tail_fc = nn.Sequential(
162
+ nn.Linear(d_model, tail_hidden_dim),
163
+ nn.ReLU(),
164
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
165
+ nn.ReLU(),
166
+ nn.Linear(tail_hidden_dim, len(lower_face3d_indices)),
167
+ )
168
+
169
+ def forward(self, content, style_code):
170
+ """
171
+
172
+ Args:
173
+ content (_type_): (B, num_frames, window, C_dmodel)
174
+ style_code (_type_): (B, C_dmodel)
175
+
176
+ Returns:
177
+ face3d: (B, L_clip, C_3dmm)
178
+ """
179
+ B, N, W, C = content.shape
180
+ style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
181
+ style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
182
+ # (W, B*N, C)
183
+
184
+ content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
185
+ # (W, B*N, C)
186
+ tgt = torch.zeros_like(style)
187
+ pos_embed = self.pos_embed(W)
188
+ pos_embed = pos_embed.permute(1, 0, 2)
189
+
190
+ upper_face3d_feat = self.upper_decoder(
191
+ tgt, content, pos=pos_embed, query_pos=style
192
+ )[0]
193
+ # (W, B*N, C)
194
+ upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
195
+ :, :, W // 2, :
196
+ ]
197
+ # (B, N, C)
198
+ upper_face3d = self.upper_tail_fc(upper_face3d_feat)
199
+ # (B, N, C_exp)
200
+
201
+ lower_face3d_feat = self.lower_decoder(
202
+ tgt, content, pos=pos_embed, query_pos=style
203
+ )[0]
204
+ lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[
205
+ :, :, W // 2, :
206
+ ]
207
+ lower_face3d = self.lower_tail_fc(lower_face3d_feat)
208
+ C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices)
209
+ face3d = torch.zeros(B, N, C_exp).to(upper_face3d)
210
+ face3d[:, :, self.upper_face3d_indices] = upper_face3d
211
+ face3d[:, :, self.lower_face3d_indices] = lower_face3d
212
+ return face3d
213
+
214
+
215
+ if __name__ == "__main__":
216
+ import sys
217
+
218
+ sys.path.append("/home/mayifeng/Research/styleTH")
219
+
220
+ from configs.default import get_cfg_defaults
221
+
222
+ cfg = get_cfg_defaults()
223
+ cfg.merge_from_file("configs/styleTH_unpair_lsfm_emotion.yaml")
224
+ cfg.freeze()
225
+
226
+ # content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
227
+
228
+ # dummy_audio = torch.randint(0, 41, (5, 64, 11))
229
+ # dummy_content = content_encoder(dummy_audio)
230
+
231
+ # style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
232
+ # dummy_face3d_seq = torch.randn(5, 64, 64)
233
+ # dummy_style_code = style_encoder(dummy_face3d_seq)
234
+
235
+ decoder = DisentangleDecoder(**cfg.DECODER)
236
+ dummy_content = torch.randn(5, 64, 11, 256)
237
+ dummy_style = torch.randn(5, 256)
238
+ dummy_output = decoder(dummy_content, dummy_style)
239
+
240
+ print("hello")
damo/dreamtalk/core/networks/dynamic_conv.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class Attention(nn.Module):
9
+ def __init__(self, cond_planes, ratio, K, temperature=30, init_weight=True):
10
+ super().__init__()
11
+ # self.avgpool = nn.AdaptiveAvgPool2d(1)
12
+ self.temprature = temperature
13
+ assert cond_planes > ratio
14
+ hidden_planes = cond_planes // ratio
15
+ self.net = nn.Sequential(
16
+ nn.Conv2d(cond_planes, hidden_planes, kernel_size=1, bias=False),
17
+ nn.ReLU(),
18
+ nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False),
19
+ )
20
+
21
+ if init_weight:
22
+ self._initialize_weights()
23
+
24
+ def update_temprature(self):
25
+ if self.temprature > 1:
26
+ self.temprature -= 1
27
+
28
+ def _initialize_weights(self):
29
+ for m in self.modules():
30
+ if isinstance(m, nn.Conv2d):
31
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
32
+ if m.bias is not None:
33
+ nn.init.constant_(m.bias, 0)
34
+ if isinstance(m, nn.BatchNorm2d):
35
+ nn.init.constant_(m.weight, 1)
36
+ nn.init.constant_(m.bias, 0)
37
+
38
+ def forward(self, cond):
39
+ """
40
+
41
+ Args:
42
+ cond (_type_): (B, C_style)
43
+
44
+ Returns:
45
+ _type_: (B, K)
46
+ """
47
+
48
+ # att = self.avgpool(cond) # bs,dim,1,1
49
+ att = cond.view(cond.shape[0], cond.shape[1], 1, 1)
50
+ att = self.net(att).view(cond.shape[0], -1) # bs,K
51
+ return F.softmax(att / self.temprature, -1)
52
+
53
+
54
+ class DynamicConv(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_planes,
58
+ out_planes,
59
+ cond_planes,
60
+ kernel_size,
61
+ stride,
62
+ padding=0,
63
+ dilation=1,
64
+ groups=1,
65
+ bias=True,
66
+ K=4,
67
+ temperature=30,
68
+ ratio=4,
69
+ init_weight=True,
70
+ ):
71
+ super().__init__()
72
+ self.in_planes = in_planes
73
+ self.out_planes = out_planes
74
+ self.cond_planes = cond_planes
75
+ self.kernel_size = kernel_size
76
+ self.stride = stride
77
+ self.padding = padding
78
+ self.dilation = dilation
79
+ self.groups = groups
80
+ self.bias = bias
81
+ self.K = K
82
+ self.init_weight = init_weight
83
+ self.attention = Attention(
84
+ cond_planes=cond_planes, ratio=ratio, K=K, temperature=temperature, init_weight=init_weight
85
+ )
86
+
87
+ self.weight = nn.Parameter(
88
+ torch.randn(K, out_planes, in_planes // groups, kernel_size, kernel_size), requires_grad=True
89
+ )
90
+ if bias:
91
+ self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
92
+ else:
93
+ self.bias = None
94
+
95
+ if self.init_weight:
96
+ self._initialize_weights()
97
+
98
+ def _initialize_weights(self):
99
+ for i in range(self.K):
100
+ nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
101
+ if self.bias is not None:
102
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
103
+ if fan_in != 0:
104
+ bound = 1 / math.sqrt(fan_in)
105
+ nn.init.uniform_(self.bias, -bound, bound)
106
+
107
+ def forward(self, x, cond):
108
+ """
109
+
110
+ Args:
111
+ x (_type_): (B, C_in, L, 1)
112
+ cond (_type_): (B, C_style)
113
+
114
+ Returns:
115
+ _type_: (B, C_out, L, 1)
116
+ """
117
+ bs, in_planels, h, w = x.shape
118
+ softmax_att = self.attention(cond) # bs,K
119
+ x = x.view(1, -1, h, w)
120
+ weight = self.weight.view(self.K, -1) # K,-1
121
+ aggregate_weight = torch.mm(softmax_att, weight).view(
122
+ bs * self.out_planes, self.in_planes // self.groups, self.kernel_size, self.kernel_size
123
+ ) # bs*out_p,in_p,k,k
124
+
125
+ if self.bias is not None:
126
+ bias = self.bias.view(self.K, -1) # K,out_p
127
+ aggregate_bias = torch.mm(softmax_att, bias).view(-1) # bs*out_p
128
+ output = F.conv2d(
129
+ x, # 1, bs*in_p, L, 1
130
+ weight=aggregate_weight,
131
+ bias=aggregate_bias,
132
+ stride=self.stride,
133
+ padding=self.padding,
134
+ groups=self.groups * bs,
135
+ dilation=self.dilation,
136
+ )
137
+ else:
138
+ output = F.conv2d(
139
+ x,
140
+ weight=aggregate_weight,
141
+ bias=None,
142
+ stride=self.stride,
143
+ padding=self.padding,
144
+ groups=self.groups * bs,
145
+ dilation=self.dilation,
146
+ )
147
+
148
+ output = output.view(bs, self.out_planes, h, w)
149
+ return output
150
+
151
+
152
+ if __name__ == "__main__":
153
+ input = torch.randn(3, 32, 64, 64)
154
+ m = DynamicConv(in_planes=32, out_planes=64, kernel_size=3, stride=1, padding=1, bias=True)
155
+ out = m(input)
156
+ print(out.shape)
damo/dreamtalk/core/networks/dynamic_fc_decoder.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ from core.networks.transformer import _get_activation_fn, _get_clones
5
+ from core.networks.dynamic_linear import DynamicLinear
6
+
7
+
8
+ class DynamicFCDecoderLayer(nn.Module):
9
+ def __init__(
10
+ self,
11
+ d_model,
12
+ nhead,
13
+ d_style,
14
+ dynamic_K,
15
+ dynamic_ratio,
16
+ dim_feedforward=2048,
17
+ dropout=0.1,
18
+ activation="relu",
19
+ normalize_before=False,
20
+ ):
21
+ super().__init__()
22
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
23
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
24
+ # Implementation of Feedforward model
25
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
26
+ self.linear1 = DynamicLinear(d_model, dim_feedforward, d_style, K=dynamic_K, ratio=dynamic_ratio)
27
+ self.dropout = nn.Dropout(dropout)
28
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
29
+ # self.linear2 = DynamicLinear(dim_feedforward, d_model, d_style, K=dynamic_K, ratio=dynamic_ratio)
30
+
31
+ self.norm1 = nn.LayerNorm(d_model)
32
+ self.norm2 = nn.LayerNorm(d_model)
33
+ self.norm3 = nn.LayerNorm(d_model)
34
+ self.dropout1 = nn.Dropout(dropout)
35
+ self.dropout2 = nn.Dropout(dropout)
36
+ self.dropout3 = nn.Dropout(dropout)
37
+
38
+ self.activation = _get_activation_fn(activation)
39
+ self.normalize_before = normalize_before
40
+
41
+ def with_pos_embed(self, tensor, pos):
42
+ return tensor if pos is None else tensor + pos
43
+
44
+ def forward_post(
45
+ self,
46
+ tgt,
47
+ memory,
48
+ style,
49
+ tgt_mask=None,
50
+ memory_mask=None,
51
+ tgt_key_padding_mask=None,
52
+ memory_key_padding_mask=None,
53
+ pos=None,
54
+ query_pos=None,
55
+ ):
56
+ # q = k = self.with_pos_embed(tgt, query_pos)
57
+ tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
58
+ tgt = tgt + self.dropout1(tgt2)
59
+ tgt = self.norm1(tgt)
60
+ tgt2 = self.multihead_attn(
61
+ query=tgt, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
62
+ )[0]
63
+ tgt = tgt + self.dropout2(tgt2)
64
+ tgt = self.norm2(tgt)
65
+ # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))), style)
66
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt, style))))
67
+ tgt = tgt + self.dropout3(tgt2)
68
+ tgt = self.norm3(tgt)
69
+ return tgt
70
+
71
+ # def forward_pre(
72
+ # self,
73
+ # tgt,
74
+ # memory,
75
+ # tgt_mask=None,
76
+ # memory_mask=None,
77
+ # tgt_key_padding_mask=None,
78
+ # memory_key_padding_mask=None,
79
+ # pos=None,
80
+ # query_pos=None,
81
+ # ):
82
+ # tgt2 = self.norm1(tgt)
83
+ # # q = k = self.with_pos_embed(tgt2, query_pos)
84
+ # tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
85
+ # tgt = tgt + self.dropout1(tgt2)
86
+ # tgt2 = self.norm2(tgt)
87
+ # tgt2 = self.multihead_attn(
88
+ # query=tgt2, key=memory, value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask
89
+ # )[0]
90
+ # tgt = tgt + self.dropout2(tgt2)
91
+ # tgt2 = self.norm3(tgt)
92
+ # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
93
+ # tgt = tgt + self.dropout3(tgt2)
94
+ # return tgt
95
+
96
+ def forward(
97
+ self,
98
+ tgt,
99
+ memory,
100
+ style,
101
+ tgt_mask=None,
102
+ memory_mask=None,
103
+ tgt_key_padding_mask=None,
104
+ memory_key_padding_mask=None,
105
+ pos=None,
106
+ query_pos=None,
107
+ ):
108
+ if self.normalize_before:
109
+ raise NotImplementedError
110
+ # return self.forward_pre(
111
+ # tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
112
+ # )
113
+ return self.forward_post(
114
+ tgt, memory, style, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos
115
+ )
116
+
117
+
118
+ class DynamicFCDecoder(nn.Module):
119
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
120
+ super().__init__()
121
+ self.layers = _get_clones(decoder_layer, num_layers)
122
+ self.num_layers = num_layers
123
+ self.norm = norm
124
+ self.return_intermediate = return_intermediate
125
+
126
+ def forward(
127
+ self,
128
+ tgt,
129
+ memory,
130
+ tgt_mask=None,
131
+ memory_mask=None,
132
+ tgt_key_padding_mask=None,
133
+ memory_key_padding_mask=None,
134
+ pos=None,
135
+ query_pos=None,
136
+ ):
137
+ style = query_pos[0]
138
+ # (B*N, C)
139
+ output = tgt + pos + query_pos
140
+
141
+ intermediate = []
142
+
143
+ for layer in self.layers:
144
+ output = layer(
145
+ output,
146
+ memory,
147
+ style,
148
+ tgt_mask=tgt_mask,
149
+ memory_mask=memory_mask,
150
+ tgt_key_padding_mask=tgt_key_padding_mask,
151
+ memory_key_padding_mask=memory_key_padding_mask,
152
+ pos=pos,
153
+ query_pos=query_pos,
154
+ )
155
+ if self.return_intermediate:
156
+ intermediate.append(self.norm(output))
157
+
158
+ if self.norm is not None:
159
+ output = self.norm(output)
160
+ if self.return_intermediate:
161
+ intermediate.pop()
162
+ intermediate.append(output)
163
+
164
+ if self.return_intermediate:
165
+ return torch.stack(intermediate)
166
+
167
+ return output.unsqueeze(0)
168
+
169
+
170
+ if __name__ == "__main__":
171
+ query = torch.randn(11, 1024, 256)
172
+ content = torch.randn(11, 1024, 256)
173
+ style = torch.randn(1024, 256)
174
+ pos = torch.randn(11, 1, 256)
175
+ m = DynamicFCDecoderLayer(256, 4, 256, 4, 4, 1024)
176
+
177
+ out = m(query, content, style, pos=pos)
178
+ print(out.shape)
damo/dreamtalk/core/networks/dynamic_linear.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from core.networks.dynamic_conv import DynamicConv
8
+
9
+
10
+ class DynamicLinear(nn.Module):
11
+ def __init__(self, in_planes, out_planes, cond_planes, bias=True, K=4, temperature=30, ratio=4, init_weight=True):
12
+ super().__init__()
13
+
14
+ self.dynamic_conv = DynamicConv(
15
+ in_planes,
16
+ out_planes,
17
+ cond_planes,
18
+ kernel_size=1,
19
+ stride=1,
20
+ padding=0,
21
+ bias=bias,
22
+ K=K,
23
+ ratio=ratio,
24
+ temperature=temperature,
25
+ init_weight=init_weight,
26
+ )
27
+
28
+ def forward(self, x, cond):
29
+ """
30
+
31
+ Args:
32
+ x (_type_): (L, B, C_in)
33
+ cond (_type_): (B, C_style)
34
+
35
+ Returns:
36
+ _type_: (L, B, C_out)
37
+ """
38
+ x = x.permute(1, 2, 0).unsqueeze(-1)
39
+ out = self.dynamic_conv(x, cond)
40
+ # (B, C_out, L, 1)
41
+ out = out.squeeze().permute(2, 0, 1)
42
+ return out
43
+
44
+
45
+ if __name__ == "__main__":
46
+ input = torch.randn(11, 1024, 255)
47
+ cond = torch.randn(1024, 256)
48
+ m = DynamicLinear(255, 1000, 256, K=7, temperature=5, ratio=8)
49
+ out = m(input, cond)
50
+ print(out.shape)
damo/dreamtalk/core/networks/generator.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .transformer import (
5
+ TransformerEncoder,
6
+ TransformerEncoderLayer,
7
+ PositionalEncoding,
8
+ TransformerDecoderLayer,
9
+ TransformerDecoder,
10
+ )
11
+ from core.utils import _reset_parameters
12
+ from core.networks.self_attention_pooling import SelfAttentionPooling
13
+
14
+
15
+ # class ContentEncoder(nn.Module):
16
+ # def __init__(
17
+ # self,
18
+ # d_model=512,
19
+ # nhead=8,
20
+ # num_encoder_layers=6,
21
+ # dim_feedforward=2048,
22
+ # dropout=0.1,
23
+ # activation="relu",
24
+ # normalize_before=False,
25
+ # pos_embed_len=80,
26
+ # ph_embed_dim=128,
27
+ # ):
28
+ # super().__init__()
29
+
30
+ # encoder_layer = TransformerEncoderLayer(
31
+ # d_model, nhead, dim_feedforward, dropout, activation, normalize_before
32
+ # )
33
+ # encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
34
+ # self.encoder = TransformerEncoder(
35
+ # encoder_layer, num_encoder_layers, encoder_norm
36
+ # )
37
+
38
+ # _reset_parameters(self.encoder)
39
+
40
+ # self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
41
+
42
+ # self.ph_embedding = nn.Embedding(41, ph_embed_dim)
43
+ # self.increase_embed_dim = nn.Linear(ph_embed_dim, d_model)
44
+
45
+ # def forward(self, x):
46
+ # """
47
+
48
+ # Args:
49
+ # x (_type_): (B, num_frames, window)
50
+
51
+ # Returns:
52
+ # content: (B, num_frames, window, C_dmodel)
53
+ # """
54
+ # x_embedding = self.ph_embedding(x)
55
+ # x_embedding = self.increase_embed_dim(x_embedding)
56
+ # # (B, N, W, C)
57
+ # B, N, W, C = x_embedding.shape
58
+ # x_embedding = x_embedding.reshape(B * N, W, C)
59
+ # x_embedding = x_embedding.permute(1, 0, 2)
60
+ # # (W, B*N, C)
61
+
62
+ # pos = self.pos_embed(W)
63
+ # pos = pos.permute(1, 0, 2)
64
+ # # (W, 1, C)
65
+
66
+ # content = self.encoder(x_embedding, pos=pos)
67
+ # # (W, B*N, C)
68
+ # content = content.permute(1, 0, 2).reshape(B, N, W, C)
69
+ # # (B, N, W, C)
70
+
71
+ # return content
72
+
73
+
74
+ class ContentW2VEncoder(nn.Module):
75
+ def __init__(
76
+ self,
77
+ d_model=512,
78
+ nhead=8,
79
+ num_encoder_layers=6,
80
+ dim_feedforward=2048,
81
+ dropout=0.1,
82
+ activation="relu",
83
+ normalize_before=False,
84
+ pos_embed_len=80,
85
+ ph_embed_dim=128,
86
+ ):
87
+ super().__init__()
88
+
89
+ encoder_layer = TransformerEncoderLayer(
90
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
91
+ )
92
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
93
+ self.encoder = TransformerEncoder(
94
+ encoder_layer, num_encoder_layers, encoder_norm
95
+ )
96
+
97
+ _reset_parameters(self.encoder)
98
+
99
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
100
+
101
+ self.increase_embed_dim = nn.Linear(1024, d_model)
102
+
103
+ def forward(self, x):
104
+ """
105
+ Args:
106
+ x (_type_): (B, num_frames, window, C_wav2vec)
107
+
108
+ Returns:
109
+ content: (B, num_frames, window, C_dmodel)
110
+ """
111
+ x_embedding = self.increase_embed_dim(
112
+ x
113
+ ) # [16, 64, 11, 1024] -> [16, 64, 11, 256]
114
+ # (B, N, W, C)
115
+ B, N, W, C = x_embedding.shape
116
+ x_embedding = x_embedding.reshape(B * N, W, C)
117
+ x_embedding = x_embedding.permute(1, 0, 2) # [11, 1024, 256]
118
+ # (W, B*N, C)
119
+
120
+ pos = self.pos_embed(W)
121
+ pos = pos.permute(1, 0, 2) # [11, 1, 256]
122
+ # (W, 1, C)
123
+
124
+ content = self.encoder(x_embedding, pos=pos) # [11, 1024, 256]
125
+ # (W, B*N, C)
126
+ content = content.permute(1, 0, 2).reshape(B, N, W, C)
127
+ # (B, N, W, C)
128
+
129
+ return content
130
+
131
+
132
+ class StyleEncoder(nn.Module):
133
+ def __init__(
134
+ self,
135
+ d_model=512,
136
+ nhead=8,
137
+ num_encoder_layers=6,
138
+ dim_feedforward=2048,
139
+ dropout=0.1,
140
+ activation="relu",
141
+ normalize_before=False,
142
+ pos_embed_len=80,
143
+ input_dim=128,
144
+ aggregate_method="average",
145
+ ):
146
+ super().__init__()
147
+ encoder_layer = TransformerEncoderLayer(
148
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
149
+ )
150
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
151
+ self.encoder = TransformerEncoder(
152
+ encoder_layer, num_encoder_layers, encoder_norm
153
+ )
154
+ _reset_parameters(self.encoder)
155
+
156
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
157
+
158
+ self.increase_embed_dim = nn.Linear(input_dim, d_model)
159
+
160
+ self.aggregate_method = None
161
+ if aggregate_method == "self_attention_pooling":
162
+ self.aggregate_method = SelfAttentionPooling(d_model)
163
+ elif aggregate_method == "average":
164
+ pass
165
+ else:
166
+ raise ValueError(f"Invalid aggregate method {aggregate_method}")
167
+
168
+ def forward(self, x, pad_mask=None):
169
+ """
170
+
171
+ Args:
172
+ x (_type_): (B, num_frames(L), C_exp)
173
+ pad_mask: (B, num_frames)
174
+
175
+ Returns:
176
+ style_code: (B, C_model)
177
+ """
178
+ x = self.increase_embed_dim(x)
179
+ # (B, L, C)
180
+ x = x.permute(1, 0, 2)
181
+ # (L, B, C)
182
+
183
+ pos = self.pos_embed(x.shape[0])
184
+ pos = pos.permute(1, 0, 2)
185
+ # (L, 1, C)
186
+
187
+ style = self.encoder(x, pos=pos, src_key_padding_mask=pad_mask)
188
+ # (L, B, C)
189
+
190
+ if self.aggregate_method is not None:
191
+ permute_style = style.permute(1, 0, 2)
192
+ # (B, L, C)
193
+ style_code = self.aggregate_method(permute_style, pad_mask)
194
+ return style_code
195
+
196
+ if pad_mask is None:
197
+ style = style.permute(1, 2, 0)
198
+ # (B, C, L)
199
+ style_code = style.mean(2)
200
+ # (B, C)
201
+ else:
202
+ permute_style = style.permute(1, 0, 2)
203
+ # (B, L, C)
204
+ permute_style[pad_mask] = 0
205
+ sum_style_code = permute_style.sum(dim=1)
206
+ # (B, C)
207
+ valid_token_num = (~pad_mask).sum(dim=1).unsqueeze(-1)
208
+ # (B, 1)
209
+ style_code = sum_style_code / valid_token_num
210
+ # (B, C)
211
+
212
+ return style_code
213
+
214
+
215
+ class Decoder(nn.Module):
216
+ def __init__(
217
+ self,
218
+ d_model=512,
219
+ nhead=8,
220
+ num_decoder_layers=3,
221
+ dim_feedforward=2048,
222
+ dropout=0.1,
223
+ activation="relu",
224
+ normalize_before=False,
225
+ return_intermediate_dec=False,
226
+ pos_embed_len=80,
227
+ output_dim=64,
228
+ **_,
229
+ ) -> None:
230
+ super().__init__()
231
+
232
+ decoder_layer = TransformerDecoderLayer(
233
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
234
+ )
235
+ decoder_norm = nn.LayerNorm(d_model)
236
+ self.decoder = TransformerDecoder(
237
+ decoder_layer,
238
+ num_decoder_layers,
239
+ decoder_norm,
240
+ return_intermediate=return_intermediate_dec,
241
+ )
242
+ _reset_parameters(self.decoder)
243
+
244
+ self.pos_embed = PositionalEncoding(d_model, pos_embed_len)
245
+
246
+ tail_hidden_dim = d_model // 2
247
+ self.tail_fc = nn.Sequential(
248
+ nn.Linear(d_model, tail_hidden_dim),
249
+ nn.ReLU(),
250
+ nn.Linear(tail_hidden_dim, tail_hidden_dim),
251
+ nn.ReLU(),
252
+ nn.Linear(tail_hidden_dim, output_dim),
253
+ )
254
+
255
+ def forward(self, content, style_code):
256
+ """
257
+
258
+ Args:
259
+ content (_type_): (B, num_frames, window, C_dmodel)
260
+ style_code (_type_): (B, C_dmodel)
261
+
262
+ Returns:
263
+ face3d: (B, num_frames, C_3dmm)
264
+ """
265
+ B, N, W, C = content.shape
266
+ style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C)
267
+ style = style.permute(2, 0, 1, 3).reshape(W, B * N, C)
268
+ # (W, B*N, C)
269
+
270
+ content = content.permute(2, 0, 1, 3).reshape(W, B * N, C)
271
+ # (W, B*N, C)
272
+ tgt = torch.zeros_like(style)
273
+ pos_embed = self.pos_embed(W)
274
+ pos_embed = pos_embed.permute(1, 0, 2)
275
+ face3d_feat = self.decoder(tgt, content, pos=pos_embed, query_pos=style)[0]
276
+ # (W, B*N, C)
277
+ face3d_feat = face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :]
278
+ # (B, N, C)
279
+ face3d = self.tail_fc(face3d_feat)
280
+ # (B, N, C_exp)
281
+ return face3d
282
+
283
+
284
+ if __name__ == "__main__":
285
+ import sys
286
+
287
+ sys.path.append("/home/mayifeng/Research/styleTH")
288
+
289
+ from configs.default import get_cfg_defaults
290
+
291
+ cfg = get_cfg_defaults()
292
+ cfg.merge_from_file("configs/styleTH_bp.yaml")
293
+ cfg.freeze()
294
+
295
+ # content_encoder = ContentEncoder(**cfg.CONTENT_ENCODER)
296
+
297
+ # dummy_audio = torch.randint(0, 41, (5, 64, 11))
298
+ # dummy_content = content_encoder(dummy_audio)
299
+
300
+ # style_encoder = StyleEncoder(**cfg.STYLE_ENCODER)
301
+ # dummy_face3d_seq = torch.randn(5, 64, 64)
302
+ # dummy_style_code = style_encoder(dummy_face3d_seq)
303
+
304
+ decoder = Decoder(**cfg.DECODER)
305
+ dummy_content = torch.randn(5, 64, 11, 512)
306
+ dummy_style = torch.randn(5, 512)
307
+ dummy_output = decoder(dummy_content, dummy_style)
308
+
309
+ print("hello")
damo/dreamtalk/core/networks/mish.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Applies the mish function element-wise:
3
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
4
+ """
5
+
6
+ # import pytorch
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ @torch.jit.script
12
+ def mish(input):
13
+ """
14
+ Applies the mish function element-wise:
15
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
16
+ See additional documentation for mish class.
17
+ """
18
+ return input * torch.tanh(F.softplus(input))
19
+
20
+ class Mish(nn.Module):
21
+ """
22
+ Applies the mish function element-wise:
23
+ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
24
+
25
+ Shape:
26
+ - Input: (N, *) where * means, any number of additional
27
+ dimensions
28
+ - Output: (N, *), same shape as the input
29
+
30
+ Examples:
31
+ >>> m = Mish()
32
+ >>> input = torch.randn(2)
33
+ >>> output = m(input)
34
+
35
+ Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html
36
+ """
37
+
38
+ def __init__(self):
39
+ """
40
+ Init method.
41
+ """
42
+ super().__init__()
43
+
44
+ def forward(self, input):
45
+ """
46
+ Forward pass of the function.
47
+ """
48
+ if torch.__version__ >= "1.9":
49
+ return F.mish(input)
50
+ else:
51
+ return mish(input)
damo/dreamtalk/core/networks/self_attention_pooling.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from core.networks.mish import Mish
4
+
5
+
6
+ class SelfAttentionPooling(nn.Module):
7
+ """
8
+ Implementation of SelfAttentionPooling
9
+ Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
10
+ https://arxiv.org/pdf/2008.01077v1.pdf
11
+ """
12
+
13
+ def __init__(self, input_dim):
14
+ super(SelfAttentionPooling, self).__init__()
15
+ self.W = nn.Sequential(nn.Linear(input_dim, input_dim), Mish(), nn.Linear(input_dim, 1))
16
+ self.softmax = nn.functional.softmax
17
+
18
+ def forward(self, batch_rep, att_mask=None):
19
+ """
20
+ N: batch size, T: sequence length, H: Hidden dimension
21
+ input:
22
+ batch_rep : size (N, T, H)
23
+ attention_weight:
24
+ att_w : size (N, T, 1)
25
+ att_mask:
26
+ att_mask: size (N, T): if True, mask this item.
27
+ return:
28
+ utter_rep: size (N, H)
29
+ """
30
+
31
+ att_logits = self.W(batch_rep).squeeze(-1)
32
+ # (N, T)
33
+ if att_mask is not None:
34
+ att_mask_logits = att_mask.to(dtype=batch_rep.dtype) * -100000.0
35
+ # (N, T)
36
+ att_logits = att_mask_logits + att_logits
37
+
38
+ att_w = self.softmax(att_logits, dim=-1).unsqueeze(-1)
39
+ utter_rep = torch.sum(batch_rep * att_w, dim=1)
40
+
41
+ return utter_rep
42
+
43
+
44
+ if __name__ == "__main__":
45
+ batch = torch.randn(8, 64, 256)
46
+ self_attn_pool = SelfAttentionPooling(256)
47
+ att_mask = torch.zeros(8, 64)
48
+ att_mask[:, 60:] = 1
49
+ att_mask = att_mask.to(torch.bool)
50
+ output = self_attn_pool(batch, att_mask)
51
+ # (8, 256)
52
+
53
+ print("hello")
damo/dreamtalk/core/networks/transformer.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ import copy
6
+
7
+
8
+ class PositionalEncoding(nn.Module):
9
+
10
+ def __init__(self, d_hid, n_position=200):
11
+ super(PositionalEncoding, self).__init__()
12
+
13
+ # Not a parameter
14
+ self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
15
+
16
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
17
+ ''' Sinusoid position encoding table '''
18
+ # TODO: make it with torch instead of numpy
19
+
20
+ def get_position_angle_vec(position):
21
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
22
+
23
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
24
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
25
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
26
+
27
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
28
+
29
+ def forward(self, winsize):
30
+ return self.pos_table[:, :winsize].clone().detach()
31
+
32
+ def _get_activation_fn(activation):
33
+ """Return an activation function given a string"""
34
+ if activation == "relu":
35
+ return F.relu
36
+ if activation == "gelu":
37
+ return F.gelu
38
+ if activation == "glu":
39
+ return F.glu
40
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
41
+
42
+ def _get_clones(module, N):
43
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
44
+
45
+ class Transformer(nn.Module):
46
+
47
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
48
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
49
+ activation="relu", normalize_before=False,
50
+ return_intermediate_dec=True):
51
+ super().__init__()
52
+
53
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
54
+ dropout, activation, normalize_before)
55
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
56
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
57
+
58
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
59
+ dropout, activation, normalize_before)
60
+ decoder_norm = nn.LayerNorm(d_model)
61
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
62
+ return_intermediate=return_intermediate_dec)
63
+
64
+ self._reset_parameters()
65
+
66
+ self.d_model = d_model
67
+ self.nhead = nhead
68
+
69
+ def _reset_parameters(self):
70
+ for p in self.parameters():
71
+ if p.dim() > 1:
72
+ nn.init.xavier_uniform_(p)
73
+
74
+ def forward(self,opt, src, query_embed, pos_embed):
75
+ # flatten NxCxHxW to HWxNxC
76
+
77
+ src = src.permute(1, 0, 2)
78
+ pos_embed = pos_embed.permute(1, 0, 2)
79
+ query_embed = query_embed.permute(1, 0, 2)
80
+
81
+ tgt = torch.zeros_like(query_embed)
82
+ memory = self.encoder(src, pos=pos_embed)
83
+
84
+ hs = self.decoder(tgt, memory,
85
+ pos=pos_embed, query_pos=query_embed)
86
+ return hs
87
+
88
+
89
+ class TransformerEncoder(nn.Module):
90
+
91
+ def __init__(self, encoder_layer, num_layers, norm=None):
92
+ super().__init__()
93
+ self.layers = _get_clones(encoder_layer, num_layers)
94
+ self.num_layers = num_layers
95
+ self.norm = norm
96
+
97
+ def forward(self, src, mask = None, src_key_padding_mask = None, pos = None):
98
+ output = src+pos
99
+
100
+ for layer in self.layers:
101
+ output = layer(output, src_mask=mask,
102
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
103
+
104
+ if self.norm is not None:
105
+ output = self.norm(output)
106
+
107
+ return output
108
+
109
+
110
+ class TransformerDecoder(nn.Module):
111
+
112
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
113
+ super().__init__()
114
+ self.layers = _get_clones(decoder_layer, num_layers)
115
+ self.num_layers = num_layers
116
+ self.norm = norm
117
+ self.return_intermediate = return_intermediate
118
+
119
+ def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None,
120
+ memory_key_padding_mask = None,
121
+ pos = None,
122
+ query_pos = None):
123
+ output = tgt+pos+query_pos
124
+
125
+ intermediate = []
126
+
127
+ for layer in self.layers:
128
+ output = layer(output, memory, tgt_mask=tgt_mask,
129
+ memory_mask=memory_mask,
130
+ tgt_key_padding_mask=tgt_key_padding_mask,
131
+ memory_key_padding_mask=memory_key_padding_mask,
132
+ pos=pos, query_pos=query_pos)
133
+ if self.return_intermediate:
134
+ intermediate.append(self.norm(output))
135
+
136
+ if self.norm is not None:
137
+ output = self.norm(output)
138
+ if self.return_intermediate:
139
+ intermediate.pop()
140
+ intermediate.append(output)
141
+
142
+ if self.return_intermediate:
143
+ return torch.stack(intermediate)
144
+
145
+ return output.unsqueeze(0)
146
+
147
+
148
+ class TransformerEncoderLayer(nn.Module):
149
+
150
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
151
+ activation="relu", normalize_before=False):
152
+ super().__init__()
153
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
154
+ # Implementation of Feedforward model
155
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
156
+ self.dropout = nn.Dropout(dropout)
157
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
158
+
159
+ self.norm1 = nn.LayerNorm(d_model)
160
+ self.norm2 = nn.LayerNorm(d_model)
161
+ self.dropout1 = nn.Dropout(dropout)
162
+ self.dropout2 = nn.Dropout(dropout)
163
+
164
+ self.activation = _get_activation_fn(activation)
165
+ self.normalize_before = normalize_before
166
+
167
+ def with_pos_embed(self, tensor, pos):
168
+ return tensor if pos is None else tensor + pos
169
+
170
+ def forward_post(self,
171
+ src,
172
+ src_mask = None,
173
+ src_key_padding_mask = None,
174
+ pos = None):
175
+ # q = k = self.with_pos_embed(src, pos)
176
+ src2 = self.self_attn(src, src, value=src, attn_mask=src_mask,
177
+ key_padding_mask=src_key_padding_mask)[0]
178
+ src = src + self.dropout1(src2)
179
+ src = self.norm1(src)
180
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
181
+ src = src + self.dropout2(src2)
182
+ src = self.norm2(src)
183
+ return src
184
+
185
+ def forward_pre(self, src,
186
+ src_mask = None,
187
+ src_key_padding_mask = None,
188
+ pos = None):
189
+ src2 = self.norm1(src)
190
+ # q = k = self.with_pos_embed(src2, pos)
191
+ src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask,
192
+ key_padding_mask=src_key_padding_mask)[0]
193
+ src = src + self.dropout1(src2)
194
+ src2 = self.norm2(src)
195
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
196
+ src = src + self.dropout2(src2)
197
+ return src
198
+
199
+ def forward(self, src,
200
+ src_mask = None,
201
+ src_key_padding_mask = None,
202
+ pos = None):
203
+ if self.normalize_before:
204
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
205
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
206
+
207
+
208
+ class TransformerDecoderLayer(nn.Module):
209
+
210
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
211
+ activation="relu", normalize_before=False):
212
+ super().__init__()
213
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
214
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
215
+ # Implementation of Feedforward model
216
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
217
+ self.dropout = nn.Dropout(dropout)
218
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
219
+
220
+ self.norm1 = nn.LayerNorm(d_model)
221
+ self.norm2 = nn.LayerNorm(d_model)
222
+ self.norm3 = nn.LayerNorm(d_model)
223
+ self.dropout1 = nn.Dropout(dropout)
224
+ self.dropout2 = nn.Dropout(dropout)
225
+ self.dropout3 = nn.Dropout(dropout)
226
+
227
+ self.activation = _get_activation_fn(activation)
228
+ self.normalize_before = normalize_before
229
+
230
+ def with_pos_embed(self, tensor, pos):
231
+ return tensor if pos is None else tensor + pos
232
+
233
+ def forward_post(self, tgt, memory,
234
+ tgt_mask = None,
235
+ memory_mask = None,
236
+ tgt_key_padding_mask = None,
237
+ memory_key_padding_mask = None,
238
+ pos = None,
239
+ query_pos = None):
240
+ # q = k = self.with_pos_embed(tgt, query_pos)
241
+ tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask,
242
+ key_padding_mask=tgt_key_padding_mask)[0]
243
+ tgt = tgt + self.dropout1(tgt2)
244
+ tgt = self.norm1(tgt)
245
+ tgt2 = self.multihead_attn(query=tgt,
246
+ key=memory,
247
+ value=memory, attn_mask=memory_mask,
248
+ key_padding_mask=memory_key_padding_mask)[0]
249
+ tgt = tgt + self.dropout2(tgt2)
250
+ tgt = self.norm2(tgt)
251
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
252
+ tgt = tgt + self.dropout3(tgt2)
253
+ tgt = self.norm3(tgt)
254
+ return tgt
255
+
256
+ def forward_pre(self, tgt, memory,
257
+ tgt_mask = None,
258
+ memory_mask = None,
259
+ tgt_key_padding_mask = None,
260
+ memory_key_padding_mask = None,
261
+ pos = None,
262
+ query_pos = None):
263
+ tgt2 = self.norm1(tgt)
264
+ # q = k = self.with_pos_embed(tgt2, query_pos)
265
+ tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask,
266
+ key_padding_mask=tgt_key_padding_mask)[0]
267
+ tgt = tgt + self.dropout1(tgt2)
268
+ tgt2 = self.norm2(tgt)
269
+ tgt2 = self.multihead_attn(query=tgt2,
270
+ key=memory,
271
+ value=memory, attn_mask=memory_mask,
272
+ key_padding_mask=memory_key_padding_mask)[0]
273
+ tgt = tgt + self.dropout2(tgt2)
274
+ tgt2 = self.norm3(tgt)
275
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
276
+ tgt = tgt + self.dropout3(tgt2)
277
+ return tgt
278
+
279
+ def forward(self, tgt, memory,
280
+ tgt_mask = None,
281
+ memory_mask = None,
282
+ tgt_key_padding_mask = None,
283
+ memory_key_padding_mask = None,
284
+ pos = None,
285
+ query_pos = None):
286
+ if self.normalize_before:
287
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
288
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
289
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
290
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
291
+
292
+
293
+
damo/dreamtalk/core/utils.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from collections import defaultdict
4
+ import logging
5
+ import pickle
6
+ import json
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+ from scipy.io import loadmat
12
+
13
+ from configs.default import get_cfg_defaults
14
+ import dlib
15
+ import cv2
16
+
17
+
18
+ def _reset_parameters(model):
19
+ for p in model.parameters():
20
+ if p.dim() > 1:
21
+ nn.init.xavier_uniform_(p)
22
+
23
+
24
+ def get_video_style(video_name, style_type):
25
+ person_id, direction, emotion, level, *_ = video_name.split("_")
26
+ if style_type == "id_dir_emo_level":
27
+ style = "_".join([person_id, direction, emotion, level])
28
+ elif style_type == "emotion":
29
+ style = emotion
30
+ elif style_type == "id":
31
+ style = person_id
32
+ else:
33
+ raise ValueError("Unknown style type")
34
+
35
+ return style
36
+
37
+
38
+ def get_style_video_lists(video_list, style_type):
39
+ style2video_list = defaultdict(list)
40
+ for video in video_list:
41
+ style = get_video_style(video, style_type)
42
+ style2video_list[style].append(video)
43
+
44
+ return style2video_list
45
+
46
+
47
+ def get_face3d_clip(
48
+ video_name, video_root_dir, num_frames, start_idx, dtype=torch.float32
49
+ ):
50
+ """_summary_
51
+
52
+ Args:
53
+ video_name (_type_): _description_
54
+ video_root_dir (_type_): _description_
55
+ num_frames (_type_): _description_
56
+ start_idx (_type_): "random" , middle, int
57
+ dtype (_type_, optional): _description_. Defaults to torch.float32.
58
+
59
+ Raises:
60
+ ValueError: _description_
61
+ ValueError: _description_
62
+
63
+ Returns:
64
+ _type_: _description_
65
+ """
66
+ video_path = os.path.join(video_root_dir, video_name)
67
+ if video_path[-3:] == "mat":
68
+ face3d_all = loadmat(video_path)["coeff"]
69
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
70
+ elif video_path[-3:] == "txt":
71
+ face3d_exp = np.loadtxt(video_path)
72
+ else:
73
+ raise ValueError("Invalid 3DMM file extension")
74
+
75
+ length = face3d_exp.shape[0]
76
+ clip_num_frames = num_frames
77
+ if start_idx == "random":
78
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
79
+ elif start_idx == "middle":
80
+ clip_start_idx = (length - clip_num_frames + 1) // 2
81
+ elif isinstance(start_idx, int):
82
+ clip_start_idx = start_idx
83
+ else:
84
+ raise ValueError(f"Invalid start_idx {start_idx}")
85
+
86
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
87
+ face3d_clip = torch.tensor(face3d_clip, dtype=dtype)
88
+
89
+ return face3d_clip
90
+
91
+
92
+ def get_video_style_clip(
93
+ video_name,
94
+ video_root_dir,
95
+ style_max_len,
96
+ start_idx="random",
97
+ dtype=torch.float32,
98
+ return_start_idx=False,
99
+ ):
100
+ video_path = os.path.join(video_root_dir, video_name)
101
+ if video_path[-3:] == "mat":
102
+ face3d_all = loadmat(video_path)["coeff"]
103
+ face3d_exp = face3d_all[:, 80:144] # expression 3DMM range
104
+ elif video_path[-3:] == "txt":
105
+ face3d_exp = np.loadtxt(video_path)
106
+ else:
107
+ raise ValueError("Invalid 3DMM file extension")
108
+
109
+ face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
110
+
111
+ length = face3d_exp.shape[0]
112
+ if length >= style_max_len:
113
+ clip_num_frames = style_max_len
114
+ if start_idx == "random":
115
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
116
+ elif start_idx == "middle":
117
+ clip_start_idx = (length - clip_num_frames + 1) // 2
118
+ elif isinstance(start_idx, int):
119
+ clip_start_idx = start_idx
120
+ else:
121
+ raise ValueError(f"Invalid start_idx {start_idx}")
122
+
123
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
124
+ pad_mask = torch.tensor([False] * style_max_len)
125
+ else:
126
+ clip_start_idx = None
127
+ padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
128
+ face3d_clip = torch.cat((face3d_exp, padding), dim=0)
129
+ pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
130
+
131
+ if return_start_idx:
132
+ return face3d_clip, pad_mask, clip_start_idx
133
+ else:
134
+ return face3d_clip, pad_mask
135
+
136
+
137
+ def get_video_style_clip_from_np(
138
+ face3d_exp,
139
+ style_max_len,
140
+ start_idx="random",
141
+ dtype=torch.float32,
142
+ return_start_idx=False,
143
+ ):
144
+ face3d_exp = torch.tensor(face3d_exp, dtype=dtype)
145
+
146
+ length = face3d_exp.shape[0]
147
+ if length >= style_max_len:
148
+ clip_num_frames = style_max_len
149
+ if start_idx == "random":
150
+ clip_start_idx = np.random.randint(low=0, high=length - clip_num_frames + 1)
151
+ elif start_idx == "middle":
152
+ clip_start_idx = (length - clip_num_frames + 1) // 2
153
+ elif isinstance(start_idx, int):
154
+ clip_start_idx = start_idx
155
+ else:
156
+ raise ValueError(f"Invalid start_idx {start_idx}")
157
+
158
+ face3d_clip = face3d_exp[clip_start_idx : clip_start_idx + clip_num_frames]
159
+ pad_mask = torch.tensor([False] * style_max_len)
160
+ else:
161
+ clip_start_idx = None
162
+ padding = torch.zeros(style_max_len - length, face3d_exp.shape[1])
163
+ face3d_clip = torch.cat((face3d_exp, padding), dim=0)
164
+ pad_mask = torch.tensor([False] * length + [True] * (style_max_len - length))
165
+
166
+ if return_start_idx:
167
+ return face3d_clip, pad_mask, clip_start_idx
168
+ else:
169
+ return face3d_clip, pad_mask
170
+
171
+
172
+ def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
173
+ """
174
+
175
+ Args:
176
+ audio_feat (np.ndarray): (N, 1024)
177
+ start_idx (_type_): _description_
178
+ num_frames (_type_): _description_
179
+ """
180
+ center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
181
+ audio_window_list = []
182
+ padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
183
+ for center_idx in center_idx_list:
184
+ cur_audio_window = []
185
+ for i in range(center_idx - win_size, center_idx + win_size + 1):
186
+ if i < 0:
187
+ cur_audio_window.append(padding)
188
+ elif i >= len(audio_feat):
189
+ cur_audio_window.append(padding)
190
+ else:
191
+ cur_audio_window.append(audio_feat[i])
192
+ cur_audio_win_array = np.stack(cur_audio_window, axis=0)
193
+ audio_window_list.append(cur_audio_win_array)
194
+
195
+ audio_window_array = np.stack(audio_window_list, axis=0)
196
+ return audio_window_array
197
+
198
+
199
+ def setup_config():
200
+ parser = argparse.ArgumentParser(description="voice2pose main program")
201
+ parser.add_argument(
202
+ "--config_file", default="", metavar="FILE", help="path to config file"
203
+ )
204
+ parser.add_argument(
205
+ "--resume_from", type=str, default=None, help="the checkpoint to resume from"
206
+ )
207
+ parser.add_argument(
208
+ "--test_only", action="store_true", help="perform testing and evaluation only"
209
+ )
210
+ parser.add_argument(
211
+ "--demo_input", type=str, default=None, help="path to input for demo"
212
+ )
213
+ parser.add_argument(
214
+ "--checkpoint", type=str, default=None, help="the checkpoint to test with"
215
+ )
216
+ parser.add_argument("--tag", type=str, default="", help="tag for the experiment")
217
+ parser.add_argument(
218
+ "opts",
219
+ help="Modify config options using the command-line",
220
+ default=None,
221
+ nargs=argparse.REMAINDER,
222
+ )
223
+ parser.add_argument(
224
+ "--local_rank",
225
+ type=int,
226
+ help="local rank for DistributedDataParallel",
227
+ )
228
+ parser.add_argument(
229
+ "--master_port",
230
+ type=str,
231
+ default="12345",
232
+ )
233
+ parser.add_argument(
234
+ "--max_audio_len",
235
+ type=int,
236
+ default=450,
237
+ help="max_audio_len for inference",
238
+ )
239
+ parser.add_argument(
240
+ "--ddim_num_step",
241
+ type=int,
242
+ default=10,
243
+ )
244
+ parser.add_argument(
245
+ "--inference_seed",
246
+ type=int,
247
+ default=1,
248
+ )
249
+ parser.add_argument(
250
+ "--inference_sample_method",
251
+ type=str,
252
+ default="ddim",
253
+ )
254
+ args = parser.parse_args()
255
+
256
+ cfg = get_cfg_defaults()
257
+ cfg.merge_from_file(args.config_file)
258
+ cfg.merge_from_list(args.opts)
259
+ cfg.freeze()
260
+ return args, cfg
261
+
262
+
263
+ def setup_logger(base_path, exp_name):
264
+ rootLogger = logging.getLogger()
265
+ rootLogger.setLevel(logging.INFO)
266
+
267
+ logFormatter = logging.Formatter("%(asctime)s [%(levelname)-0.5s] %(message)s")
268
+
269
+ log_path = "{0}/{1}.log".format(base_path, exp_name)
270
+ fileHandler = logging.FileHandler(log_path)
271
+ fileHandler.setFormatter(logFormatter)
272
+ rootLogger.addHandler(fileHandler)
273
+
274
+ consoleHandler = logging.StreamHandler()
275
+ consoleHandler.setFormatter(logFormatter)
276
+ rootLogger.addHandler(consoleHandler)
277
+ rootLogger.handlers[0].setLevel(logging.INFO)
278
+
279
+ logging.info("log path: %s" % log_path)
280
+
281
+
282
+ def cosine_loss(a, v, y, logloss=nn.BCELoss()):
283
+ d = nn.functional.cosine_similarity(a, v)
284
+ loss = logloss(d.unsqueeze(1), y)
285
+ return loss
286
+
287
+
288
+ def get_pose_params(mat_path):
289
+ """Get pose parameters from mat file
290
+
291
+ Args:
292
+ mat_path (str): path of mat file
293
+
294
+ Returns:
295
+ pose_params (numpy.ndarray): shape (L_video, 9), angle, translation, crop paramters
296
+ """
297
+ mat_dict = loadmat(mat_path)
298
+
299
+ np_3dmm = mat_dict["coeff"]
300
+ angles = np_3dmm[:, 224:227]
301
+ translations = np_3dmm[:, 254:257]
302
+
303
+ np_trans_params = mat_dict["transform_params"]
304
+ crop = np_trans_params[:, -3:]
305
+
306
+ pose_params = np.concatenate((angles, translations, crop), axis=1)
307
+
308
+ return pose_params
309
+
310
+
311
+ def sinusoidal_embedding(timesteps, dim):
312
+ """
313
+
314
+ Args:
315
+ timesteps (_type_): (B,)
316
+ dim (_type_): (C_embed)
317
+
318
+ Returns:
319
+ _type_: (B, C_embed)
320
+ """
321
+ # check input
322
+ half = dim // 2
323
+ timesteps = timesteps.float()
324
+
325
+ # compute sinusoidal embedding
326
+ sinusoid = torch.outer(
327
+ timesteps, torch.pow(10000, -torch.arange(half).to(timesteps).div(half))
328
+ )
329
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
330
+ if dim % 2 != 0:
331
+ x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
332
+ return x
333
+
334
+
335
+ def get_wav2vec_audio_window(audio_feat, start_idx, num_frames, win_size):
336
+ """
337
+
338
+ Args:
339
+ audio_feat (np.ndarray): (250, 1024)
340
+ start_idx (_type_): _description_
341
+ num_frames (_type_): _description_
342
+ """
343
+ center_idx_list = [2 * idx for idx in range(start_idx, start_idx + num_frames)]
344
+ audio_window_list = []
345
+ padding = np.zeros(audio_feat.shape[1], dtype=np.float32)
346
+ for center_idx in center_idx_list:
347
+ cur_audio_window = []
348
+ for i in range(center_idx - win_size, center_idx + win_size + 1):
349
+ if i < 0:
350
+ cur_audio_window.append(padding)
351
+ elif i >= len(audio_feat):
352
+ cur_audio_window.append(padding)
353
+ else:
354
+ cur_audio_window.append(audio_feat[i])
355
+ cur_audio_win_array = np.stack(cur_audio_window, axis=0)
356
+ audio_window_list.append(cur_audio_win_array)
357
+
358
+ audio_window_array = np.stack(audio_window_list, axis=0)
359
+ return audio_window_array
360
+
361
+
362
+ def reshape_audio_feat(style_audio_all_raw, stride):
363
+ """_summary_
364
+
365
+ Args:
366
+ style_audio_all_raw (_type_): (stride * L, C)
367
+ stride (_type_): int
368
+
369
+ Returns:
370
+ _type_: (L, C * stride)
371
+ """
372
+ style_audio_all_raw = style_audio_all_raw[
373
+ : style_audio_all_raw.shape[0] // stride * stride
374
+ ]
375
+ style_audio_all_raw = style_audio_all_raw.reshape(
376
+ style_audio_all_raw.shape[0] // stride, stride, style_audio_all_raw.shape[1]
377
+ )
378
+ style_audio_all = style_audio_all_raw.reshape(style_audio_all_raw.shape[0], -1)
379
+ return style_audio_all
380
+
381
+
382
+ import random
383
+
384
+
385
+ def get_derangement_tuple(n):
386
+ while True:
387
+ v = [i for i in range(n)]
388
+ for j in range(n - 1, -1, -1):
389
+ p = random.randint(0, j)
390
+ if v[p] == j:
391
+ break
392
+ else:
393
+ v[j], v[p] = v[p], v[j]
394
+ else:
395
+ if v[0] != 0:
396
+ return tuple(v)
397
+
398
+
399
+ def compute_aspect_preserved_bbox(bbox, increase_area, h, w):
400
+ left, top, right, bot = bbox
401
+ width = right - left
402
+ height = bot - top
403
+
404
+ width_increase = max(
405
+ increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)
406
+ )
407
+ height_increase = max(
408
+ increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)
409
+ )
410
+
411
+ left_t = int(left - width_increase * width)
412
+ top_t = int(top - height_increase * height)
413
+ right_t = int(right + width_increase * width)
414
+ bot_t = int(bot + height_increase * height)
415
+
416
+ left_oob = -min(0, left_t)
417
+ right_oob = right - min(right_t, w)
418
+ top_oob = -min(0, top_t)
419
+ bot_oob = bot - min(bot_t, h)
420
+
421
+ if max(left_oob, right_oob, top_oob, bot_oob) > 0:
422
+ max_w = max(left_oob, right_oob)
423
+ max_h = max(top_oob, bot_oob)
424
+ if max_w > max_h:
425
+ return left_t + max_w, top_t + max_w, right_t - max_w, bot_t - max_w
426
+ else:
427
+ return left_t + max_h, top_t + max_h, right_t - max_h, bot_t - max_h
428
+
429
+ else:
430
+ return (left_t, top_t, right_t, bot_t)
431
+
432
+
433
+ def crop_src_image(src_img, save_img, increase_ratio, detector=None):
434
+ if detector is None:
435
+ detector = dlib.get_frontal_face_detector()
436
+
437
+ img = cv2.imread(src_img)
438
+ faces = detector(img, 0)
439
+ h, width, _ = img.shape
440
+ if len(faces) > 0:
441
+ bbox = [faces[0].left(), faces[0].top(), faces[0].right(), faces[0].bottom()]
442
+ l = bbox[3] - bbox[1]
443
+ bbox[1] = bbox[1] - l * 0.1
444
+ bbox[3] = bbox[3] - l * 0.1
445
+ bbox[1] = max(0, bbox[1])
446
+ bbox[3] = min(h, bbox[3])
447
+ bbox = compute_aspect_preserved_bbox(
448
+ tuple(bbox), increase_ratio, img.shape[0], img.shape[1]
449
+ )
450
+ img = img[bbox[1] : bbox[3], bbox[0] : bbox[2]]
451
+ img = cv2.resize(img, (256, 256))
452
+ cv2.imwrite(save_img, img)
453
+ else:
454
+ raise ValueError("No face detected in the input image")
455
+ # img = cv2.resize(img, (256, 256))
456
+ # cv2.imwrite(save_img, img)
damo/dreamtalk/data/audio/German1.wav ADDED
Binary file (279 kB). View file
 
damo/dreamtalk/data/audio/German2.wav ADDED
Binary file (219 kB). View file
 
damo/dreamtalk/data/audio/German3.wav ADDED
Binary file (240 kB). View file
 
damo/dreamtalk/data/audio/German4.wav ADDED
Binary file (219 kB). View file
 
damo/dreamtalk/data/audio/acknowledgement_chinese.m4a ADDED
Binary file (537 kB). View file
 
damo/dreamtalk/data/audio/acknowledgement_english.m4a ADDED
Binary file (511 kB). View file
 
damo/dreamtalk/data/audio/chinese1_haierlizhi.wav ADDED
Binary file (420 kB). View file
 
damo/dreamtalk/data/audio/chinese2_guanyu.wav ADDED
Binary file (638 kB). View file
 
damo/dreamtalk/data/audio/french1.wav ADDED
Binary file (220 kB). View file
 
damo/dreamtalk/data/audio/french2.wav ADDED
Binary file (177 kB). View file
 
damo/dreamtalk/data/audio/french3.wav ADDED
Binary file (168 kB). View file
 
damo/dreamtalk/data/audio/italian1.wav ADDED
Binary file (285 kB). View file
 
damo/dreamtalk/data/audio/italian2.wav ADDED
Binary file (170 kB). View file
 
damo/dreamtalk/data/audio/italian3.wav ADDED
Binary file (197 kB). View file
 
damo/dreamtalk/data/audio/japan1.wav ADDED
Binary file (197 kB). View file
 
damo/dreamtalk/data/audio/japan2.wav ADDED
Binary file (231 kB). View file
 
damo/dreamtalk/data/audio/japan3.wav ADDED
Binary file (234 kB). View file
 
damo/dreamtalk/data/audio/korean1.wav ADDED
Binary file (328 kB). View file
 
damo/dreamtalk/data/audio/korean2.wav ADDED
Binary file (210 kB). View file
 
damo/dreamtalk/data/audio/korean3.wav ADDED
Binary file (148 kB). View file
 
damo/dreamtalk/data/audio/noisy_audio_cafeter_snr_0.wav ADDED
Binary file (206 kB). View file
 
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_0.wav ADDED
Binary file (206 kB). View file
 
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_10.wav ADDED
Binary file (206 kB). View file
 
damo/dreamtalk/data/audio/noisy_audio_meeting_snr_20.wav ADDED
Binary file (206 kB). View file
 
damo/dreamtalk/data/audio/noisy_audio_narrative.wav ADDED
Binary file (206 kB). View file
 
damo/dreamtalk/data/audio/noisy_audio_office_snr_0.wav ADDED
Binary file (206 kB). View file
 
damo/dreamtalk/data/audio/out_of_domain_narrative.wav ADDED
Binary file (445 kB). View file
 
damo/dreamtalk/data/audio/spanish1.wav ADDED
Binary file (144 kB). View file
 
damo/dreamtalk/data/audio/spanish2.wav ADDED
Binary file (150 kB). View file
 
damo/dreamtalk/data/audio/spanish3.wav ADDED
Binary file (212 kB). View file