Spaces:
Running
Running
cantabile-kwok
commited on
Commit
·
05005db
1
Parent(s):
8bd60fe
prepare demo page
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +51 -0
- pretrained/WavLM-Large.pt +3 -0
- pretrained/config.yml +201 -0
- pretrained/generator.ckpt +3 -0
- pretrained/vq-wav2vec_kmeans.pt +3 -0
- requirements.txt +25 -0
- vec2wav2/__init__.py +3 -0
- vec2wav2/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/__pycache__/__init__.cpython-311.pyc +0 -0
- vec2wav2/__pycache__/__init__.cpython-39.pyc +0 -0
- vec2wav2/bin/.DS_Store +0 -0
- vec2wav2/bin/__init__.py +0 -0
- vec2wav2/bin/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/bin/__pycache__/vc.cpython-310.pyc +0 -0
- vec2wav2/bin/decode.py +163 -0
- vec2wav2/bin/gradio_app.py +51 -0
- vec2wav2/bin/train.py +1007 -0
- vec2wav2/bin/vc.py +128 -0
- vec2wav2/datasets/__init__.py +1 -0
- vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc +0 -0
- vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc +0 -0
- vec2wav2/datasets/scp_dataset.py +300 -0
- vec2wav2/distributed/__init__.py +0 -0
- vec2wav2/distributed/launch.py +163 -0
- vec2wav2/layers/__init__.py +6 -0
- vec2wav2/layers/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/__init__.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/activations.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/upsample.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/upsample.cpython-39.pyc +0 -0
- vec2wav2/layers/activations.py +197 -0
- vec2wav2/layers/causal_conv.py +66 -0
- vec2wav2/layers/pqmf.py +150 -0
- vec2wav2/layers/residual_block.py +222 -0
- vec2wav2/layers/residual_stack.py +85 -0
- vec2wav2/layers/tade_res_block.py +160 -0
- vec2wav2/layers/upsample.py +194 -0
- vec2wav2/losses/__init__.py +4 -0
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import logging
|
6 |
+
import yaml
|
7 |
+
import soundfile as sf
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args
|
11 |
+
|
12 |
+
# Create Gradio interface
|
13 |
+
def create_interface():
|
14 |
+
args = vc_args()
|
15 |
+
logger = configure_logging(args.verbose)
|
16 |
+
voice_converter = VoiceConverter(
|
17 |
+
expdir=args.expdir,
|
18 |
+
token_extractor=args.token_extractor,
|
19 |
+
prompt_extractor=args.prompt_extractor,
|
20 |
+
prompt_output_layer=args.prompt_output_layer,
|
21 |
+
checkpoint=args.checkpoint,
|
22 |
+
script_logger=logger
|
23 |
+
)
|
24 |
+
with gr.Blocks(title="Voice Conversion") as demo:
|
25 |
+
gr.Markdown("# vec2wav 2.0 Voice Conversion Demo")
|
26 |
+
gr.Markdown("Upload source audio and target speaker audio to convert the voice.")
|
27 |
+
|
28 |
+
with gr.Row():
|
29 |
+
source_audio = gr.Audio(label="Source Audio", type="filepath")
|
30 |
+
target_audio = gr.Audio(label="Target Speaker Audio", type="filepath")
|
31 |
+
|
32 |
+
examples = [
|
33 |
+
["examples/Zuckerberg.wav", "examples/Rachel.wav"],
|
34 |
+
["examples/TheresaMay.wav", "examples/OptimusPrime.wav"]
|
35 |
+
]
|
36 |
+
gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio])
|
37 |
+
|
38 |
+
convert_btn = gr.Button("Convert Voice")
|
39 |
+
output_audio = gr.Audio(label="Converted Audio")
|
40 |
+
|
41 |
+
convert_btn.click(
|
42 |
+
fn=voice_converter.voice_conversion,
|
43 |
+
inputs=[source_audio, target_audio],
|
44 |
+
outputs=output_audio
|
45 |
+
)
|
46 |
+
|
47 |
+
return demo
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
demo = create_interface()
|
51 |
+
demo.launch(share=True)
|
pretrained/WavLM-Large.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f
|
3 |
+
size 1261965425
|
pretrained/config.yml
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
allow_cache: false
|
2 |
+
batch_frames: 3600
|
3 |
+
config: conf/ctxv2w.v1.yaml
|
4 |
+
crop_max_frames: 100
|
5 |
+
discriminator_adv_loss_params:
|
6 |
+
average_by_discriminators: false
|
7 |
+
discriminator_grad_norm: -1
|
8 |
+
discriminator_optimizer_params:
|
9 |
+
betas:
|
10 |
+
- 0.5
|
11 |
+
- 0.9
|
12 |
+
lr: 0.0002
|
13 |
+
weight_decay: 0.0
|
14 |
+
discriminator_optimizer_type: Adam
|
15 |
+
discriminator_params:
|
16 |
+
follow_official_norm: true
|
17 |
+
period_discriminator_params:
|
18 |
+
bias: true
|
19 |
+
channels: 32
|
20 |
+
downsample_scales:
|
21 |
+
- 3
|
22 |
+
- 3
|
23 |
+
- 3
|
24 |
+
- 3
|
25 |
+
- 1
|
26 |
+
in_channels: 1
|
27 |
+
kernel_sizes:
|
28 |
+
- 5
|
29 |
+
- 3
|
30 |
+
max_downsample_channels: 1024
|
31 |
+
nonlinear_activation: LeakyReLU
|
32 |
+
nonlinear_activation_params:
|
33 |
+
negative_slope: 0.1
|
34 |
+
out_channels: 1
|
35 |
+
use_spectral_norm: false
|
36 |
+
use_weight_norm: true
|
37 |
+
periods:
|
38 |
+
- 2
|
39 |
+
- 3
|
40 |
+
- 5
|
41 |
+
- 7
|
42 |
+
- 11
|
43 |
+
scale_discriminator_params:
|
44 |
+
bias: true
|
45 |
+
channels: 128
|
46 |
+
downsample_scales:
|
47 |
+
- 4
|
48 |
+
- 4
|
49 |
+
- 4
|
50 |
+
- 4
|
51 |
+
- 1
|
52 |
+
in_channels: 1
|
53 |
+
kernel_sizes:
|
54 |
+
- 15
|
55 |
+
- 41
|
56 |
+
- 5
|
57 |
+
- 3
|
58 |
+
max_downsample_channels: 1024
|
59 |
+
max_groups: 16
|
60 |
+
nonlinear_activation: LeakyReLU
|
61 |
+
nonlinear_activation_params:
|
62 |
+
negative_slope: 0.1
|
63 |
+
out_channels: 1
|
64 |
+
scale_downsample_pooling: AvgPool1d
|
65 |
+
scale_downsample_pooling_params:
|
66 |
+
kernel_size: 4
|
67 |
+
padding: 2
|
68 |
+
stride: 2
|
69 |
+
scales: 3
|
70 |
+
discriminator_scheduler_params:
|
71 |
+
gamma: 0.5
|
72 |
+
milestones:
|
73 |
+
- 200000
|
74 |
+
- 400000
|
75 |
+
- 600000
|
76 |
+
- 800000
|
77 |
+
discriminator_scheduler_type: MultiStepLR
|
78 |
+
discriminator_train_start_steps: 0
|
79 |
+
discriminator_type: HiFiGANMultiScaleMultiPeriodDiscriminator
|
80 |
+
distributed: true
|
81 |
+
dropout_features: 0.0
|
82 |
+
eval_interval_steps: 100000
|
83 |
+
feat_match_loss_params:
|
84 |
+
average_by_discriminators: false
|
85 |
+
average_by_layers: false
|
86 |
+
include_final_outputs: false
|
87 |
+
frontend_mel_prediction_stop_steps: 200000
|
88 |
+
frontend_params:
|
89 |
+
conformer_params:
|
90 |
+
activation_type: swish
|
91 |
+
attention_dim: 184
|
92 |
+
attention_dropout_rate: 0.2
|
93 |
+
attention_heads: 2
|
94 |
+
cnn_module_kernel: 31
|
95 |
+
concat_after: false
|
96 |
+
dropout_rate: 0.2
|
97 |
+
linear_units: 1536
|
98 |
+
macaron_style: true
|
99 |
+
normalize_before: true
|
100 |
+
num_blocks: 2
|
101 |
+
pos_enc_layer_type: rel_pos
|
102 |
+
positional_dropout_rate: 0.2
|
103 |
+
positionwise_conv_kernel_size: 3
|
104 |
+
positionwise_layer_type: conv1d
|
105 |
+
selfattention_layer_type: rel_selfattn
|
106 |
+
use_cnn_module: true
|
107 |
+
prompt_channels: 1024
|
108 |
+
vqvec_channels: 512
|
109 |
+
generator_adv_loss_params:
|
110 |
+
average_by_discriminators: false
|
111 |
+
generator_grad_norm: -1
|
112 |
+
generator_optimizer_params:
|
113 |
+
betas:
|
114 |
+
- 0.5
|
115 |
+
- 0.9
|
116 |
+
lr: 0.0002
|
117 |
+
weight_decay: 0.0
|
118 |
+
generator_optimizer_type: Adam
|
119 |
+
generator_params:
|
120 |
+
bias: true
|
121 |
+
channels: 512
|
122 |
+
condition_dim: 1024
|
123 |
+
in_channels: 184
|
124 |
+
kernel_size: 7
|
125 |
+
nonlinear_activation: snakebeta-condition
|
126 |
+
out_channels: 1
|
127 |
+
resblock: '1'
|
128 |
+
resblock_dilations:
|
129 |
+
- - 1
|
130 |
+
- 3
|
131 |
+
- 5
|
132 |
+
- - 1
|
133 |
+
- 3
|
134 |
+
- 5
|
135 |
+
- - 1
|
136 |
+
- 3
|
137 |
+
- 5
|
138 |
+
resblock_kernel_sizes:
|
139 |
+
- 3
|
140 |
+
- 7
|
141 |
+
- 11
|
142 |
+
snake_logscale: true
|
143 |
+
upsample_kernel_sizes:
|
144 |
+
- 16
|
145 |
+
- 10
|
146 |
+
- 6
|
147 |
+
- 4
|
148 |
+
upsample_scales:
|
149 |
+
- 8
|
150 |
+
- 5
|
151 |
+
- 3
|
152 |
+
- 2
|
153 |
+
use_additional_convs: true
|
154 |
+
use_weight_norm: true
|
155 |
+
generator_scheduler_params:
|
156 |
+
gamma: 0.5
|
157 |
+
milestones:
|
158 |
+
- 200000
|
159 |
+
- 400000
|
160 |
+
- 600000
|
161 |
+
- 800000
|
162 |
+
generator_scheduler_type: MultiStepLR
|
163 |
+
generator_train_start_steps: 1
|
164 |
+
generator_type: BigVGAN
|
165 |
+
hop_size: 240
|
166 |
+
lambda_adv: 1.0
|
167 |
+
lambda_aux: 45.0
|
168 |
+
lambda_feat_match: 2.0
|
169 |
+
lambda_frontend_mel_prediction: 60
|
170 |
+
log_interval_steps: 1000
|
171 |
+
max_num_frames: 3000
|
172 |
+
mel_loss_params:
|
173 |
+
fft_size: 2048
|
174 |
+
fmax: 8000
|
175 |
+
fmin: 40
|
176 |
+
fs: 24000
|
177 |
+
hop_size: 300
|
178 |
+
log_base: null
|
179 |
+
num_mels: 80
|
180 |
+
win_length: 1200
|
181 |
+
window: hann
|
182 |
+
min_num_frames: 600
|
183 |
+
num_mels: 80
|
184 |
+
num_save_intermediate_results: 4
|
185 |
+
num_workers: 8
|
186 |
+
outdir: exp/train_all_ctxv2w.v1
|
187 |
+
pin_memory: true
|
188 |
+
pretrain: ''
|
189 |
+
prompt_fold_by_2: true
|
190 |
+
prompt_net_type: ConvPromptPrenet
|
191 |
+
rank: 0
|
192 |
+
sampling_rate: 24000
|
193 |
+
save_interval_steps: 10000
|
194 |
+
use_feat_match_loss: true
|
195 |
+
use_mel_loss: true
|
196 |
+
use_stft_loss: false
|
197 |
+
verbose: 1
|
198 |
+
version: 0.5.3
|
199 |
+
vq_codebook: feats/vqidx/codebook.npy
|
200 |
+
win_length: 697
|
201 |
+
world_size: 4
|
pretrained/generator.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a10b9df62462bbf48382970ffba267b458b00b361bcb245701e3d3c0b6bd19f
|
3 |
+
size 161604549
|
pretrained/vq-wav2vec_kmeans.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c975a93479dc5f3cfc4339032e1547c6034eddd15eb1cba73364c20786b42a5a
|
3 |
+
size 336509919
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchaudio==0.13.1
|
2 |
+
auraloss==0.4.0
|
3 |
+
cython==3.0.10
|
4 |
+
einops
|
5 |
+
debugpy==1.8.0
|
6 |
+
fairseq==0.12.2
|
7 |
+
filelock~=3.12.2
|
8 |
+
h5py
|
9 |
+
kaldiio~=2.18.0
|
10 |
+
librosa==0.8.1
|
11 |
+
matplotlib~=3.4.3
|
12 |
+
nltk==3.8.1
|
13 |
+
numpy
|
14 |
+
pathlib~=1.0.1
|
15 |
+
pyyaml~=6.0
|
16 |
+
scikit-learn
|
17 |
+
scipy~=1.7.1
|
18 |
+
setuptools==65.6.3
|
19 |
+
six==1.16.0
|
20 |
+
soundfile~=0.10.3.post1
|
21 |
+
sox
|
22 |
+
tensorboard
|
23 |
+
tensorboardx~=2.5.1
|
24 |
+
tqdm~=4.62.3
|
25 |
+
transformers==4.42.3
|
vec2wav2/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
__version__ = ""
|
vec2wav2/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (211 Bytes). View file
|
|
vec2wav2/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (289 Bytes). View file
|
|
vec2wav2/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (225 Bytes). View file
|
|
vec2wav2/bin/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
vec2wav2/bin/__init__.py
ADDED
File without changes
|
vec2wav2/bin/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (199 Bytes). View file
|
|
vec2wav2/bin/__pycache__/vc.cpython-310.pyc
ADDED
Binary file (4.76 kB). View file
|
|
vec2wav2/bin/decode.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Tomoki Hayashi
|
5 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
6 |
+
|
7 |
+
# Modified by Yiwei Guo, 2024
|
8 |
+
|
9 |
+
"""Decode with trained vec2wav Generator."""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import soundfile as sf
|
18 |
+
import torch
|
19 |
+
import yaml
|
20 |
+
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from vec2wav2.datasets import MelSCPDataset
|
24 |
+
from vec2wav2.utils import load_model, load_feat_codebook, idx2vec
|
25 |
+
|
26 |
+
|
27 |
+
def set_loglevel(verbose):
|
28 |
+
# set logger
|
29 |
+
if verbose > 1:
|
30 |
+
logging.basicConfig(
|
31 |
+
level=logging.DEBUG,
|
32 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
33 |
+
)
|
34 |
+
elif verbose > 0:
|
35 |
+
logging.basicConfig(
|
36 |
+
level=logging.INFO,
|
37 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
logging.basicConfig(
|
41 |
+
level=logging.WARN,
|
42 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
43 |
+
)
|
44 |
+
logging.warning("Skip DEBUG/INFO messages")
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
"""Run decoding process."""
|
49 |
+
parser = argparse.ArgumentParser(
|
50 |
+
description="Decode from audio tokens and acoustic prompts with trained vec2wav model"
|
51 |
+
"(See detail in vec2wav2/bin/decode.py)."
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--feats-scp",
|
55 |
+
"--scp",
|
56 |
+
default=None,
|
57 |
+
type=str,
|
58 |
+
required=True,
|
59 |
+
help="kaldi-style feats.scp file. "
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--prompt-scp",
|
63 |
+
default=None,
|
64 |
+
type=str,
|
65 |
+
help="kaldi-style prompt.scp file. Similar to feats.scp."
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--outdir",
|
69 |
+
type=str,
|
70 |
+
required=True,
|
71 |
+
help="directory to save generated speech.",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--checkpoint",
|
75 |
+
type=str,
|
76 |
+
required=True,
|
77 |
+
help="checkpoint file to be loaded.",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--config",
|
81 |
+
default=None,
|
82 |
+
type=str,
|
83 |
+
help="yaml format configuration file. if not explicitly provided, "
|
84 |
+
"it will be searched in the checkpoint directory. (default=None)",
|
85 |
+
)
|
86 |
+
parser.add_argument(
|
87 |
+
"--verbose",
|
88 |
+
type=int,
|
89 |
+
default=1,
|
90 |
+
help="logging level. higher is more logging. (default=1)",
|
91 |
+
)
|
92 |
+
args = parser.parse_args()
|
93 |
+
set_loglevel(args.verbose)
|
94 |
+
|
95 |
+
# check directory existence
|
96 |
+
if not os.path.exists(args.outdir):
|
97 |
+
os.makedirs(args.outdir)
|
98 |
+
|
99 |
+
# load config
|
100 |
+
if args.config is None:
|
101 |
+
dirname = os.path.dirname(args.checkpoint)
|
102 |
+
args.config = os.path.join(dirname, "config.yml")
|
103 |
+
with open(args.config) as f:
|
104 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
105 |
+
config.update(vars(args))
|
106 |
+
|
107 |
+
# get dataset
|
108 |
+
dataset = MelSCPDataset(
|
109 |
+
vqidx_scp=args.feats_scp,
|
110 |
+
prompt_scp=args.prompt_scp,
|
111 |
+
return_utt_id=True,
|
112 |
+
)
|
113 |
+
logging.info(f"The number of features to be decoded = {len(dataset)}.")
|
114 |
+
|
115 |
+
# setup model
|
116 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
117 |
+
logging.info(f"Using {'GPU' if torch.cuda.is_available() else 'CPU'}.")
|
118 |
+
|
119 |
+
model = load_model(args.checkpoint, config)
|
120 |
+
logging.info(f"Loaded model parameters from {args.checkpoint}.")
|
121 |
+
|
122 |
+
model.backend.remove_weight_norm()
|
123 |
+
model = model.eval().to(device)
|
124 |
+
|
125 |
+
# load vq codebook
|
126 |
+
feat_codebook, feat_codebook_numgroups = load_feat_codebook(np.load(config["vq_codebook"], allow_pickle=True), device)
|
127 |
+
|
128 |
+
# start generation
|
129 |
+
total_rtf = 0.0
|
130 |
+
with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
|
131 |
+
for idx, batch in enumerate(pbar, 1):
|
132 |
+
utt_id, vqidx, prompt = batch[0], batch[1], batch[2]
|
133 |
+
|
134 |
+
vqidx = torch.tensor(vqidx).to(device) # (L, G)
|
135 |
+
prompt = torch.tensor(prompt).unsqueeze(0).to(device) # (1, L', D')
|
136 |
+
|
137 |
+
vqidx = vqidx.long()
|
138 |
+
vqvec = idx2vec(feat_codebook, vqidx, feat_codebook_numgroups).unsqueeze(0) # (1, L, D)
|
139 |
+
|
140 |
+
# generate
|
141 |
+
start = time.time()
|
142 |
+
y = model.inference(vqvec, prompt)[-1].view(-1)
|
143 |
+
rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
|
144 |
+
pbar.set_postfix({"RTF": rtf})
|
145 |
+
total_rtf += rtf
|
146 |
+
|
147 |
+
tgt_dir = os.path.dirname(os.path.join(config["outdir"], f"{utt_id}.wav"))
|
148 |
+
os.makedirs(tgt_dir, exist_ok=True)
|
149 |
+
basename = os.path.basename(f"{utt_id}.wav")
|
150 |
+
# save as PCM 16 bit wav file
|
151 |
+
sf.write(
|
152 |
+
os.path.join(tgt_dir, basename),
|
153 |
+
y.cpu().numpy(),
|
154 |
+
config["sampling_rate"],
|
155 |
+
"PCM_16",
|
156 |
+
)
|
157 |
+
|
158 |
+
# report average RTF
|
159 |
+
logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).")
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
main()
|
vec2wav2/bin/gradio_app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import logging
|
6 |
+
import yaml
|
7 |
+
import soundfile as sf
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args
|
11 |
+
|
12 |
+
# Create Gradio interface
|
13 |
+
def create_interface():
|
14 |
+
args = vc_args()
|
15 |
+
logger = configure_logging(args.verbose)
|
16 |
+
voice_converter = VoiceConverter(
|
17 |
+
expdir=args.expdir,
|
18 |
+
token_extractor=args.token_extractor,
|
19 |
+
prompt_extractor=args.prompt_extractor,
|
20 |
+
prompt_output_layer=args.prompt_output_layer,
|
21 |
+
checkpoint=args.checkpoint,
|
22 |
+
script_logger=logger
|
23 |
+
)
|
24 |
+
with gr.Blocks(title="Voice Conversion") as demo:
|
25 |
+
gr.Markdown("# vec2wav 2.0 Voice Conversion Demo")
|
26 |
+
gr.Markdown("Upload source audio and target speaker audio to convert the voice.")
|
27 |
+
|
28 |
+
with gr.Row():
|
29 |
+
source_audio = gr.Audio(label="Source Audio", type="filepath")
|
30 |
+
target_audio = gr.Audio(label="Target Speaker Audio", type="filepath")
|
31 |
+
|
32 |
+
examples = [
|
33 |
+
["examples/Zuckerberg.wav", "examples/Rachel.wav"],
|
34 |
+
["examples/TheresaMay.wav", "examples/OptimusPrime.wav"]
|
35 |
+
]
|
36 |
+
gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio])
|
37 |
+
|
38 |
+
convert_btn = gr.Button("Convert Voice")
|
39 |
+
output_audio = gr.Audio(label="Converted Audio")
|
40 |
+
|
41 |
+
convert_btn.click(
|
42 |
+
fn=voice_converter.voice_conversion,
|
43 |
+
inputs=[source_audio, target_audio],
|
44 |
+
outputs=output_audio
|
45 |
+
)
|
46 |
+
|
47 |
+
return demo
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
demo = create_interface()
|
51 |
+
demo.launch(share=True)
|
vec2wav2/bin/train.py
ADDED
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Tomoki Hayashi
|
5 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
6 |
+
|
7 |
+
# Modified by Yiwei Guo, 2024
|
8 |
+
|
9 |
+
"""Train vec2wav."""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
import random
|
16 |
+
|
17 |
+
from collections import defaultdict
|
18 |
+
|
19 |
+
import matplotlib
|
20 |
+
import numpy as np
|
21 |
+
import soundfile as sf
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import yaml
|
25 |
+
import torch.multiprocessing as mp
|
26 |
+
from tensorboardX import SummaryWriter
|
27 |
+
from torch.utils.data import DataLoader
|
28 |
+
from tqdm import tqdm
|
29 |
+
|
30 |
+
import vec2wav2
|
31 |
+
import vec2wav2.models
|
32 |
+
import vec2wav2.optimizers
|
33 |
+
from torch.utils.data.distributed import DistributedSampler
|
34 |
+
|
35 |
+
from vec2wav2.datasets import AudioMelSCPDataset
|
36 |
+
from vec2wav2.layers import PQMF
|
37 |
+
from vec2wav2.losses import DiscriminatorAdversarialLoss
|
38 |
+
from vec2wav2.losses import FeatureMatchLoss
|
39 |
+
from vec2wav2.losses import GeneratorAdversarialLoss
|
40 |
+
from vec2wav2.losses import MelSpectrogramLoss
|
41 |
+
from vec2wav2.losses import MultiResolutionSTFTLoss
|
42 |
+
from vec2wav2.utils import crop_seq, load_feat_codebook, idx2vec
|
43 |
+
|
44 |
+
from vec2wav2.utils.espnet_utils import pad_list, make_non_pad_mask
|
45 |
+
|
46 |
+
# set to avoid matplotlib error in CLI environment
|
47 |
+
matplotlib.use("Agg")
|
48 |
+
|
49 |
+
|
50 |
+
def set_loglevel(verbose):
|
51 |
+
# set logger
|
52 |
+
if verbose > 1:
|
53 |
+
logging.basicConfig(
|
54 |
+
level=logging.DEBUG,
|
55 |
+
stream=sys.stdout,
|
56 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
57 |
+
)
|
58 |
+
elif verbose > 0:
|
59 |
+
logging.basicConfig(
|
60 |
+
level=logging.INFO,
|
61 |
+
stream=sys.stdout,
|
62 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
logging.basicConfig(
|
66 |
+
level=logging.WARN,
|
67 |
+
stream=sys.stdout,
|
68 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
69 |
+
)
|
70 |
+
logging.warning("Skip DEBUG/INFO messages")
|
71 |
+
|
72 |
+
|
73 |
+
class Trainer(object):
|
74 |
+
"""Customized trainer module for Parallel WaveGAN training."""
|
75 |
+
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
steps,
|
79 |
+
epochs,
|
80 |
+
data_loader,
|
81 |
+
sampler,
|
82 |
+
model,
|
83 |
+
criterion,
|
84 |
+
optimizer,
|
85 |
+
scheduler,
|
86 |
+
config,
|
87 |
+
device=torch.device("cpu"),
|
88 |
+
):
|
89 |
+
"""Initialize trainer.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
steps (int): Initial global steps.
|
93 |
+
epochs (int): Initial global epochs.
|
94 |
+
data_loader (dict): Dict of data loaders. It must contain "train" and "dev" loaders.
|
95 |
+
model (dict): Dict of models. It must contain "generator" and "discriminator" models.
|
96 |
+
criterion (dict): Dict of criteria. It must contain "stft" and "mse" criteria.
|
97 |
+
optimizer (dict): Dict of optimizers. It must contain "generator" and "discriminator" optimizers.
|
98 |
+
scheduler (dict): Dict of schedulers. It must contain "generator" and "discriminator" schedulers.
|
99 |
+
config (dict): Config dict loaded from yaml format configuration file.
|
100 |
+
device (torch.deive): Pytorch device instance.
|
101 |
+
|
102 |
+
"""
|
103 |
+
self.steps = steps
|
104 |
+
self.epochs = epochs
|
105 |
+
self.data_loader = data_loader
|
106 |
+
self.sampler = sampler
|
107 |
+
self.model = model
|
108 |
+
self.criterion = criterion
|
109 |
+
self.optimizer = optimizer
|
110 |
+
self.scheduler = scheduler
|
111 |
+
self.config = config
|
112 |
+
self.device = device
|
113 |
+
self.writer = SummaryWriter(config["outdir"])
|
114 |
+
self.finish_train = False
|
115 |
+
self.total_train_loss = defaultdict(float)
|
116 |
+
self.total_eval_loss = defaultdict(float)
|
117 |
+
|
118 |
+
# load vq codebook
|
119 |
+
feat_codebook_path = self.config["vq_codebook"]
|
120 |
+
|
121 |
+
self.feat_codebook, self.feat_codebook_numgroups = load_feat_codebook(np.load(feat_codebook_path, allow_pickle=True), device)
|
122 |
+
|
123 |
+
def run(self):
|
124 |
+
"""Run training."""
|
125 |
+
self.tqdm = tqdm(initial=self.steps, total=self.config["train_max_steps"], desc="[train]")
|
126 |
+
while True:
|
127 |
+
# train one epoch
|
128 |
+
self._train_epoch()
|
129 |
+
|
130 |
+
# check whether training is finished
|
131 |
+
if self.finish_train:
|
132 |
+
break
|
133 |
+
|
134 |
+
self.tqdm.close()
|
135 |
+
logging.info("Finished training.")
|
136 |
+
|
137 |
+
def save_checkpoint(self, checkpoint_path):
|
138 |
+
"""Save checkpoint.
|
139 |
+
Args:
|
140 |
+
checkpoint_path (str): Checkpoint path to be saved.
|
141 |
+
"""
|
142 |
+
state_dict = {
|
143 |
+
"optimizer": {
|
144 |
+
"generator": self.optimizer["generator"].state_dict(),
|
145 |
+
"discriminator": self.optimizer["discriminator"].state_dict(),
|
146 |
+
},
|
147 |
+
"scheduler": {
|
148 |
+
"generator": self.scheduler["generator"].state_dict(),
|
149 |
+
"discriminator": self.scheduler["discriminator"].state_dict(),
|
150 |
+
},
|
151 |
+
"steps": self.steps,
|
152 |
+
"epochs": self.epochs,
|
153 |
+
}
|
154 |
+
if self.config["distributed"]:
|
155 |
+
state_dict["model"] = {
|
156 |
+
"generator": self.model["generator"].module.state_dict(),
|
157 |
+
"discriminator": self.model["discriminator"].module.state_dict(),
|
158 |
+
}
|
159 |
+
else:
|
160 |
+
state_dict["model"] = {
|
161 |
+
"generator": self.model["generator"].state_dict(),
|
162 |
+
"discriminator": self.model["discriminator"].state_dict(),
|
163 |
+
}
|
164 |
+
|
165 |
+
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
166 |
+
os.makedirs(os.path.dirname(checkpoint_path))
|
167 |
+
torch.save(state_dict, checkpoint_path)
|
168 |
+
|
169 |
+
def load_checkpoint(self, checkpoint_path, load_only_params=False):
|
170 |
+
"""Load checkpoint.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
checkpoint_path (str): Checkpoint path to be loaded.
|
174 |
+
load_only_params (bool): Whether to load only model parameters.
|
175 |
+
|
176 |
+
"""
|
177 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
178 |
+
if self.config["distributed"]:
|
179 |
+
self.model["generator"].module.load_state_dict(
|
180 |
+
state_dict["model"]["generator"]
|
181 |
+
)
|
182 |
+
self.model["discriminator"].module.load_state_dict(
|
183 |
+
state_dict["model"]["discriminator"]
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
self.model["generator"].load_state_dict(state_dict["model"]["generator"])
|
187 |
+
self.model["discriminator"].load_state_dict(
|
188 |
+
state_dict["model"]["discriminator"]
|
189 |
+
)
|
190 |
+
if not load_only_params:
|
191 |
+
self.steps = state_dict["steps"]
|
192 |
+
self.epochs = state_dict["epochs"]
|
193 |
+
self.optimizer["generator"].load_state_dict(state_dict["optimizer"]["generator"])
|
194 |
+
self.optimizer["discriminator"].load_state_dict(state_dict["optimizer"]["discriminator"])
|
195 |
+
self.scheduler["generator"].load_state_dict(state_dict["scheduler"]["generator"])
|
196 |
+
self.scheduler["discriminator"].load_state_dict(state_dict["scheduler"]["discriminator"])
|
197 |
+
|
198 |
+
def _train_step(self, batch):
|
199 |
+
"""Train model one step."""
|
200 |
+
# parse batch
|
201 |
+
vqidx, mel, prompt, y, xlens, prompt_lens = batch
|
202 |
+
vqidx = vqidx.to(self.device)
|
203 |
+
mel = mel.to(self.device)
|
204 |
+
prompt = prompt.to(self.device)
|
205 |
+
vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups) # (B, L, D)
|
206 |
+
y = y.unsqueeze(-2).to(self.device) # (B, 1, T)
|
207 |
+
|
208 |
+
# build mask
|
209 |
+
mask = make_non_pad_mask(xlens).to(self.device) # (B, L)
|
210 |
+
prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt)
|
211 |
+
|
212 |
+
# crop wav sequence
|
213 |
+
crop_xlen = min(self.config["crop_max_frames"], min(xlens))
|
214 |
+
x_offsets = [np.random.randint(0, l - crop_xlen + 1) for l in xlens]
|
215 |
+
crop_ylen = crop_xlen * self.config["hop_size"]
|
216 |
+
y_offsets = [o * self.config["hop_size"] for o in x_offsets]
|
217 |
+
y = crop_seq(y, y_offsets, crop_ylen)
|
218 |
+
|
219 |
+
#######################
|
220 |
+
# Generator #
|
221 |
+
#######################
|
222 |
+
if self.steps > self.config.get("generator_train_start_steps", 0):
|
223 |
+
mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T)
|
224 |
+
|
225 |
+
# initialize
|
226 |
+
gen_loss, aux_loss = 0.0, 0.0
|
227 |
+
|
228 |
+
# frontend mel prediction loss
|
229 |
+
if self.steps <= self.config.get("frontend_mel_prediction_stop_steps", 0):
|
230 |
+
frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)),
|
231 |
+
torch.masked_select(mel_, mask.unsqueeze(-1)))
|
232 |
+
self.total_train_loss["train/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item()
|
233 |
+
gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss
|
234 |
+
|
235 |
+
# multi-resolution sfft loss
|
236 |
+
if self.config["use_stft_loss"]:
|
237 |
+
sc_loss, mag_loss = self.criterion["stft"](y_, y)
|
238 |
+
aux_loss += sc_loss + mag_loss
|
239 |
+
self.total_train_loss["train/spectral_convergence_loss"] += sc_loss.item()
|
240 |
+
self.total_train_loss["train/log_stft_magnitude_loss"] += mag_loss.item()
|
241 |
+
|
242 |
+
# subband multi-resolution stft loss
|
243 |
+
if self.config["use_subband_stft_loss"]:
|
244 |
+
aux_loss *= 0.5 # for balancing with subband stft loss
|
245 |
+
y_mb = self.criterion["pqmf"].analysis(y)
|
246 |
+
y_mb_ = self.criterion["pqmf"].analysis(y_)
|
247 |
+
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
|
248 |
+
aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
|
249 |
+
self.total_train_loss["train/sub_spectral_convergence_loss"] += sub_sc_loss.item()
|
250 |
+
self.total_train_loss["train/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
|
251 |
+
|
252 |
+
# mel spectrogram loss
|
253 |
+
if self.config["use_mel_loss"]:
|
254 |
+
mel_loss = self.criterion["mel"](y_, y)
|
255 |
+
aux_loss += mel_loss
|
256 |
+
self.total_train_loss["train/mel_loss"] += mel_loss.item()
|
257 |
+
|
258 |
+
# weighting aux loss
|
259 |
+
gen_loss += self.config.get("lambda_aux", 1.0) * aux_loss
|
260 |
+
|
261 |
+
# adversarial loss
|
262 |
+
if self.steps > self.config["discriminator_train_start_steps"]:
|
263 |
+
p_ = self.model["discriminator"](y_)
|
264 |
+
adv_loss = self.criterion["gen_adv"](p_)
|
265 |
+
self.total_train_loss["train/adversarial_loss"] += adv_loss.item()
|
266 |
+
|
267 |
+
# feature matching loss
|
268 |
+
if self.config["use_feat_match_loss"]:
|
269 |
+
# no need to track gradients
|
270 |
+
with torch.no_grad():
|
271 |
+
p = self.model["discriminator"](y)
|
272 |
+
fm_loss = self.criterion["feat_match"](p_, p)
|
273 |
+
self.total_train_loss["train/feature_matching_loss"] += fm_loss.item()
|
274 |
+
adv_loss += self.config["lambda_feat_match"] * fm_loss
|
275 |
+
|
276 |
+
# add adversarial loss to generator loss
|
277 |
+
gen_loss += self.config["lambda_adv"] * adv_loss
|
278 |
+
|
279 |
+
self.total_train_loss["train/generator_loss"] += gen_loss.item()
|
280 |
+
|
281 |
+
# update generator
|
282 |
+
self.optimizer["generator"].zero_grad()
|
283 |
+
gen_loss.backward()
|
284 |
+
if self.config["generator_grad_norm"] > 0:
|
285 |
+
torch.nn.utils.clip_grad_norm_(
|
286 |
+
self.model["generator"].parameters(),
|
287 |
+
self.config["generator_grad_norm"],
|
288 |
+
)
|
289 |
+
self.optimizer["generator"].step()
|
290 |
+
self.scheduler["generator"].step()
|
291 |
+
|
292 |
+
#######################
|
293 |
+
# Discriminator #
|
294 |
+
#######################
|
295 |
+
if self.steps > self.config["discriminator_train_start_steps"]:
|
296 |
+
# re-compute y_ which leads better quality
|
297 |
+
with torch.no_grad():
|
298 |
+
# logging.info(f"{vqvec.shape, prompt.shape, mask.shape, prompt_mask.shape}")
|
299 |
+
_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T)
|
300 |
+
|
301 |
+
if self.config["generator_params"]["out_channels"] > 1:
|
302 |
+
y_ = self.criterion["pqmf"].synthesis(y_)
|
303 |
+
|
304 |
+
# discriminator loss
|
305 |
+
p = self.model["discriminator"](y)
|
306 |
+
p_ = self.model["discriminator"](y_.detach())
|
307 |
+
real_loss, fake_loss = self.criterion["dis_adv"](p_, p)
|
308 |
+
dis_loss = real_loss + fake_loss
|
309 |
+
self.total_train_loss["train/real_loss"] += real_loss.item()
|
310 |
+
self.total_train_loss["train/fake_loss"] += fake_loss.item()
|
311 |
+
self.total_train_loss["train/discriminator_loss"] += dis_loss.item()
|
312 |
+
|
313 |
+
# update discriminator
|
314 |
+
self.optimizer["discriminator"].zero_grad()
|
315 |
+
dis_loss.backward()
|
316 |
+
if self.config["discriminator_grad_norm"] > 0:
|
317 |
+
torch.nn.utils.clip_grad_norm_(
|
318 |
+
self.model["discriminator"].parameters(),
|
319 |
+
self.config["discriminator_grad_norm"],
|
320 |
+
)
|
321 |
+
self.optimizer["discriminator"].step()
|
322 |
+
self.scheduler["discriminator"].step()
|
323 |
+
|
324 |
+
# update counts
|
325 |
+
self.steps += 1
|
326 |
+
self.tqdm.update(1)
|
327 |
+
self._check_train_finish()
|
328 |
+
|
329 |
+
def _train_epoch(self):
|
330 |
+
"""Train model one epoch."""
|
331 |
+
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
|
332 |
+
# train one step
|
333 |
+
self._train_step(batch)
|
334 |
+
|
335 |
+
# check interval
|
336 |
+
if self.config["rank"] == 0:
|
337 |
+
self._check_log_interval()
|
338 |
+
self._check_eval_interval()
|
339 |
+
self._check_save_interval()
|
340 |
+
|
341 |
+
# check whether training is finished
|
342 |
+
if self.finish_train:
|
343 |
+
return
|
344 |
+
|
345 |
+
# update
|
346 |
+
self.epochs += 1
|
347 |
+
self.train_steps_per_epoch = train_steps_per_epoch
|
348 |
+
logging.info(
|
349 |
+
f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
|
350 |
+
f"({self.train_steps_per_epoch} steps per epoch)."
|
351 |
+
)
|
352 |
+
|
353 |
+
# needed for shuffle in distributed training
|
354 |
+
if self.config["distributed"]:
|
355 |
+
self.sampler["train"].set_epoch(self.epochs)
|
356 |
+
|
357 |
+
@torch.no_grad()
|
358 |
+
def _eval_step(self, batch):
|
359 |
+
"""Evaluate model one step."""
|
360 |
+
# parse batch
|
361 |
+
vqidx, mel, prompt, y, xlens, prompt_lens = batch
|
362 |
+
vqidx = vqidx.to(self.device).long()
|
363 |
+
mel = mel.to(self.device)
|
364 |
+
prompt = prompt.to(self.device)
|
365 |
+
vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups)
|
366 |
+
y = y.unsqueeze(-2).to(self.device) # (B, 1, T)
|
367 |
+
|
368 |
+
# build mask
|
369 |
+
mask = make_non_pad_mask(xlens).to(self.device) # (B, L)
|
370 |
+
prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt)
|
371 |
+
|
372 |
+
#######################
|
373 |
+
# Generator #
|
374 |
+
#######################
|
375 |
+
mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask) # (B, L, 80), (B, C, T)
|
376 |
+
|
377 |
+
# reconstruct the signal from multi-band signal
|
378 |
+
if self.config["generator_params"]["out_channels"] > 1:
|
379 |
+
y_mb_ = y_
|
380 |
+
y_ = self.criterion["pqmf"].synthesis(y_mb_)
|
381 |
+
|
382 |
+
# initialize
|
383 |
+
gen_loss = 0.0
|
384 |
+
aux_loss = 0.0
|
385 |
+
|
386 |
+
# frontend mel prediction loss
|
387 |
+
frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)),
|
388 |
+
torch.masked_select(mel_, mask.unsqueeze(-1)))
|
389 |
+
self.total_eval_loss["eval/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item()
|
390 |
+
gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss
|
391 |
+
|
392 |
+
# multi-resolution stft loss
|
393 |
+
if self.config["use_stft_loss"]:
|
394 |
+
sc_loss, mag_loss = self.criterion["stft"](y_, y)
|
395 |
+
aux_loss += sc_loss + mag_loss
|
396 |
+
self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item()
|
397 |
+
self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item()
|
398 |
+
|
399 |
+
# subband multi-resolution stft loss
|
400 |
+
if self.config.get("use_subband_stft_loss", False):
|
401 |
+
aux_loss *= 0.5 # for balancing with subband stft loss
|
402 |
+
y_mb = self.criterion["pqmf"].analysis(y)
|
403 |
+
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
|
404 |
+
self.total_eval_loss["eval/sub_spectral_convergence_loss"] += sub_sc_loss.item()
|
405 |
+
self.total_eval_loss["eval/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
|
406 |
+
aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
|
407 |
+
|
408 |
+
# mel spectrogram loss
|
409 |
+
if self.config["use_mel_loss"]:
|
410 |
+
mel_loss = self.criterion["mel"](y_, y)
|
411 |
+
aux_loss += mel_loss
|
412 |
+
self.total_eval_loss["eval/mel_loss"] += mel_loss.item()
|
413 |
+
|
414 |
+
# weighting stft loss
|
415 |
+
gen_loss += aux_loss * self.config.get("lambda_aux", 1.0)
|
416 |
+
|
417 |
+
# adversarial loss
|
418 |
+
p_ = self.model["discriminator"](y_)
|
419 |
+
adv_loss = self.criterion["gen_adv"](p_)
|
420 |
+
gen_loss += self.config["lambda_adv"] * adv_loss
|
421 |
+
|
422 |
+
# feature matching loss
|
423 |
+
if self.config["use_feat_match_loss"]:
|
424 |
+
p = self.model["discriminator"](y)
|
425 |
+
fm_loss = self.criterion["feat_match"](p_, p)
|
426 |
+
self.total_eval_loss["eval/feature_matching_loss"] += fm_loss.item()
|
427 |
+
gen_loss += (
|
428 |
+
self.config["lambda_adv"] * self.config["lambda_feat_match"] * fm_loss
|
429 |
+
)
|
430 |
+
|
431 |
+
#######################
|
432 |
+
# Discriminator #
|
433 |
+
#######################
|
434 |
+
p = self.model["discriminator"](y)
|
435 |
+
p_ = self.model["discriminator"](y_)
|
436 |
+
|
437 |
+
# discriminator loss
|
438 |
+
real_loss, fake_loss = self.criterion["dis_adv"](p_, p)
|
439 |
+
dis_loss = real_loss + fake_loss
|
440 |
+
|
441 |
+
# add to total eval loss
|
442 |
+
self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item()
|
443 |
+
self.total_eval_loss["eval/generator_loss"] += gen_loss.item()
|
444 |
+
self.total_eval_loss["eval/real_loss"] += real_loss.item()
|
445 |
+
self.total_eval_loss["eval/fake_loss"] += fake_loss.item()
|
446 |
+
self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item()
|
447 |
+
|
448 |
+
def _eval_epoch(self):
|
449 |
+
"""Evaluate model one epoch."""
|
450 |
+
logging.info(f"(Steps: {self.steps}) Start evaluation.")
|
451 |
+
# change mode
|
452 |
+
for key in self.model.keys():
|
453 |
+
self.model[key].eval()
|
454 |
+
|
455 |
+
# calculate loss for each batch
|
456 |
+
for eval_steps_per_epoch, batch in enumerate(tqdm(self.data_loader["dev"], desc="[eval]"), 1):
|
457 |
+
# eval one step
|
458 |
+
self._eval_step(batch)
|
459 |
+
|
460 |
+
logging.info(
|
461 |
+
f"(Steps: {self.steps}) Finished evaluation "
|
462 |
+
f"({eval_steps_per_epoch} steps per epoch)."
|
463 |
+
)
|
464 |
+
|
465 |
+
# average loss
|
466 |
+
for key in self.total_eval_loss.keys():
|
467 |
+
self.total_eval_loss[key] /= eval_steps_per_epoch
|
468 |
+
logging.info(f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}.")
|
469 |
+
|
470 |
+
# record
|
471 |
+
self._write_to_tensorboard(self.total_eval_loss)
|
472 |
+
|
473 |
+
# reset
|
474 |
+
self.total_eval_loss = defaultdict(float)
|
475 |
+
|
476 |
+
# restore mode
|
477 |
+
for key in self.model.keys():
|
478 |
+
self.model[key].train()
|
479 |
+
|
480 |
+
def _write_to_tensorboard(self, loss):
|
481 |
+
"""Write to tensorboard."""
|
482 |
+
for key, value in loss.items():
|
483 |
+
self.writer.add_scalar(key, value, self.steps)
|
484 |
+
|
485 |
+
def _check_save_interval(self):
|
486 |
+
if self.steps % self.config["save_interval_steps"] == 0:
|
487 |
+
self.save_checkpoint(os.path.join(self.config["outdir"],
|
488 |
+
f"checkpoint-{self.steps}steps.pkl"))
|
489 |
+
logging.info(f"Successfully saved checkpoint @ {self.steps} steps.")
|
490 |
+
|
491 |
+
def _check_eval_interval(self):
|
492 |
+
if self.steps % self.config["eval_interval_steps"] == 0:
|
493 |
+
self._eval_epoch()
|
494 |
+
|
495 |
+
def _check_log_interval(self):
|
496 |
+
if self.steps % self.config["log_interval_steps"] == 0:
|
497 |
+
for key in self.total_train_loss.keys():
|
498 |
+
self.total_train_loss[key] /= self.config["log_interval_steps"]
|
499 |
+
logging.info(f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}.")
|
500 |
+
self._write_to_tensorboard(self.total_train_loss)
|
501 |
+
|
502 |
+
# reset
|
503 |
+
self.total_train_loss = defaultdict(float)
|
504 |
+
|
505 |
+
def _check_train_finish(self):
|
506 |
+
if self.steps >= self.config["train_max_steps"]:
|
507 |
+
self.finish_train = True
|
508 |
+
|
509 |
+
|
510 |
+
class Collator(object):
|
511 |
+
"""Customized collator for Pytorch DataLoader in training."""
|
512 |
+
|
513 |
+
def __init__(
|
514 |
+
self,
|
515 |
+
hop_size=256,
|
516 |
+
win_length=1024,
|
517 |
+
sampling_rate=16000,
|
518 |
+
prompt_dim=1024,
|
519 |
+
prompt_fold_by_2=False
|
520 |
+
):
|
521 |
+
"""Initialize customized collator for PyTorch DataLoader.
|
522 |
+
|
523 |
+
Args:
|
524 |
+
hop_size (int): Hop size of features, in sampling points.
|
525 |
+
win_length (int): window length of features.
|
526 |
+
sampling_rate (int): sampling rate of waveform data
|
527 |
+
prompt_dim (int): number of prompt embedding dimensions
|
528 |
+
"""
|
529 |
+
self.hop_size = hop_size
|
530 |
+
self.win_length = win_length
|
531 |
+
self.sampling_rate = sampling_rate
|
532 |
+
self.prompt_dim = prompt_dim
|
533 |
+
if prompt_fold_by_2:
|
534 |
+
self.prompt_len_factor = 2
|
535 |
+
else:
|
536 |
+
self.prompt_len_factor = 1
|
537 |
+
|
538 |
+
def construct_prompt(self, mel_lens):
|
539 |
+
prompt_lens = [random.randint(int(l / (3 * self.prompt_len_factor)), int(l / (2 * self.prompt_len_factor))) for l in mel_lens]
|
540 |
+
prompt_starts = []
|
541 |
+
is_from_start = []
|
542 |
+
for ml, pl in zip(mel_lens, prompt_lens):
|
543 |
+
if random.random() > 0.5:
|
544 |
+
# from start
|
545 |
+
prompt_start = random.randint(0, 1 * self.sampling_rate // (self.hop_size * self.prompt_len_factor))
|
546 |
+
is_from_start.append(True)
|
547 |
+
else:
|
548 |
+
# from ending
|
549 |
+
prompt_start = random.randint((ml - 1 * self.sampling_rate // self.hop_size) // self.prompt_len_factor, ml // self.prompt_len_factor) - pl
|
550 |
+
is_from_start.append(False)
|
551 |
+
prompt_starts.append(prompt_start)
|
552 |
+
return prompt_lens, prompt_starts, is_from_start
|
553 |
+
|
554 |
+
def __call__(self, batch):
|
555 |
+
"""Convert into batch tensors.
|
556 |
+
|
557 |
+
Args:
|
558 |
+
batch (list): list of tuple of the pair of audio and features.
|
559 |
+
|
560 |
+
This collator will automatically determine the prompt segment (acoustic context) for each utterance.
|
561 |
+
The prompt is cut off from the current utterance, ranging from one third to half of the original utterance.
|
562 |
+
The prompt can be cut from either the starting or the ending of the utterance, within 1 second margin.
|
563 |
+
The other features include 2-dim VQ features (2 is the number of groups), and D-dim prompts (e.g. WavLM features)
|
564 |
+
|
565 |
+
Returns:
|
566 |
+
Tensor ys: waveform batch (B, T).
|
567 |
+
Tensors vqs, mels: Auxiliary feature batch (B, C, T'), where T' = T / hop_size.
|
568 |
+
Tensor prompts: prompt feature batch (B, C, T'')
|
569 |
+
List c_lengths, prompt_lengths: list of lengths
|
570 |
+
"""
|
571 |
+
batch = batch[0]
|
572 |
+
|
573 |
+
# check length
|
574 |
+
batch = [self._adjust_length(*b) for b in batch]
|
575 |
+
ys, vqs, mels, prompts_old = list(map(list, zip(*batch))) # [(a,b), (c,d)] -> [a, c], [b, d]
|
576 |
+
|
577 |
+
batch_size = len(vqs)
|
578 |
+
|
579 |
+
prompt_lengths, prompt_starts, is_from_starts = self.construct_prompt([len(m) for m in mels])
|
580 |
+
c_lengths = []
|
581 |
+
prompts = torch.zeros(batch_size, max(prompt_lengths), self.prompt_dim)
|
582 |
+
for i in range(batch_size):
|
583 |
+
prompts[i, :prompt_lengths[i]] = torch.tensor(prompts_old[i][prompt_starts[i]:prompt_starts[i]+prompt_lengths[i], :])
|
584 |
+
if is_from_starts[i]:
|
585 |
+
start_idx = (prompt_starts[i] + prompt_lengths[i])*self.prompt_len_factor
|
586 |
+
mels[i] = mels[i][start_idx:]
|
587 |
+
vqs[i] = vqs[i][start_idx:]
|
588 |
+
ys[i] = ys[i][start_idx * self.hop_size: ]
|
589 |
+
else:
|
590 |
+
end_idx = prompt_starts[i]*self.prompt_len_factor
|
591 |
+
mels[i] = mels[i][:end_idx]
|
592 |
+
vqs[i] = vqs[i][:end_idx]
|
593 |
+
ys[i] = ys[i][:end_idx * self.hop_size]
|
594 |
+
c_lengths.append(len(mels[i]))
|
595 |
+
|
596 |
+
vqs = pad_list([torch.tensor(c) for c in vqs], pad_value=0) # (B, L, Groups)
|
597 |
+
vqs = vqs.long()
|
598 |
+
mels = pad_list([torch.tensor(c) for c in mels], pad_value=0) # (B, L, 80)
|
599 |
+
|
600 |
+
ys = pad_list([torch.tensor(y, dtype=torch.float) for y in ys], pad_value=0)[:, :mels.size(1) * self.hop_size] # (B, T)
|
601 |
+
assert ys.size(1) == mels.size(1) * self.hop_size == vqs.size(1) * self.hop_size
|
602 |
+
|
603 |
+
return vqs, mels, prompts, ys, c_lengths, prompt_lengths
|
604 |
+
|
605 |
+
def _adjust_length(self, x, c, *args):
|
606 |
+
"""Adjust the audio and feature lengths.
|
607 |
+
|
608 |
+
Note:
|
609 |
+
Basically we assume that the length of x and c are adjusted
|
610 |
+
through preprocessing stage, but if we use other library processed
|
611 |
+
features, this process will be needed.
|
612 |
+
|
613 |
+
"""
|
614 |
+
if len(x) > len(c) * self.hop_size:
|
615 |
+
x = x[(self.win_length - self.hop_size) // 2:]
|
616 |
+
x = x[:len(c) * self.hop_size]
|
617 |
+
|
618 |
+
# check the legnth is valid
|
619 |
+
assert len(x) == len(c) * self.hop_size
|
620 |
+
|
621 |
+
return x, c, *args
|
622 |
+
|
623 |
+
|
624 |
+
def main(rank, n_gpus):
|
625 |
+
"""Run training process."""
|
626 |
+
parser = argparse.ArgumentParser(
|
627 |
+
description="Train vec2wav2 (See detail in vec2wav2/bin/train.py)."
|
628 |
+
)
|
629 |
+
parser.add_argument(
|
630 |
+
"--train-wav-scp",
|
631 |
+
default=None,
|
632 |
+
type=str,
|
633 |
+
help="kaldi-style wav.scp file for training. "
|
634 |
+
)
|
635 |
+
parser.add_argument(
|
636 |
+
"--train-vqidx-scp",
|
637 |
+
default=None,
|
638 |
+
type=str,
|
639 |
+
help="kaldi-style feats.scp file for training. "
|
640 |
+
)
|
641 |
+
parser.add_argument(
|
642 |
+
"--train-mel-scp",
|
643 |
+
default=None,
|
644 |
+
type=str,
|
645 |
+
help="kaldi-style feats.scp file for training. "
|
646 |
+
)
|
647 |
+
parser.add_argument(
|
648 |
+
"--train-prompt-scp",
|
649 |
+
default=None,
|
650 |
+
type=str,
|
651 |
+
help="prompt scp (in this case, utt to path)"
|
652 |
+
)
|
653 |
+
parser.add_argument(
|
654 |
+
"--train-segments",
|
655 |
+
default=None,
|
656 |
+
type=str,
|
657 |
+
help="kaldi-style segments file for training.",
|
658 |
+
)
|
659 |
+
parser.add_argument(
|
660 |
+
"--train-num-frames",
|
661 |
+
default=None,
|
662 |
+
type=str,
|
663 |
+
help="kaldi-style utt2num_frames file for training.",
|
664 |
+
)
|
665 |
+
parser.add_argument(
|
666 |
+
"--dev-wav-scp",
|
667 |
+
default=None,
|
668 |
+
type=str,
|
669 |
+
help="kaldi-style wav.scp file for validation. "
|
670 |
+
)
|
671 |
+
parser.add_argument(
|
672 |
+
"--dev-vqidx-scp",
|
673 |
+
default=None,
|
674 |
+
type=str,
|
675 |
+
help="kaldi-style feats.scp file for vaidation. "
|
676 |
+
)
|
677 |
+
parser.add_argument(
|
678 |
+
"--dev-mel-scp",
|
679 |
+
default=None,
|
680 |
+
type=str,
|
681 |
+
help="kaldi-style feats.scp file for vaidation. "
|
682 |
+
)
|
683 |
+
parser.add_argument(
|
684 |
+
"--dev-prompt-scp",
|
685 |
+
default=None,
|
686 |
+
type=str,
|
687 |
+
help="prompt scp (in this case, utt to path)"
|
688 |
+
)
|
689 |
+
parser.add_argument(
|
690 |
+
"--dev-segments",
|
691 |
+
default=None,
|
692 |
+
type=str,
|
693 |
+
help="kaldi-style segments file for validation.",
|
694 |
+
)
|
695 |
+
parser.add_argument(
|
696 |
+
"--dev-num-frames",
|
697 |
+
default=None,
|
698 |
+
type=str,
|
699 |
+
help="kaldi-style utt2num_frames file for validation.",
|
700 |
+
)
|
701 |
+
parser.add_argument(
|
702 |
+
"--outdir",
|
703 |
+
type=str,
|
704 |
+
required=True,
|
705 |
+
help="directory to save checkpoints.",
|
706 |
+
)
|
707 |
+
parser.add_argument(
|
708 |
+
"--config",
|
709 |
+
type=str,
|
710 |
+
required=True,
|
711 |
+
help="yaml format configuration file.",
|
712 |
+
)
|
713 |
+
parser.add_argument(
|
714 |
+
"--pretrain",
|
715 |
+
default="",
|
716 |
+
type=str,
|
717 |
+
nargs="?",
|
718 |
+
help='checkpoint file path to load pretrained params. (default="")',
|
719 |
+
)
|
720 |
+
parser.add_argument(
|
721 |
+
"--resume",
|
722 |
+
default="",
|
723 |
+
type=str,
|
724 |
+
nargs="?",
|
725 |
+
help='checkpoint file path to resume training. (default="")',
|
726 |
+
)
|
727 |
+
parser.add_argument(
|
728 |
+
"--verbose",
|
729 |
+
type=int,
|
730 |
+
default=1,
|
731 |
+
help="logging level. higher is more logging. (default=1)",
|
732 |
+
)
|
733 |
+
parser.add_argument("--vq-codebook", default=None, type=str)
|
734 |
+
# parser.add_argument("--sampling-rate", type=int)
|
735 |
+
# parser.add_argument("--num-mels", type=int)
|
736 |
+
# parser.add_argument("--hop-size", type=int)
|
737 |
+
# parser.add_argument("--win-length", type=int)
|
738 |
+
args = parser.parse_args()
|
739 |
+
|
740 |
+
# init distributed training
|
741 |
+
device = torch.device("cuda")
|
742 |
+
# effective when using fixed size inputs
|
743 |
+
# see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
|
744 |
+
torch.backends.cudnn.benchmark = True
|
745 |
+
# setup for distributed training
|
746 |
+
# see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
|
747 |
+
if n_gpus == 1:
|
748 |
+
assert rank == 0
|
749 |
+
|
750 |
+
set_loglevel(args.verbose)
|
751 |
+
|
752 |
+
# check directory existence
|
753 |
+
if not os.path.exists(args.outdir):
|
754 |
+
os.makedirs(args.outdir)
|
755 |
+
|
756 |
+
# init process group
|
757 |
+
logging.info("Synchronizing between all workers.")
|
758 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank)
|
759 |
+
torch.cuda.set_device(rank)
|
760 |
+
logging.info("Finished init process group.")
|
761 |
+
|
762 |
+
# load and save config
|
763 |
+
with open(args.config) as f:
|
764 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
765 |
+
config.update(vars(args))
|
766 |
+
config['rank'] = rank
|
767 |
+
config['distributed'] = True
|
768 |
+
config['world_size'] = n_gpus
|
769 |
+
config["version"] = vec2wav2.__version__ # add version info
|
770 |
+
if rank == 0:
|
771 |
+
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
|
772 |
+
yaml.dump(config, f, Dumper=yaml.Dumper)
|
773 |
+
for key, value in config.items():
|
774 |
+
logging.info(f"{key} = {value}")
|
775 |
+
|
776 |
+
# get dataset
|
777 |
+
train_dataset = AudioMelSCPDataset(
|
778 |
+
wav_scp=args.train_wav_scp,
|
779 |
+
vqidx_scp=args.train_vqidx_scp,
|
780 |
+
mel_scp=args.train_mel_scp,
|
781 |
+
prompt_scp=args.train_prompt_scp,
|
782 |
+
utt2num_frames=args.train_num_frames,
|
783 |
+
segments=args.train_segments,
|
784 |
+
batch_frames=config.get("batch_frames", None),
|
785 |
+
batch_size=config.get("batch_size", None),
|
786 |
+
min_num_frames=config.get("min_num_frames", None),
|
787 |
+
max_num_frames=config.get("max_num_frames", None),
|
788 |
+
allow_cache=config.get("allow_cache", False), # keep compatibility
|
789 |
+
length_tolerance=config.get("length_tolerance", 2),
|
790 |
+
prompt_fold_by_2=config.get("prompt_fold_by_2", True)
|
791 |
+
)
|
792 |
+
if rank == 0:
|
793 |
+
logging.info(f"The number of training batches = {len(train_dataset)}.")
|
794 |
+
dev_dataset = AudioMelSCPDataset(
|
795 |
+
wav_scp=args.dev_wav_scp,
|
796 |
+
vqidx_scp=args.dev_vqidx_scp,
|
797 |
+
mel_scp=args.dev_mel_scp,
|
798 |
+
prompt_scp=args.dev_prompt_scp,
|
799 |
+
utt2num_frames=args.dev_num_frames,
|
800 |
+
segments=args.dev_segments,
|
801 |
+
min_num_frames=config.get("min_num_frames", None),
|
802 |
+
max_num_frames=config.get("max_num_frames", None),
|
803 |
+
allow_cache=config.get("allow_cache", False), # keep compatibility
|
804 |
+
length_tolerance=config.get("length_tolerance", 2),
|
805 |
+
prompt_fold_by_2=config.get("prompt_fold_by_2", True)
|
806 |
+
)
|
807 |
+
if rank == 0:
|
808 |
+
logging.info(f"The number of development batches = {len(dev_dataset)}.")
|
809 |
+
dataset = {
|
810 |
+
"train": train_dataset,
|
811 |
+
"dev": dev_dataset,
|
812 |
+
}
|
813 |
+
|
814 |
+
# get data loader
|
815 |
+
collator = Collator(
|
816 |
+
hop_size=config["hop_size"],
|
817 |
+
win_length=config["win_length"],
|
818 |
+
sampling_rate=config["sampling_rate"],
|
819 |
+
prompt_dim=config['frontend_params']['prompt_channels'],
|
820 |
+
prompt_fold_by_2=config.get("prompt_fold_by_2", True)
|
821 |
+
)
|
822 |
+
|
823 |
+
sampler = {
|
824 |
+
"train": DistributedSampler(
|
825 |
+
dataset=dataset["train"],
|
826 |
+
num_replicas=n_gpus,
|
827 |
+
rank=rank,
|
828 |
+
shuffle=True,
|
829 |
+
),
|
830 |
+
"dev": DistributedSampler(
|
831 |
+
dataset=dataset["dev"],
|
832 |
+
num_replicas=n_gpus,
|
833 |
+
rank=rank,
|
834 |
+
shuffle=False,
|
835 |
+
)}
|
836 |
+
data_loader = {
|
837 |
+
"train": DataLoader(
|
838 |
+
dataset=dataset["train"],
|
839 |
+
shuffle=False,
|
840 |
+
collate_fn=collator,
|
841 |
+
num_workers=config["num_workers"],
|
842 |
+
sampler=sampler["train"],
|
843 |
+
pin_memory=config["pin_memory"],
|
844 |
+
),
|
845 |
+
"dev": DataLoader(
|
846 |
+
dataset=dataset["dev"],
|
847 |
+
shuffle=False,
|
848 |
+
collate_fn=collator,
|
849 |
+
num_workers=config["num_workers"],
|
850 |
+
sampler=sampler["dev"],
|
851 |
+
pin_memory=config["pin_memory"],
|
852 |
+
),
|
853 |
+
}
|
854 |
+
|
855 |
+
# define models
|
856 |
+
generator_class = getattr(
|
857 |
+
vec2wav2.models,
|
858 |
+
# keep compatibility
|
859 |
+
config.get("generator_type", "ParallelWaveGANGenerator"),
|
860 |
+
)
|
861 |
+
discriminator_class = getattr(
|
862 |
+
vec2wav2.models,
|
863 |
+
# keep compatibility
|
864 |
+
config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
|
865 |
+
)
|
866 |
+
model = {
|
867 |
+
"generator": vec2wav2.models.VEC2WAV2Generator(
|
868 |
+
vec2wav2.models.CTXVEC2WAVFrontend(config["prompt_net_type"], config["num_mels"], **config["frontend_params"]),
|
869 |
+
generator_class(**config["generator_params"])
|
870 |
+
).to(device),
|
871 |
+
"discriminator": discriminator_class(
|
872 |
+
**config["discriminator_params"],
|
873 |
+
).to(device),
|
874 |
+
}
|
875 |
+
|
876 |
+
# define criteria
|
877 |
+
criterion = {
|
878 |
+
"gen_adv": GeneratorAdversarialLoss(
|
879 |
+
# keep compatibility
|
880 |
+
**config.get("generator_adv_loss_params", {})
|
881 |
+
).to(device),
|
882 |
+
"dis_adv": DiscriminatorAdversarialLoss(
|
883 |
+
# keep compatibility
|
884 |
+
**config.get("discriminator_adv_loss_params", {})
|
885 |
+
).to(device),
|
886 |
+
}
|
887 |
+
if config.get("use_stft_loss", True): # keep compatibility
|
888 |
+
config["use_stft_loss"] = True
|
889 |
+
criterion["stft"] = MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device)
|
890 |
+
if config.get("use_subband_stft_loss", False): # keep compatibility
|
891 |
+
assert config["generator_params"]["out_channels"] > 1
|
892 |
+
criterion["sub_stft"] = MultiResolutionSTFTLoss(**config["subband_stft_loss_params"]).to(device)
|
893 |
+
else:
|
894 |
+
config["use_subband_stft_loss"] = False
|
895 |
+
if config.get("use_feat_match_loss", False): # keep compatibility
|
896 |
+
criterion["feat_match"] = FeatureMatchLoss(
|
897 |
+
# keep compatibility
|
898 |
+
**config.get("feat_match_loss_params", {}),
|
899 |
+
).to(device)
|
900 |
+
else:
|
901 |
+
config["use_feat_match_loss"] = False
|
902 |
+
if config.get("use_mel_loss", False): # keep compatibility
|
903 |
+
criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"],).to(device)
|
904 |
+
else:
|
905 |
+
config["use_mel_loss"] = False
|
906 |
+
|
907 |
+
# define optimizers and schedulers
|
908 |
+
generator_optimizer_class = getattr(
|
909 |
+
vec2wav2.optimizers,
|
910 |
+
# keep compatibility
|
911 |
+
config.get("generator_optimizer_type", "RAdam"),
|
912 |
+
)
|
913 |
+
discriminator_optimizer_class = getattr(
|
914 |
+
vec2wav2.optimizers,
|
915 |
+
# keep compatibility
|
916 |
+
config.get("discriminator_optimizer_type", "RAdam"),
|
917 |
+
)
|
918 |
+
optimizer = {
|
919 |
+
"generator": generator_optimizer_class(
|
920 |
+
model["generator"].parameters(),
|
921 |
+
**config["generator_optimizer_params"],
|
922 |
+
),
|
923 |
+
"discriminator": discriminator_optimizer_class(
|
924 |
+
model["discriminator"].parameters(),
|
925 |
+
**config["discriminator_optimizer_params"],
|
926 |
+
),
|
927 |
+
}
|
928 |
+
generator_scheduler_class = getattr(
|
929 |
+
torch.optim.lr_scheduler,
|
930 |
+
# keep compatibility
|
931 |
+
config.get("generator_scheduler_type", "StepLR"),
|
932 |
+
)
|
933 |
+
discriminator_scheduler_class = getattr(
|
934 |
+
torch.optim.lr_scheduler,
|
935 |
+
# keep compatibility
|
936 |
+
config.get("discriminator_scheduler_type", "StepLR"),
|
937 |
+
)
|
938 |
+
scheduler = {
|
939 |
+
"generator": generator_scheduler_class(
|
940 |
+
optimizer=optimizer["generator"],
|
941 |
+
**config["generator_scheduler_params"],
|
942 |
+
),
|
943 |
+
"discriminator": discriminator_scheduler_class(
|
944 |
+
optimizer=optimizer["discriminator"],
|
945 |
+
**config["discriminator_scheduler_params"],
|
946 |
+
),
|
947 |
+
}
|
948 |
+
from torch.nn.parallel import DistributedDataParallel
|
949 |
+
model["generator"] = DistributedDataParallel(model["generator"], device_ids=[rank], find_unused_parameters=True)
|
950 |
+
model["discriminator"] = DistributedDataParallel(model["discriminator"], device_ids=[rank], find_unused_parameters=True)
|
951 |
+
|
952 |
+
if rank == 0:
|
953 |
+
# show settings
|
954 |
+
logging.info(model["generator"])
|
955 |
+
logging.info(f"Generator has nparams: {sum([p.numel() for p in model['generator'].parameters()])}")
|
956 |
+
logging.info(model["discriminator"])
|
957 |
+
logging.info(f"Discriminator has nparams: {sum([p.numel() for p in model['discriminator'].parameters()])}")
|
958 |
+
logging.info(optimizer["generator"])
|
959 |
+
logging.info(optimizer["discriminator"])
|
960 |
+
|
961 |
+
# define trainer
|
962 |
+
trainer = Trainer(
|
963 |
+
steps=0,
|
964 |
+
epochs=0,
|
965 |
+
data_loader=data_loader,
|
966 |
+
sampler=sampler,
|
967 |
+
model=model,
|
968 |
+
criterion=criterion,
|
969 |
+
optimizer=optimizer,
|
970 |
+
scheduler=scheduler,
|
971 |
+
config=config,
|
972 |
+
device=device,
|
973 |
+
)
|
974 |
+
|
975 |
+
# load pretrained parameters from checkpoint
|
976 |
+
if len(args.pretrain) != 0:
|
977 |
+
trainer.load_checkpoint(args.pretrain, load_only_params=True)
|
978 |
+
if rank == 0:
|
979 |
+
logging.info(f"Successfully load parameters from {args.pretrain}.")
|
980 |
+
|
981 |
+
# resume from checkpoint
|
982 |
+
if len(args.resume) != 0:
|
983 |
+
trainer.load_checkpoint(args.resume)
|
984 |
+
if rank == 0:
|
985 |
+
logging.info(f"Successfully resumed from {args.resume}.")
|
986 |
+
|
987 |
+
# run training loop
|
988 |
+
try:
|
989 |
+
trainer.run()
|
990 |
+
finally:
|
991 |
+
if rank == 0:
|
992 |
+
trainer.save_checkpoint(os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl"))
|
993 |
+
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
|
994 |
+
|
995 |
+
|
996 |
+
if __name__ == "__main__":
|
997 |
+
assert torch.cuda.is_available(), "CPU training is not allowed."
|
998 |
+
n_gpus = torch.cuda.device_count()
|
999 |
+
print(f"============> using {n_gpus} GPUS")
|
1000 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
1001 |
+
os.environ["MASTER_PORT"] = "8000"
|
1002 |
+
|
1003 |
+
mp.spawn(
|
1004 |
+
main,
|
1005 |
+
nprocs=n_gpus,
|
1006 |
+
args=(n_gpus,)
|
1007 |
+
)
|
vec2wav2/bin/vc.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Copyright 2024 Yiwei Guo
|
4 |
+
|
5 |
+
""" Run VC inference with trained model """
|
6 |
+
|
7 |
+
import vec2wav2
|
8 |
+
from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor
|
9 |
+
from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor
|
10 |
+
# from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor
|
11 |
+
import torch
|
12 |
+
import logging
|
13 |
+
import argparse
|
14 |
+
from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k
|
15 |
+
import soundfile as sf
|
16 |
+
import yaml
|
17 |
+
import os
|
18 |
+
|
19 |
+
|
20 |
+
def configure_logging(verbose):
|
21 |
+
if verbose:
|
22 |
+
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG)
|
23 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
24 |
+
logging.basicConfig(level=logging.DEBUG)
|
25 |
+
else:
|
26 |
+
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR)
|
27 |
+
logging.getLogger().setLevel(logging.ERROR)
|
28 |
+
logging.basicConfig(level=logging.ERROR)
|
29 |
+
|
30 |
+
script_logger = logging.getLogger("script_logger")
|
31 |
+
handler = logging.StreamHandler()
|
32 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s'))
|
33 |
+
script_logger.addHandler(handler)
|
34 |
+
script_logger.setLevel(logging.INFO)
|
35 |
+
script_logger.propagate = False
|
36 |
+
return script_logger
|
37 |
+
|
38 |
+
def vc_args():
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
# required arguments
|
41 |
+
parser.add_argument("-s", "--source", default="examples/source.wav", type=str,
|
42 |
+
help="source wav path")
|
43 |
+
parser.add_argument("-t", "--target", default="examples/target.wav", type=str,
|
44 |
+
help="target speaker prompt path")
|
45 |
+
parser.add_argument("-o", "--output", default="output.wav", type=str,
|
46 |
+
help="path of the output wav file")
|
47 |
+
|
48 |
+
# optional arguments
|
49 |
+
parser.add_argument("--expdir", default="pretrained/", type=str,
|
50 |
+
help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.")
|
51 |
+
parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.")
|
52 |
+
parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str,
|
53 |
+
help="checkpoint or model flag of input token extractor")
|
54 |
+
parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str,
|
55 |
+
help="checkpoint or model flag of speaker prompt extractor")
|
56 |
+
parser.add_argument("--prompt-output-layer", default=6, type=int,
|
57 |
+
help="output layer when prompt is extracted from WavLM.")
|
58 |
+
|
59 |
+
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
60 |
+
|
61 |
+
args = parser.parse_args()
|
62 |
+
return args
|
63 |
+
|
64 |
+
|
65 |
+
class VoiceConverter:
|
66 |
+
def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt",
|
67 |
+
prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6,
|
68 |
+
checkpoint=None, script_logger=None):
|
69 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
70 |
+
self.script_logger = script_logger
|
71 |
+
self.log_if_possible(f"Using device: {self.device}")
|
72 |
+
|
73 |
+
# set up token extractor
|
74 |
+
self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device)
|
75 |
+
feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device)
|
76 |
+
self.feat_codebook = feat_codebook
|
77 |
+
self.feat_codebook_numgroups = feat_codebook_numgroups
|
78 |
+
self.log_if_possible(f"Successfully set up token extractor from {token_extractor}")
|
79 |
+
|
80 |
+
# set up prompt extractor
|
81 |
+
self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer)
|
82 |
+
self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}")
|
83 |
+
|
84 |
+
# load VC model
|
85 |
+
self.config_path = os.path.join(expdir, "config.yml")
|
86 |
+
with open(self.config_path) as f:
|
87 |
+
self.config = yaml.load(f, Loader=yaml.Loader)
|
88 |
+
if checkpoint is not None:
|
89 |
+
checkpoint = os.path.join(expdir, checkpoint)
|
90 |
+
else:
|
91 |
+
checkpoint = os.path.join(expdir, "generator.ckpt")
|
92 |
+
self.model = load_model(checkpoint, self.config)
|
93 |
+
self.log_if_possible(f"Successfully set up VC model from {checkpoint}")
|
94 |
+
|
95 |
+
self.model.backend.remove_weight_norm()
|
96 |
+
self.model.eval().to(self.device)
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def voice_conversion(self, source_audio, target_audio, output_path="output.wav"):
|
100 |
+
self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}")
|
101 |
+
source_wav = read_wav_16k(source_audio)
|
102 |
+
target_wav = read_wav_16k(target_audio)
|
103 |
+
vq_idx = self.token_extractor.extract(source_wav).long().to(self.device)
|
104 |
+
|
105 |
+
vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0)
|
106 |
+
prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device)
|
107 |
+
converted = self.model.inference(vqvec, prompt)[-1].view(-1)
|
108 |
+
sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate'])
|
109 |
+
self.log_if_possible(f"Saved audio file to {output_path}")
|
110 |
+
return output_path
|
111 |
+
|
112 |
+
def log_if_possible(self, msg):
|
113 |
+
if self.script_logger is not None:
|
114 |
+
self.script_logger.info(msg)
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
args = vc_args()
|
119 |
+
script_logger = configure_logging(args.verbose)
|
120 |
+
|
121 |
+
source_wav = read_wav_16k(args.source)
|
122 |
+
target_prompt = read_wav_16k(args.target)
|
123 |
+
|
124 |
+
with torch.no_grad():
|
125 |
+
voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor,
|
126 |
+
prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer,
|
127 |
+
checkpoint=args.checkpoint, script_logger=script_logger)
|
128 |
+
voice_converter.voice_conversion(args.source, args.target, args.output)
|
vec2wav2/datasets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .scp_dataset import * # NOQA
|
vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (288 Bytes). View file
|
|
vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (241 Bytes). View file
|
|
vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc
ADDED
Binary file (8.4 kB). View file
|
|
vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc
ADDED
Binary file (8.95 kB). View file
|
|
vec2wav2/datasets/scp_dataset.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
# Modified by Yiwei Guo, 2024
|
7 |
+
|
8 |
+
"""Dataset modules based on kaldi-style scp files."""
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import random
|
12 |
+
import copy
|
13 |
+
from multiprocessing import Manager
|
14 |
+
|
15 |
+
import kaldiio
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
from vec2wav2.utils import HDF5ScpLoader
|
21 |
+
from vec2wav2.utils import NpyScpLoader
|
22 |
+
|
23 |
+
|
24 |
+
def _get_feats_scp_loader(feats_scp):
|
25 |
+
# read the first line of feats.scp file
|
26 |
+
with open(feats_scp) as f:
|
27 |
+
key, value = f.readlines()[0].replace("\n", "").split()
|
28 |
+
|
29 |
+
# check scp type
|
30 |
+
if ":" in value:
|
31 |
+
value_1, value_2 = value.split(":")
|
32 |
+
if value_1.endswith(".ark"):
|
33 |
+
# kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index
|
34 |
+
return kaldiio.load_scp(feats_scp)
|
35 |
+
elif value_1.endswith(".h5"):
|
36 |
+
# hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats
|
37 |
+
return HDF5ScpLoader(feats_scp)
|
38 |
+
else:
|
39 |
+
raise ValueError("Not supported feats.scp type.")
|
40 |
+
else:
|
41 |
+
if value.endswith(".h5"):
|
42 |
+
# hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5
|
43 |
+
return HDF5ScpLoader(feats_scp)
|
44 |
+
elif value.endswith(".npy"):
|
45 |
+
# npy case: utt_id_1 /path/to/utt_id_1.npy
|
46 |
+
return NpyScpLoader(feats_scp)
|
47 |
+
else:
|
48 |
+
raise ValueError("Not supported feats.scp type.")
|
49 |
+
|
50 |
+
|
51 |
+
class AudioMelSCPDataset(Dataset):
|
52 |
+
"""PyTorch compatible audio and feat dataset based on kaldi-stype scp files."""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
wav_scp,
|
57 |
+
vqidx_scp,
|
58 |
+
mel_scp,
|
59 |
+
prompt_scp,
|
60 |
+
utt2num_frames=None,
|
61 |
+
segments=None,
|
62 |
+
batch_frames=None,
|
63 |
+
batch_size=None,
|
64 |
+
min_num_frames=None,
|
65 |
+
max_num_frames=None,
|
66 |
+
return_utt_id=False,
|
67 |
+
return_sampling_rate=False,
|
68 |
+
allow_cache=False,
|
69 |
+
length_tolerance=2,
|
70 |
+
prompt_fold_by_2=True
|
71 |
+
):
|
72 |
+
"""Initialize dataset.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
wav_scp (str): Kaldi-style wav.scp file.
|
76 |
+
vqidx_scp (str): Kaldi-style fests.scp file.
|
77 |
+
mel_scp (str): Kaldi-style fests.scp file.
|
78 |
+
segments (str): Kaldi-style segments file.
|
79 |
+
min_num_frames (int): Threshold to remove short feature files.
|
80 |
+
max_num_frames (int): Threshold to remove long feature files.
|
81 |
+
return_utt_id (bool): Whether to return utterance id.
|
82 |
+
return_sampling_rate (bool): Whether to return sampling rate.
|
83 |
+
allow_cache (bool): Whether to allow cache of the loaded files.
|
84 |
+
prompt_fold_by_2 (bool): if true, then prompt have half the length of vqidx sequence.
|
85 |
+
|
86 |
+
"""
|
87 |
+
# load scp as lazy dict
|
88 |
+
self.audio_loader = kaldiio.load_scp(wav_scp, segments=segments)
|
89 |
+
self.vqidx_loader = _get_feats_scp_loader(vqidx_scp)
|
90 |
+
self.mel_loader = _get_feats_scp_loader(mel_scp)
|
91 |
+
|
92 |
+
self.prompt_loader = _get_feats_scp_loader(prompt_scp)
|
93 |
+
|
94 |
+
self.utt_ids = list(self.mel_loader.keys())
|
95 |
+
self.return_utt_id = return_utt_id
|
96 |
+
self.return_sampling_rate = return_sampling_rate
|
97 |
+
self.allow_cache = allow_cache
|
98 |
+
|
99 |
+
utt2num_frames_loader = None
|
100 |
+
if utt2num_frames is not None:
|
101 |
+
with open(utt2num_frames, 'r') as f:
|
102 |
+
utt2num_frames_loader = dict([(x.split()[0], int(x.split()[1])) for x in f.readlines()])
|
103 |
+
else:
|
104 |
+
utt2num_frames_loader = dict([(k, mel.shape[0]) for k, mel in self.mel_loader.items()])
|
105 |
+
|
106 |
+
self.utt2num_frames_loader = utt2num_frames_loader
|
107 |
+
|
108 |
+
# filter by threshold
|
109 |
+
if (min_num_frames or max_num_frames) is not None:
|
110 |
+
mel_lengths = [utt2num_frames_loader[key] for key in self.utt_ids]
|
111 |
+
idxs = [
|
112 |
+
idx
|
113 |
+
for idx in range(len(self.utt_ids))
|
114 |
+
if (min_num_frames and mel_lengths[idx] >= min_num_frames) and (max_num_frames and mel_lengths[idx] <= max_num_frames)
|
115 |
+
]
|
116 |
+
if len(self.utt_ids) != len(idxs):
|
117 |
+
logging.warning(
|
118 |
+
f"Some files are filtered by mel length threshold "
|
119 |
+
f"({len(self.utt_ids)} -> {len(idxs)})."
|
120 |
+
)
|
121 |
+
self.utt_ids = [self.utt_ids[idx] for idx in idxs]
|
122 |
+
|
123 |
+
# batchify
|
124 |
+
if batch_frames is not None:
|
125 |
+
self.batches = self.batchify(utt2num_frames_loader, batch_frames=batch_frames)
|
126 |
+
elif batch_size is not None:
|
127 |
+
self.batches = self.batchify(utt2num_frames_loader, batch_size=batch_size)
|
128 |
+
else:
|
129 |
+
self.batches = [[utt_id] for utt_id in self.utt_ids]
|
130 |
+
|
131 |
+
if allow_cache:
|
132 |
+
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
|
133 |
+
self.manager = Manager()
|
134 |
+
self.caches = self.manager.dict()
|
135 |
+
self.length_tolerance = length_tolerance
|
136 |
+
if prompt_fold_by_2:
|
137 |
+
self.prompt_len_factor = 2
|
138 |
+
else:
|
139 |
+
self.prompt_len_factor = 1
|
140 |
+
|
141 |
+
def batchify(self, utt2num_frames_loader, batch_frames=None, batch_size=None, min_batch_size=1, drop_last=True):
|
142 |
+
|
143 |
+
assert batch_size is None or batch_size > min_batch_size
|
144 |
+
|
145 |
+
batches = []
|
146 |
+
batch = []
|
147 |
+
accum_num_frames = 0
|
148 |
+
utt_ids_set = set(self.utt_ids)
|
149 |
+
for utt_id, mel_length in tqdm(sorted(list(utt2num_frames_loader.items()), key=lambda x: x[1], reverse=True)):
|
150 |
+
if utt_id not in utt_ids_set:
|
151 |
+
continue
|
152 |
+
if (batch_frames is not None and accum_num_frames + mel_length > batch_frames and len(batch) > min_batch_size) or (batch_size is not None and len(batch) == batch_size):
|
153 |
+
batches.append(batch)
|
154 |
+
batch = []
|
155 |
+
accum_num_frames = 0
|
156 |
+
batch.append(utt_id)
|
157 |
+
accum_num_frames += mel_length
|
158 |
+
if len(batch) > min_batch_size and not drop_last:
|
159 |
+
batches.append(batch)
|
160 |
+
return batches
|
161 |
+
|
162 |
+
def __getitem__(self, idx):
|
163 |
+
"""Get specified idx items.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
idx (int): Index of the item.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
str: Utterance id (only in return_utt_id = True).
|
170 |
+
ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True).
|
171 |
+
ndarrays: Features (T', C).
|
172 |
+
|
173 |
+
"""
|
174 |
+
batch = self.batches[idx]
|
175 |
+
batch_items = []
|
176 |
+
|
177 |
+
for utt_id in batch:
|
178 |
+
if self.allow_cache and self.caches.get(utt_id) is not None:
|
179 |
+
items = self.caches[utt_id]
|
180 |
+
else:
|
181 |
+
fs, audio = self.audio_loader[utt_id]
|
182 |
+
mel = self.mel_loader[utt_id]
|
183 |
+
prompt = self.prompt_loader[utt_id]
|
184 |
+
vqidx = self.vqidx_loader[utt_id]
|
185 |
+
|
186 |
+
min_len = min(len(mel), len(vqidx), len(prompt)*self.prompt_len_factor)
|
187 |
+
assert ((abs(len(mel) - min_len) <= self.length_tolerance) and
|
188 |
+
(abs(len(vqidx) - min_len) <= self.length_tolerance) and
|
189 |
+
(abs(len(prompt)*self.prompt_len_factor - min_len) <= self.length_tolerance)), \
|
190 |
+
f"Audio feature lengths difference exceeds length tolerance for {utt_id}"
|
191 |
+
mel, vqidx, prompt = mel[:min_len], vqidx[:min_len], prompt[:min_len//self.prompt_len_factor]
|
192 |
+
|
193 |
+
# normalize audio signal to be [-1, 1]
|
194 |
+
audio = audio.astype(np.float32)
|
195 |
+
audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit
|
196 |
+
|
197 |
+
if self.return_sampling_rate:
|
198 |
+
audio = (audio, fs)
|
199 |
+
|
200 |
+
if self.return_utt_id:
|
201 |
+
items = utt_id, audio, vqidx, mel, prompt
|
202 |
+
else:
|
203 |
+
items = audio, vqidx, mel, prompt
|
204 |
+
|
205 |
+
if self.allow_cache:
|
206 |
+
self.caches[utt_id] = items
|
207 |
+
|
208 |
+
batch_items.append(items)
|
209 |
+
|
210 |
+
return batch_items
|
211 |
+
|
212 |
+
def __len__(self):
|
213 |
+
"""Return dataset length.
|
214 |
+
Returns:
|
215 |
+
int: The length of dataset.
|
216 |
+
"""
|
217 |
+
return len(self.batches)
|
218 |
+
|
219 |
+
|
220 |
+
class MelSCPDataset(Dataset):
|
221 |
+
"""PyTorch compatible feat dataset based on kaldi-stype scp files."""
|
222 |
+
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
vqidx_scp,
|
226 |
+
prompt_scp,
|
227 |
+
return_utt_id=False,
|
228 |
+
allow_cache=False,
|
229 |
+
):
|
230 |
+
"""Initialize dataset.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
vqidx_scp (str): Kaldi-style fests.scp file.
|
234 |
+
prompt_scp (str): Kaldi-style scp file. In this file, every utt is associated with its prompt's mel-spectrogram.
|
235 |
+
min_num_frames (int): Threshold to remove short feature files.
|
236 |
+
max_num_frames (int): Threshold to remove long feature files.
|
237 |
+
return_utt_id (bool): Whether to return utterance id.
|
238 |
+
allow_cache (bool): Whether to allow cache of the loaded files.
|
239 |
+
"""
|
240 |
+
# load scp as lazy dict
|
241 |
+
vqidx_loader = _get_feats_scp_loader(vqidx_scp)
|
242 |
+
self.prompt_loader = _get_feats_scp_loader(prompt_scp)
|
243 |
+
# self.prompt_loader = dict()
|
244 |
+
# with open(prompt_scp, 'r') as fr:
|
245 |
+
# for line in fr.readlines():
|
246 |
+
# terms = line.strip().split()
|
247 |
+
# self.prompt_loader[terms[0]] = terms[1]
|
248 |
+
vqidx_keys = list(set(self.prompt_loader.keys()) & set(vqidx_loader.keys()))
|
249 |
+
|
250 |
+
# NOTE: this dataset does not apply filtering, because it is usually used for decoding
|
251 |
+
|
252 |
+
self.vqidx_loader = vqidx_loader
|
253 |
+
self.utt_ids = vqidx_keys
|
254 |
+
self.return_utt_id = return_utt_id
|
255 |
+
self.allow_cache = allow_cache
|
256 |
+
|
257 |
+
if allow_cache:
|
258 |
+
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
|
259 |
+
self.manager = Manager()
|
260 |
+
self.caches = self.manager.list()
|
261 |
+
self.caches += [() for _ in range(len(self.utt_ids))]
|
262 |
+
|
263 |
+
def __getitem__(self, idx):
|
264 |
+
"""Get specified idx items.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
idx (int): Index of the item.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
str: Utterance id (only in return_utt_id = True).
|
271 |
+
ndarray: Feature (T', C).
|
272 |
+
|
273 |
+
"""
|
274 |
+
if self.allow_cache and len(self.caches[idx]) != 0:
|
275 |
+
return self.caches[idx]
|
276 |
+
|
277 |
+
utt_id = self.utt_ids[idx]
|
278 |
+
vqidx = self.vqidx_loader[utt_id].astype(int)
|
279 |
+
|
280 |
+
# prompt = torch.load(self.prompt_loader[utt_id]).float().numpy()
|
281 |
+
prompt = self.prompt_loader[utt_id]
|
282 |
+
|
283 |
+
if self.return_utt_id:
|
284 |
+
items = utt_id, vqidx, prompt
|
285 |
+
else:
|
286 |
+
items = vqidx, prompt
|
287 |
+
|
288 |
+
if self.allow_cache:
|
289 |
+
self.caches[idx] = items
|
290 |
+
|
291 |
+
return items
|
292 |
+
|
293 |
+
def __len__(self):
|
294 |
+
"""Return dataset length.
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
int: The length of dataset.
|
298 |
+
|
299 |
+
"""
|
300 |
+
return len(self.utt_ids)
|
vec2wav2/distributed/__init__.py
ADDED
File without changes
|
vec2wav2/distributed/launch.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Distributed process launcher.
|
5 |
+
|
6 |
+
This code is modified from https://github.com/pytorch/pytorch/blob/v1.3.0/torch/distributed/launch.py.
|
7 |
+
|
8 |
+
"""
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import sys
|
12 |
+
|
13 |
+
from argparse import ArgumentParser
|
14 |
+
from argparse import REMAINDER
|
15 |
+
|
16 |
+
|
17 |
+
def parse_args():
|
18 |
+
"""Parse arguments."""
|
19 |
+
parser = ArgumentParser(
|
20 |
+
description="PyTorch distributed training launch "
|
21 |
+
"helper utilty that will spawn up "
|
22 |
+
"multiple distributed processes"
|
23 |
+
)
|
24 |
+
|
25 |
+
# Optional arguments for the launch helper
|
26 |
+
parser.add_argument(
|
27 |
+
"--nnodes",
|
28 |
+
type=int,
|
29 |
+
default=1,
|
30 |
+
help="The number of nodes to use for distributed " "training",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--node_rank",
|
34 |
+
type=int,
|
35 |
+
default=0,
|
36 |
+
help="The rank of the node for multi-node distributed " "training",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--nproc_per_node",
|
40 |
+
type=int,
|
41 |
+
default=1,
|
42 |
+
help="The number of processes to launch on each node, "
|
43 |
+
"for GPU training, this is recommended to be set "
|
44 |
+
"to the number of GPUs in your system so that "
|
45 |
+
"each process can be bound to a single GPU.",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--master_addr",
|
49 |
+
default="127.0.0.1",
|
50 |
+
type=str,
|
51 |
+
help="Master node (rank 0)'s address, should be either "
|
52 |
+
"the IP address or the hostname of node 0, for "
|
53 |
+
"single node multi-proc training, the "
|
54 |
+
"--master_addr can simply be 127.0.0.1",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--master_port",
|
58 |
+
default=29500,
|
59 |
+
type=int,
|
60 |
+
help="Master node (rank 0)'s free port that needs to "
|
61 |
+
"be used for communciation during distributed "
|
62 |
+
"training",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--use_env",
|
66 |
+
default=False,
|
67 |
+
action="store_true",
|
68 |
+
help="Use environment variable to pass "
|
69 |
+
"'local rank'. For legacy reasons, the default value is False. "
|
70 |
+
"If set to True, the script will not pass "
|
71 |
+
"--local_rank as argument, and will instead set LOCAL_RANK.",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"-m",
|
75 |
+
"--module",
|
76 |
+
default=False,
|
77 |
+
action="store_true",
|
78 |
+
help="Changes each process to interpret the launch script "
|
79 |
+
"as a python module, executing with the same behavior as"
|
80 |
+
"'python -m'.",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"-c",
|
84 |
+
"--command",
|
85 |
+
default=False,
|
86 |
+
action="store_true",
|
87 |
+
help="Changes each process to interpret the launch script " "as a command.",
|
88 |
+
)
|
89 |
+
|
90 |
+
# positional
|
91 |
+
parser.add_argument(
|
92 |
+
"training_script",
|
93 |
+
type=str,
|
94 |
+
help="The full path to the single GPU training "
|
95 |
+
"program/script/command to be launched in parallel, "
|
96 |
+
"followed by all the arguments for the "
|
97 |
+
"training script",
|
98 |
+
)
|
99 |
+
|
100 |
+
# rest from the training program
|
101 |
+
parser.add_argument("training_script_args", nargs=REMAINDER)
|
102 |
+
return parser.parse_args()
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
"""Launch distributed processes."""
|
107 |
+
args = parse_args()
|
108 |
+
|
109 |
+
# world size in terms of number of processes
|
110 |
+
dist_world_size = args.nproc_per_node * args.nnodes
|
111 |
+
|
112 |
+
# set PyTorch distributed related environmental variables
|
113 |
+
current_env = os.environ.copy()
|
114 |
+
current_env["MASTER_ADDR"] = args.master_addr
|
115 |
+
current_env["MASTER_PORT"] = str(args.master_port)
|
116 |
+
current_env["WORLD_SIZE"] = str(dist_world_size)
|
117 |
+
|
118 |
+
processes = []
|
119 |
+
|
120 |
+
if "OMP_NUM_THREADS" not in os.environ and args.nproc_per_node > 1:
|
121 |
+
current_env["OMP_NUM_THREADS"] = str(1)
|
122 |
+
print(
|
123 |
+
"*****************************************\n"
|
124 |
+
"Setting OMP_NUM_THREADS environment variable for each process "
|
125 |
+
"to be {} in default, to avoid your system being overloaded, "
|
126 |
+
"please further tune the variable for optimal performance in "
|
127 |
+
"your application as needed. \n"
|
128 |
+
"*****************************************".format(
|
129 |
+
current_env["OMP_NUM_THREADS"]
|
130 |
+
)
|
131 |
+
)
|
132 |
+
|
133 |
+
for local_rank in range(0, args.nproc_per_node):
|
134 |
+
# each process's rank
|
135 |
+
dist_rank = args.nproc_per_node * args.node_rank + local_rank
|
136 |
+
current_env["RANK"] = str(dist_rank)
|
137 |
+
current_env["LOCAL_RANK"] = str(local_rank)
|
138 |
+
|
139 |
+
# spawn the processes
|
140 |
+
if args.command:
|
141 |
+
cmd = [args.training_script]
|
142 |
+
else:
|
143 |
+
cmd = [sys.executable, "-u"]
|
144 |
+
if args.module:
|
145 |
+
cmd.append("-m")
|
146 |
+
cmd.append(args.training_script)
|
147 |
+
|
148 |
+
if not args.use_env:
|
149 |
+
cmd.append("--local_rank={}".format(local_rank))
|
150 |
+
|
151 |
+
cmd.extend(args.training_script_args)
|
152 |
+
|
153 |
+
process = subprocess.Popen(cmd, env=current_env)
|
154 |
+
processes.append(process)
|
155 |
+
|
156 |
+
for process in processes:
|
157 |
+
process.wait()
|
158 |
+
if process.returncode != 0:
|
159 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
main()
|
vec2wav2/layers/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .causal_conv import * # NOQA
|
2 |
+
from .pqmf import * # NOQA
|
3 |
+
from .residual_block import * # NOQA
|
4 |
+
from .residual_stack import * # NOQA
|
5 |
+
from .tade_res_block import * # NOQA
|
6 |
+
from .upsample import * # NOQA
|
vec2wav2/layers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (344 Bytes). View file
|
|
vec2wav2/layers/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (353 Bytes). View file
|
|
vec2wav2/layers/__pycache__/activations.cpython-310.pyc
ADDED
Binary file (6.64 kB). View file
|
|
vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc
ADDED
Binary file (2.23 kB). View file
|
|
vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc
ADDED
Binary file (2.24 kB). View file
|
|
vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc
ADDED
Binary file (4.14 kB). View file
|
|
vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc
ADDED
Binary file (4.14 kB). View file
|
|
vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc
ADDED
Binary file (6.21 kB). View file
|
|
vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc
ADDED
Binary file (6.18 kB). View file
|
|
vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc
ADDED
Binary file (2.51 kB). View file
|
|
vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc
ADDED
Binary file (2.51 kB). View file
|
|
vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc
ADDED
Binary file (3.59 kB). View file
|
|
vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc
ADDED
Binary file (3.56 kB). View file
|
|
vec2wav2/layers/__pycache__/upsample.cpython-310.pyc
ADDED
Binary file (6.01 kB). View file
|
|
vec2wav2/layers/__pycache__/upsample.cpython-39.pyc
ADDED
Binary file (6 kB). View file
|
|
vec2wav2/layers/activations.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
# Modified by Yiwei Guo, 2024
|
5 |
+
# including conditioned snakebeta activation
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn, sin, pow
|
9 |
+
from torch.nn import Parameter
|
10 |
+
|
11 |
+
|
12 |
+
class Snake(nn.Module):
|
13 |
+
'''
|
14 |
+
Implementation of a sine-based periodic activation function
|
15 |
+
Shape:
|
16 |
+
- Input: (B, C, T)
|
17 |
+
- Output: (B, C, T), same shape as the input
|
18 |
+
Parameters:
|
19 |
+
- alpha - trainable parameter
|
20 |
+
References:
|
21 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
22 |
+
https://arxiv.org/abs/2006.08195
|
23 |
+
Examples:
|
24 |
+
>>> a1 = snake(256)
|
25 |
+
>>> x = torch.randn(256)
|
26 |
+
>>> x = a1(x)
|
27 |
+
'''
|
28 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
29 |
+
'''
|
30 |
+
Initialization.
|
31 |
+
INPUT:
|
32 |
+
- in_features: shape of the input
|
33 |
+
- alpha: trainable parameter
|
34 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
35 |
+
alpha will be trained along with the rest of your model.
|
36 |
+
'''
|
37 |
+
super(Snake, self).__init__()
|
38 |
+
self.in_features = in_features
|
39 |
+
|
40 |
+
# initialize alpha
|
41 |
+
self.alpha_logscale = alpha_logscale
|
42 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
43 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
44 |
+
else: # linear scale alphas initialized to ones
|
45 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
46 |
+
|
47 |
+
self.alpha.requires_grad = alpha_trainable
|
48 |
+
|
49 |
+
self.no_div_by_zero = 0.000000001
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
'''
|
53 |
+
Forward pass of the function.
|
54 |
+
Applies the function to the input elementwise.
|
55 |
+
Snake := x + 1/a * sin^2 (xa)
|
56 |
+
'''
|
57 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
58 |
+
if self.alpha_logscale:
|
59 |
+
alpha = torch.exp(alpha)
|
60 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class SnakeBeta(nn.Module):
|
66 |
+
'''
|
67 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
68 |
+
Shape:
|
69 |
+
- Input: (B, C, T)
|
70 |
+
- Output: (B, C, T), same shape as the input
|
71 |
+
Parameters:
|
72 |
+
- alpha - trainable parameter that controls frequency
|
73 |
+
- beta - trainable parameter that controls magnitude
|
74 |
+
References:
|
75 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
76 |
+
https://arxiv.org/abs/2006.08195
|
77 |
+
Examples:
|
78 |
+
>>> a1 = snakebeta(256)
|
79 |
+
>>> x = torch.randn(256)
|
80 |
+
>>> x = a1(x)
|
81 |
+
'''
|
82 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
83 |
+
'''
|
84 |
+
Initialization.
|
85 |
+
INPUT:
|
86 |
+
- in_features: shape of the input
|
87 |
+
- alpha - trainable parameter that controls frequency
|
88 |
+
- beta - trainable parameter that controls magnitude
|
89 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
90 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
91 |
+
alpha will be trained along with the rest of your model.
|
92 |
+
'''
|
93 |
+
super(SnakeBeta, self).__init__()
|
94 |
+
self.in_features = in_features
|
95 |
+
|
96 |
+
# initialize alpha
|
97 |
+
self.alpha_logscale = alpha_logscale
|
98 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
99 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
100 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
101 |
+
else: # linear scale alphas initialized to ones
|
102 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
103 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
104 |
+
|
105 |
+
self.alpha.requires_grad = alpha_trainable
|
106 |
+
self.beta.requires_grad = alpha_trainable
|
107 |
+
|
108 |
+
self.no_div_by_zero = 0.000000001
|
109 |
+
|
110 |
+
def forward(self, x, cond=None):
|
111 |
+
'''
|
112 |
+
Forward pass of the function.
|
113 |
+
Applies the function to the input elementwise.
|
114 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
115 |
+
'''
|
116 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
117 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
118 |
+
if self.alpha_logscale:
|
119 |
+
alpha = torch.exp(alpha)
|
120 |
+
beta = torch.exp(beta)
|
121 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
122 |
+
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
class SnakeBetaWithCondition(nn.Module):
|
127 |
+
'''
|
128 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
129 |
+
Shape:
|
130 |
+
- Input: (B, C, T)
|
131 |
+
- Condition: (B, D), where D-dimension will be mapped to C dimensions
|
132 |
+
- Output: (B, C, T), same shape as the input
|
133 |
+
Parameters:
|
134 |
+
- alpha - trainable parameter that controls frequency
|
135 |
+
- beta - trainable parameter that controls magnitude
|
136 |
+
- condition_alpha_prenet - trainable parameter that controls alpha and beta using condition
|
137 |
+
References:
|
138 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
139 |
+
https://arxiv.org/abs/2006.08195
|
140 |
+
Examples:
|
141 |
+
>>> a1 = snakebeta(256, 128)
|
142 |
+
>>> x = torch.randn(256)
|
143 |
+
>>> cond = torch.randn(128)
|
144 |
+
>>> x = a1(x, cond)
|
145 |
+
'''
|
146 |
+
def __init__(self, in_features, condition_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
147 |
+
'''
|
148 |
+
Initialization.
|
149 |
+
INPUT:
|
150 |
+
- in_features: dimension of the input
|
151 |
+
- condition_features: dimension of the condition vectors
|
152 |
+
- alpha - trainable parameter that controls frequency
|
153 |
+
- beta - trainable parameter that controls magnitude
|
154 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
155 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
156 |
+
alpha, beta will be trained along with the rest of your model.
|
157 |
+
'''
|
158 |
+
super(SnakeBetaWithCondition, self).__init__()
|
159 |
+
self.in_features = in_features
|
160 |
+
|
161 |
+
self.condition_alpha_prenet = torch.nn.Linear(condition_features, in_features)
|
162 |
+
# self.condition_beta_prenet = torch.nn.Linear(condition_features, in_features)
|
163 |
+
|
164 |
+
# initialize alpha
|
165 |
+
self.alpha_logscale = alpha_logscale
|
166 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
167 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
168 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
169 |
+
else: # linear scale alphas initialized to ones
|
170 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
171 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
172 |
+
|
173 |
+
self.alpha.requires_grad = alpha_trainable
|
174 |
+
self.beta.requires_grad = alpha_trainable
|
175 |
+
|
176 |
+
self.no_div_by_zero = 0.000000001
|
177 |
+
|
178 |
+
def forward(self, x, condition):
|
179 |
+
'''
|
180 |
+
condition: [B, D]
|
181 |
+
Forward pass of the function.
|
182 |
+
Applies the function to the input elementwise.
|
183 |
+
SnakeBeta := x + 1/b * sin^2 (xa)
|
184 |
+
'''
|
185 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
186 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
187 |
+
if self.alpha_logscale:
|
188 |
+
alpha = torch.exp(alpha)
|
189 |
+
beta = torch.exp(beta)
|
190 |
+
|
191 |
+
condition = torch.tanh(self.condition_alpha_prenet(condition).unsqueeze(-1)) # Same prenet for both alpha and beta, to save parameters
|
192 |
+
alpha = alpha + condition
|
193 |
+
beta = beta + 0.5 * condition # multiply 0.5 for avoiding beta being too small
|
194 |
+
|
195 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
196 |
+
|
197 |
+
return x
|
vec2wav2/layers/causal_conv.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Causal convolusion layer modules."""
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class CausalConv1d(torch.nn.Module):
|
13 |
+
"""CausalConv1d module with customized initialization."""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_channels,
|
18 |
+
out_channels,
|
19 |
+
kernel_size,
|
20 |
+
dilation=1,
|
21 |
+
bias=True,
|
22 |
+
pad="ConstantPad1d",
|
23 |
+
pad_params={"value": 0.0},
|
24 |
+
):
|
25 |
+
"""Initialize CausalConv1d module."""
|
26 |
+
super(CausalConv1d, self).__init__()
|
27 |
+
self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
|
28 |
+
self.conv = torch.nn.Conv1d(
|
29 |
+
in_channels, out_channels, kernel_size, dilation=dilation, bias=bias
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
"""Calculate forward propagation.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
x (Tensor): Input tensor (B, in_channels, T).
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Tensor: Output tensor (B, out_channels, T).
|
40 |
+
|
41 |
+
"""
|
42 |
+
return self.conv(self.pad(x))[:, :, : x.size(2)]
|
43 |
+
|
44 |
+
|
45 |
+
class CausalConvTranspose1d(torch.nn.Module):
|
46 |
+
"""CausalConvTranspose1d module with customized initialization."""
|
47 |
+
|
48 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
|
49 |
+
"""Initialize CausalConvTranspose1d module."""
|
50 |
+
super(CausalConvTranspose1d, self).__init__()
|
51 |
+
self.deconv = torch.nn.ConvTranspose1d(
|
52 |
+
in_channels, out_channels, kernel_size, stride, bias=bias
|
53 |
+
)
|
54 |
+
self.stride = stride
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
"""Calculate forward propagation.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
x (Tensor): Input tensor (B, in_channels, T_in).
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Tensor: Output tensor (B, out_channels, T_out).
|
64 |
+
|
65 |
+
"""
|
66 |
+
return self.deconv(x)[:, :, : -self.stride]
|
vec2wav2/layers/pqmf.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Pseudo QMF modules."""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from scipy.signal import kaiser
|
13 |
+
|
14 |
+
|
15 |
+
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
|
16 |
+
"""Design prototype filter for PQMF.
|
17 |
+
|
18 |
+
This method is based on `A Kaiser window approach for the design of prototype
|
19 |
+
filters of cosine modulated filterbanks`_.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
taps (int): The number of filter taps.
|
23 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
24 |
+
beta (float): Beta coefficient for kaiser window.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
ndarray: Impluse response of prototype filter (taps + 1,).
|
28 |
+
|
29 |
+
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
30 |
+
https://ieeexplore.ieee.org/abstract/document/681427
|
31 |
+
|
32 |
+
"""
|
33 |
+
# check the arguments are valid
|
34 |
+
assert taps % 2 == 0, "The number of taps mush be even number."
|
35 |
+
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
36 |
+
|
37 |
+
# make initial filter
|
38 |
+
omega_c = np.pi * cutoff_ratio
|
39 |
+
with np.errstate(invalid="ignore"):
|
40 |
+
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
|
41 |
+
np.pi * (np.arange(taps + 1) - 0.5 * taps)
|
42 |
+
)
|
43 |
+
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
44 |
+
|
45 |
+
# apply kaiser window
|
46 |
+
w = kaiser(taps + 1, beta)
|
47 |
+
h = h_i * w
|
48 |
+
|
49 |
+
return h
|
50 |
+
|
51 |
+
|
52 |
+
class PQMF(torch.nn.Module):
|
53 |
+
"""PQMF module.
|
54 |
+
|
55 |
+
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
56 |
+
|
57 |
+
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
58 |
+
https://ieeexplore.ieee.org/document/258122
|
59 |
+
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
|
63 |
+
"""Initilize PQMF module.
|
64 |
+
|
65 |
+
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
|
66 |
+
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
subbands (int): The number of subbands.
|
70 |
+
taps (int): The number of filter taps.
|
71 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
72 |
+
beta (float): Beta coefficient for kaiser window.
|
73 |
+
|
74 |
+
"""
|
75 |
+
super(PQMF, self).__init__()
|
76 |
+
|
77 |
+
# build analysis & synthesis filter coefficients
|
78 |
+
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
79 |
+
h_analysis = np.zeros((subbands, len(h_proto)))
|
80 |
+
h_synthesis = np.zeros((subbands, len(h_proto)))
|
81 |
+
for k in range(subbands):
|
82 |
+
h_analysis[k] = (
|
83 |
+
2
|
84 |
+
* h_proto
|
85 |
+
* np.cos(
|
86 |
+
(2 * k + 1)
|
87 |
+
* (np.pi / (2 * subbands))
|
88 |
+
* (np.arange(taps + 1) - (taps / 2))
|
89 |
+
+ (-1) ** k * np.pi / 4
|
90 |
+
)
|
91 |
+
)
|
92 |
+
h_synthesis[k] = (
|
93 |
+
2
|
94 |
+
* h_proto
|
95 |
+
* np.cos(
|
96 |
+
(2 * k + 1)
|
97 |
+
* (np.pi / (2 * subbands))
|
98 |
+
* (np.arange(taps + 1) - (taps / 2))
|
99 |
+
- (-1) ** k * np.pi / 4
|
100 |
+
)
|
101 |
+
)
|
102 |
+
|
103 |
+
# convert to tensor
|
104 |
+
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
105 |
+
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
106 |
+
|
107 |
+
# register coefficients as beffer
|
108 |
+
self.register_buffer("analysis_filter", analysis_filter)
|
109 |
+
self.register_buffer("synthesis_filter", synthesis_filter)
|
110 |
+
|
111 |
+
# filter for downsampling & upsampling
|
112 |
+
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
113 |
+
for k in range(subbands):
|
114 |
+
updown_filter[k, k, 0] = 1.0
|
115 |
+
self.register_buffer("updown_filter", updown_filter)
|
116 |
+
self.subbands = subbands
|
117 |
+
|
118 |
+
# keep padding info
|
119 |
+
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
120 |
+
|
121 |
+
def analysis(self, x):
|
122 |
+
"""Analysis with PQMF.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
x (Tensor): Input tensor (B, 1, T).
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tensor: Output tensor (B, subbands, T // subbands).
|
129 |
+
|
130 |
+
"""
|
131 |
+
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
132 |
+
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
133 |
+
|
134 |
+
def synthesis(self, x):
|
135 |
+
"""Synthesis with PQMF.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
x (Tensor): Input tensor (B, subbands, T // subbands).
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Tensor: Output tensor (B, 1, T).
|
142 |
+
|
143 |
+
"""
|
144 |
+
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
|
145 |
+
# Not sure this is the correct way, it is better to check again.
|
146 |
+
# TODO(kan-bayashi): Understand the reconstruction procedure
|
147 |
+
x = F.conv_transpose1d(
|
148 |
+
x, self.updown_filter * self.subbands, stride=self.subbands
|
149 |
+
)
|
150 |
+
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
vec2wav2/layers/residual_block.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""Residual block modules.
|
4 |
+
|
5 |
+
References:
|
6 |
+
- https://github.com/r9y9/wavenet_vocoder
|
7 |
+
- https://github.com/jik876/hifi-gan
|
8 |
+
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
class Conv1d(torch.nn.Conv1d):
|
18 |
+
"""Conv1d module with customized initialization."""
|
19 |
+
|
20 |
+
def __init__(self, *args, **kwargs):
|
21 |
+
"""Initialize Conv1d module."""
|
22 |
+
super(Conv1d, self).__init__(*args, **kwargs)
|
23 |
+
|
24 |
+
def reset_parameters(self):
|
25 |
+
"""Reset parameters."""
|
26 |
+
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
27 |
+
if self.bias is not None:
|
28 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
29 |
+
|
30 |
+
|
31 |
+
class Conv1d1x1(Conv1d):
|
32 |
+
"""1x1 Conv1d with customized initialization."""
|
33 |
+
|
34 |
+
def __init__(self, in_channels, out_channels, bias):
|
35 |
+
"""Initialize 1x1 Conv1d module."""
|
36 |
+
super(Conv1d1x1, self).__init__(
|
37 |
+
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class WaveNetResidualBlock(torch.nn.Module):
|
42 |
+
"""Residual block module in WaveNet."""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
kernel_size=3,
|
47 |
+
residual_channels=64,
|
48 |
+
gate_channels=128,
|
49 |
+
skip_channels=64,
|
50 |
+
aux_channels=80,
|
51 |
+
dropout=0.0,
|
52 |
+
dilation=1,
|
53 |
+
bias=True,
|
54 |
+
use_causal_conv=False,
|
55 |
+
):
|
56 |
+
"""Initialize WaveNetResidualBlock module.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
60 |
+
residual_channels (int): Number of channels for residual connection.
|
61 |
+
skip_channels (int): Number of channels for skip connection.
|
62 |
+
aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
|
63 |
+
dropout (float): Dropout probability.
|
64 |
+
dilation (int): Dilation factor.
|
65 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
66 |
+
use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
|
67 |
+
|
68 |
+
"""
|
69 |
+
super().__init__()
|
70 |
+
self.dropout = dropout
|
71 |
+
# no future time stamps available
|
72 |
+
if use_causal_conv:
|
73 |
+
padding = (kernel_size - 1) * dilation
|
74 |
+
else:
|
75 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
76 |
+
padding = (kernel_size - 1) // 2 * dilation
|
77 |
+
self.use_causal_conv = use_causal_conv
|
78 |
+
|
79 |
+
# dilation conv
|
80 |
+
self.conv = Conv1d(
|
81 |
+
residual_channels,
|
82 |
+
gate_channels,
|
83 |
+
kernel_size,
|
84 |
+
padding=padding,
|
85 |
+
dilation=dilation,
|
86 |
+
bias=bias,
|
87 |
+
)
|
88 |
+
|
89 |
+
# local conditioning
|
90 |
+
if aux_channels > 0:
|
91 |
+
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
|
92 |
+
else:
|
93 |
+
self.conv1x1_aux = None
|
94 |
+
|
95 |
+
# conv output is split into two groups
|
96 |
+
gate_out_channels = gate_channels // 2
|
97 |
+
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
|
98 |
+
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
|
99 |
+
|
100 |
+
def forward(self, x, c):
|
101 |
+
"""Calculate forward propagation.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
x (Tensor): Input tensor (B, residual_channels, T).
|
105 |
+
c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
109 |
+
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
110 |
+
|
111 |
+
"""
|
112 |
+
residual = x
|
113 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
114 |
+
x = self.conv(x)
|
115 |
+
|
116 |
+
# remove future time steps if use_causal_conv conv
|
117 |
+
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
|
118 |
+
|
119 |
+
# split into two part for gated activation
|
120 |
+
splitdim = 1
|
121 |
+
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
122 |
+
|
123 |
+
# local conditioning
|
124 |
+
if c is not None:
|
125 |
+
assert self.conv1x1_aux is not None
|
126 |
+
c = self.conv1x1_aux(c)
|
127 |
+
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
128 |
+
xa, xb = xa + ca, xb + cb
|
129 |
+
|
130 |
+
x = torch.tanh(xa) * torch.sigmoid(xb)
|
131 |
+
|
132 |
+
# for skip connection
|
133 |
+
s = self.conv1x1_skip(x)
|
134 |
+
|
135 |
+
# for residual connection
|
136 |
+
x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
|
137 |
+
|
138 |
+
return x, s
|
139 |
+
|
140 |
+
|
141 |
+
class HiFiGANResidualBlock(torch.nn.Module):
|
142 |
+
"""Residual block module in HiFiGAN."""
|
143 |
+
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
kernel_size=3,
|
147 |
+
channels=512,
|
148 |
+
dilations=(1, 3, 5),
|
149 |
+
bias=True,
|
150 |
+
use_additional_convs=True,
|
151 |
+
nonlinear_activation="LeakyReLU",
|
152 |
+
nonlinear_activation_params={"negative_slope": 0.1},
|
153 |
+
):
|
154 |
+
"""Initialize HiFiGANResidualBlock module.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
158 |
+
channels (int): Number of channels for convolution layer.
|
159 |
+
dilations (List[int]): List of dilation factors.
|
160 |
+
use_additional_convs (bool): Whether to use additional convolution layers.
|
161 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
162 |
+
nonlinear_activation (str): Activation function module name.
|
163 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
164 |
+
|
165 |
+
"""
|
166 |
+
super().__init__()
|
167 |
+
self.use_additional_convs = use_additional_convs
|
168 |
+
self.convs1 = torch.nn.ModuleList()
|
169 |
+
if use_additional_convs:
|
170 |
+
self.convs2 = torch.nn.ModuleList()
|
171 |
+
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
172 |
+
for dilation in dilations:
|
173 |
+
self.convs1 += [
|
174 |
+
torch.nn.Sequential(
|
175 |
+
getattr(torch.nn, nonlinear_activation)(
|
176 |
+
**nonlinear_activation_params
|
177 |
+
),
|
178 |
+
torch.nn.Conv1d(
|
179 |
+
channels,
|
180 |
+
channels,
|
181 |
+
kernel_size,
|
182 |
+
1,
|
183 |
+
dilation=dilation,
|
184 |
+
bias=bias,
|
185 |
+
padding=(kernel_size - 1) // 2 * dilation,
|
186 |
+
),
|
187 |
+
)
|
188 |
+
]
|
189 |
+
if use_additional_convs:
|
190 |
+
self.convs2 += [
|
191 |
+
torch.nn.Sequential(
|
192 |
+
getattr(torch.nn, nonlinear_activation)(
|
193 |
+
**nonlinear_activation_params
|
194 |
+
),
|
195 |
+
torch.nn.Conv1d(
|
196 |
+
channels,
|
197 |
+
channels,
|
198 |
+
kernel_size,
|
199 |
+
1,
|
200 |
+
dilation=1,
|
201 |
+
bias=bias,
|
202 |
+
padding=(kernel_size - 1) // 2,
|
203 |
+
),
|
204 |
+
)
|
205 |
+
]
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
"""Calculate forward propagation.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
x (Tensor): Input tensor (B, channels, T).
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
Tensor: Output tensor (B, channels, T).
|
215 |
+
|
216 |
+
"""
|
217 |
+
for idx in range(len(self.convs1)):
|
218 |
+
xt = self.convs1[idx](x)
|
219 |
+
if self.use_additional_convs:
|
220 |
+
xt = self.convs2[idx](xt)
|
221 |
+
x = xt + x
|
222 |
+
return x
|
vec2wav2/layers/residual_stack.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Residual stack module in MelGAN."""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from vec2wav2.layers import CausalConv1d
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualStack(torch.nn.Module):
|
14 |
+
"""Residual stack module introduced in MelGAN."""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
kernel_size=3,
|
19 |
+
channels=32,
|
20 |
+
dilation=1,
|
21 |
+
bias=True,
|
22 |
+
nonlinear_activation="LeakyReLU",
|
23 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
24 |
+
pad="ReflectionPad1d",
|
25 |
+
pad_params={},
|
26 |
+
use_causal_conv=False,
|
27 |
+
):
|
28 |
+
"""Initialize ResidualStack module.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
32 |
+
channels (int): Number of channels of convolution layers.
|
33 |
+
dilation (int): Dilation factor.
|
34 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
35 |
+
nonlinear_activation (str): Activation function module name.
|
36 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
37 |
+
pad (str): Padding function module name before dilated convolution layer.
|
38 |
+
pad_params (dict): Hyperparameters for padding function.
|
39 |
+
use_causal_conv (bool): Whether to use causal convolution.
|
40 |
+
|
41 |
+
"""
|
42 |
+
super(ResidualStack, self).__init__()
|
43 |
+
|
44 |
+
# defile residual stack part
|
45 |
+
if not use_causal_conv:
|
46 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
47 |
+
self.stack = torch.nn.Sequential(
|
48 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
49 |
+
getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
|
50 |
+
torch.nn.Conv1d(
|
51 |
+
channels, channels, kernel_size, dilation=dilation, bias=bias
|
52 |
+
),
|
53 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
54 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias),
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
self.stack = torch.nn.Sequential(
|
58 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
59 |
+
CausalConv1d(
|
60 |
+
channels,
|
61 |
+
channels,
|
62 |
+
kernel_size,
|
63 |
+
dilation=dilation,
|
64 |
+
bias=bias,
|
65 |
+
pad=pad,
|
66 |
+
pad_params=pad_params,
|
67 |
+
),
|
68 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
69 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias),
|
70 |
+
)
|
71 |
+
|
72 |
+
# defile extra layer for skip connection
|
73 |
+
self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
|
74 |
+
|
75 |
+
def forward(self, c):
|
76 |
+
"""Calculate forward propagation.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
c (Tensor): Input tensor (B, channels, T).
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
Tensor: Output tensor (B, chennels, T).
|
83 |
+
|
84 |
+
"""
|
85 |
+
return self.stack(c) + self.skip_layer(c)
|
vec2wav2/layers/tade_res_block.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 Tomoki Hayashi
|
2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
3 |
+
|
4 |
+
"""StyleMelGAN's TADEResBlock Modules."""
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class TADELayer(torch.nn.Module):
|
12 |
+
"""TADE Layer module."""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_channels=64,
|
17 |
+
aux_channels=80,
|
18 |
+
kernel_size=9,
|
19 |
+
bias=True,
|
20 |
+
upsample_factor=2,
|
21 |
+
upsample_mode="nearest",
|
22 |
+
):
|
23 |
+
"""Initilize TADE layer."""
|
24 |
+
super().__init__()
|
25 |
+
self.norm = torch.nn.InstanceNorm1d(in_channels)
|
26 |
+
self.aux_conv = torch.nn.Sequential(
|
27 |
+
torch.nn.Conv1d(
|
28 |
+
aux_channels,
|
29 |
+
in_channels,
|
30 |
+
kernel_size,
|
31 |
+
1,
|
32 |
+
bias=bias,
|
33 |
+
padding=(kernel_size - 1) // 2,
|
34 |
+
),
|
35 |
+
# NOTE(kan-bayashi): Use non-linear activation?
|
36 |
+
)
|
37 |
+
self.gated_conv = torch.nn.Sequential(
|
38 |
+
torch.nn.Conv1d(
|
39 |
+
in_channels,
|
40 |
+
in_channels * 2,
|
41 |
+
kernel_size,
|
42 |
+
1,
|
43 |
+
bias=bias,
|
44 |
+
padding=(kernel_size - 1) // 2,
|
45 |
+
),
|
46 |
+
# NOTE(kan-bayashi): Use non-linear activation?
|
47 |
+
)
|
48 |
+
self.upsample = torch.nn.Upsample(
|
49 |
+
scale_factor=upsample_factor, mode=upsample_mode
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x, c):
|
53 |
+
"""Calculate forward propagation.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (Tensor): Input tensor (B, in_channels, T).
|
57 |
+
c (Tensor): Auxiliary input tensor (B, aux_channels, T').
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tensor: Output tensor (B, in_channels, T * in_upsample_factor).
|
61 |
+
Tensor: Upsampled aux tensor (B, in_channels, T * aux_upsample_factor).
|
62 |
+
|
63 |
+
"""
|
64 |
+
x = self.norm(x)
|
65 |
+
c = self.upsample(c)
|
66 |
+
c = self.aux_conv(c)
|
67 |
+
cg = self.gated_conv(c)
|
68 |
+
cg1, cg2 = cg.split(cg.size(1) // 2, dim=1)
|
69 |
+
# NOTE(kan-bayashi): Use upsample for noise input here?
|
70 |
+
y = cg1 * self.upsample(x) + cg2
|
71 |
+
# NOTE(kan-bayashi): Return upsampled aux here?
|
72 |
+
return y, c
|
73 |
+
|
74 |
+
|
75 |
+
class TADEResBlock(torch.nn.Module):
|
76 |
+
"""TADEResBlock module."""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
in_channels=64,
|
81 |
+
aux_channels=80,
|
82 |
+
kernel_size=9,
|
83 |
+
dilation=2,
|
84 |
+
bias=True,
|
85 |
+
upsample_factor=2,
|
86 |
+
upsample_mode="nearest",
|
87 |
+
gated_function="softmax",
|
88 |
+
):
|
89 |
+
"""Initialize TADEResBlock module."""
|
90 |
+
super().__init__()
|
91 |
+
self.tade1 = TADELayer(
|
92 |
+
in_channels=in_channels,
|
93 |
+
aux_channels=aux_channels,
|
94 |
+
kernel_size=kernel_size,
|
95 |
+
bias=bias,
|
96 |
+
# NOTE(kan-bayashi): Use upsample in the first TADE layer?
|
97 |
+
upsample_factor=1,
|
98 |
+
upsample_mode=upsample_mode,
|
99 |
+
)
|
100 |
+
self.gated_conv1 = torch.nn.Conv1d(
|
101 |
+
in_channels,
|
102 |
+
in_channels * 2,
|
103 |
+
kernel_size,
|
104 |
+
1,
|
105 |
+
bias=bias,
|
106 |
+
padding=(kernel_size - 1) // 2,
|
107 |
+
)
|
108 |
+
self.tade2 = TADELayer(
|
109 |
+
in_channels=in_channels,
|
110 |
+
aux_channels=in_channels,
|
111 |
+
kernel_size=kernel_size,
|
112 |
+
bias=bias,
|
113 |
+
upsample_factor=upsample_factor,
|
114 |
+
upsample_mode=upsample_mode,
|
115 |
+
)
|
116 |
+
self.gated_conv2 = torch.nn.Conv1d(
|
117 |
+
in_channels,
|
118 |
+
in_channels * 2,
|
119 |
+
kernel_size,
|
120 |
+
1,
|
121 |
+
bias=bias,
|
122 |
+
dilation=dilation,
|
123 |
+
padding=(kernel_size - 1) // 2 * dilation,
|
124 |
+
)
|
125 |
+
self.upsample = torch.nn.Upsample(
|
126 |
+
scale_factor=upsample_factor, mode=upsample_mode
|
127 |
+
)
|
128 |
+
if gated_function == "softmax":
|
129 |
+
self.gated_function = partial(torch.softmax, dim=1)
|
130 |
+
elif gated_function == "sigmoid":
|
131 |
+
self.gated_function = torch.sigmoid
|
132 |
+
else:
|
133 |
+
raise ValueError(f"{gated_function} is not supported.")
|
134 |
+
|
135 |
+
def forward(self, x, c):
|
136 |
+
"""Calculate forward propagation.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
x (Tensor): Input tensor (B, in_channels, T).
|
140 |
+
c (Tensor): Auxiliary input tensor (B, aux_channels, T').
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Tensor: Output tensor (B, in_channels, T * in_upsample_factor).
|
144 |
+
Tensor: Upsampled auxirialy tensor (B, in_channels, T * in_upsample_factor).
|
145 |
+
|
146 |
+
"""
|
147 |
+
residual = x
|
148 |
+
|
149 |
+
x, c = self.tade1(x, c)
|
150 |
+
x = self.gated_conv1(x)
|
151 |
+
xa, xb = x.split(x.size(1) // 2, dim=1)
|
152 |
+
x = self.gated_function(xa) * torch.tanh(xb)
|
153 |
+
|
154 |
+
x, c = self.tade2(x, c)
|
155 |
+
x = self.gated_conv2(x)
|
156 |
+
xa, xb = x.split(x.size(1) // 2, dim=1)
|
157 |
+
x = self.gated_function(xa) * torch.tanh(xb)
|
158 |
+
|
159 |
+
# NOTE(kan-bayashi): Return upsampled aux here?
|
160 |
+
return self.upsample(residual) + x, c
|
vec2wav2/layers/upsample.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""Upsampling module.
|
4 |
+
|
5 |
+
This code is modified from https://github.com/r9y9/wavenet_vocoder.
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from vec2wav2.layers import Conv1d
|
14 |
+
|
15 |
+
|
16 |
+
class Stretch2d(torch.nn.Module):
|
17 |
+
"""Stretch2d module."""
|
18 |
+
|
19 |
+
def __init__(self, x_scale, y_scale, mode="nearest"):
|
20 |
+
"""Initialize Stretch2d module.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
x_scale (int): X scaling factor (Time axis in spectrogram).
|
24 |
+
y_scale (int): Y scaling factor (Frequency axis in spectrogram).
|
25 |
+
mode (str): Interpolation mode.
|
26 |
+
|
27 |
+
"""
|
28 |
+
super(Stretch2d, self).__init__()
|
29 |
+
self.x_scale = x_scale
|
30 |
+
self.y_scale = y_scale
|
31 |
+
self.mode = mode
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
"""Calculate forward propagation.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (Tensor): Input tensor (B, C, F, T).
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
41 |
+
|
42 |
+
"""
|
43 |
+
return F.interpolate(
|
44 |
+
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
class Conv2d(torch.nn.Conv2d):
|
49 |
+
"""Conv2d module with customized initialization."""
|
50 |
+
|
51 |
+
def __init__(self, *args, **kwargs):
|
52 |
+
"""Initialize Conv2d module."""
|
53 |
+
super(Conv2d, self).__init__(*args, **kwargs)
|
54 |
+
|
55 |
+
def reset_parameters(self):
|
56 |
+
"""Reset parameters."""
|
57 |
+
self.weight.data.fill_(1.0 / np.prod(self.kernel_size))
|
58 |
+
if self.bias is not None:
|
59 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
60 |
+
|
61 |
+
|
62 |
+
class UpsampleNetwork(torch.nn.Module):
|
63 |
+
"""Upsampling network module."""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
upsample_scales,
|
68 |
+
nonlinear_activation=None,
|
69 |
+
nonlinear_activation_params={},
|
70 |
+
interpolate_mode="nearest",
|
71 |
+
freq_axis_kernel_size=1,
|
72 |
+
use_causal_conv=False,
|
73 |
+
):
|
74 |
+
"""Initialize upsampling network module.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
upsample_scales (list): List of upsampling scales.
|
78 |
+
nonlinear_activation (str): Activation function name.
|
79 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
80 |
+
interpolate_mode (str): Interpolation mode.
|
81 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
82 |
+
|
83 |
+
"""
|
84 |
+
super(UpsampleNetwork, self).__init__()
|
85 |
+
self.use_causal_conv = use_causal_conv
|
86 |
+
self.up_layers = torch.nn.ModuleList()
|
87 |
+
for scale in upsample_scales:
|
88 |
+
# interpolation layer
|
89 |
+
stretch = Stretch2d(scale, 1, interpolate_mode)
|
90 |
+
self.up_layers += [stretch]
|
91 |
+
|
92 |
+
# conv layer
|
93 |
+
assert (
|
94 |
+
freq_axis_kernel_size - 1
|
95 |
+
) % 2 == 0, "Not support even number freq axis kernel size."
|
96 |
+
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
97 |
+
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
98 |
+
if use_causal_conv:
|
99 |
+
padding = (freq_axis_padding, scale * 2)
|
100 |
+
else:
|
101 |
+
padding = (freq_axis_padding, scale)
|
102 |
+
conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
103 |
+
self.up_layers += [conv]
|
104 |
+
|
105 |
+
# nonlinear
|
106 |
+
if nonlinear_activation is not None:
|
107 |
+
nonlinear = getattr(torch.nn, nonlinear_activation)(
|
108 |
+
**nonlinear_activation_params
|
109 |
+
)
|
110 |
+
self.up_layers += [nonlinear]
|
111 |
+
|
112 |
+
def forward(self, c):
|
113 |
+
"""Calculate forward propagation.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
c : Input tensor (B, C, T).
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
|
120 |
+
|
121 |
+
"""
|
122 |
+
c = c.unsqueeze(1) # (B, 1, C, T)
|
123 |
+
for f in self.up_layers:
|
124 |
+
if self.use_causal_conv and isinstance(f, Conv2d):
|
125 |
+
c = f(c)[..., : c.size(-1)]
|
126 |
+
else:
|
127 |
+
c = f(c)
|
128 |
+
return c.squeeze(1) # (B, C, T')
|
129 |
+
|
130 |
+
|
131 |
+
class ConvInUpsampleNetwork(torch.nn.Module):
|
132 |
+
"""Convolution + upsampling network module."""
|
133 |
+
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
upsample_scales,
|
137 |
+
nonlinear_activation=None,
|
138 |
+
nonlinear_activation_params={},
|
139 |
+
interpolate_mode="nearest",
|
140 |
+
freq_axis_kernel_size=1,
|
141 |
+
aux_channels=80,
|
142 |
+
aux_context_window=0,
|
143 |
+
use_causal_conv=False,
|
144 |
+
):
|
145 |
+
"""Initialize convolution + upsampling network module.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
upsample_scales (list): List of upsampling scales.
|
149 |
+
nonlinear_activation (str): Activation function name.
|
150 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
151 |
+
mode (str): Interpolation mode.
|
152 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
153 |
+
aux_channels (int): Number of channels of pre-convolutional layer.
|
154 |
+
aux_context_window (int): Context window size of the pre-convolutional layer.
|
155 |
+
use_causal_conv (bool): Whether to use causal structure.
|
156 |
+
|
157 |
+
"""
|
158 |
+
super(ConvInUpsampleNetwork, self).__init__()
|
159 |
+
self.aux_context_window = aux_context_window
|
160 |
+
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
161 |
+
# To capture wide-context information in conditional features
|
162 |
+
kernel_size = (
|
163 |
+
aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
164 |
+
)
|
165 |
+
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
166 |
+
self.conv_in = Conv1d(
|
167 |
+
aux_channels, aux_channels, kernel_size=kernel_size, bias=False
|
168 |
+
)
|
169 |
+
self.upsample = UpsampleNetwork(
|
170 |
+
upsample_scales=upsample_scales,
|
171 |
+
nonlinear_activation=nonlinear_activation,
|
172 |
+
nonlinear_activation_params=nonlinear_activation_params,
|
173 |
+
interpolate_mode=interpolate_mode,
|
174 |
+
freq_axis_kernel_size=freq_axis_kernel_size,
|
175 |
+
use_causal_conv=use_causal_conv,
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(self, c):
|
179 |
+
"""Calculate forward propagation.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
c : Input tensor (B, C, T').
|
183 |
+
|
184 |
+
Returns:
|
185 |
+
Tensor: Upsampled tensor (B, C, T),
|
186 |
+
where T = (T' - aux_context_window * 2) * prod(upsample_scales).
|
187 |
+
|
188 |
+
Note:
|
189 |
+
The length of inputs considers the context window size.
|
190 |
+
|
191 |
+
"""
|
192 |
+
c_ = self.conv_in(c)
|
193 |
+
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
|
194 |
+
return self.upsample(c)
|
vec2wav2/losses/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .adversarial_loss import * # NOQA
|
2 |
+
from .feat_match_loss import * # NOQA
|
3 |
+
from .mel_loss import * # NOQA
|
4 |
+
from .stft_loss import * # NOQA
|