cantabile-kwok commited on
Commit
05005db
·
1 Parent(s): 8bd60fe

prepare demo page

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +51 -0
  2. pretrained/WavLM-Large.pt +3 -0
  3. pretrained/config.yml +201 -0
  4. pretrained/generator.ckpt +3 -0
  5. pretrained/vq-wav2vec_kmeans.pt +3 -0
  6. requirements.txt +25 -0
  7. vec2wav2/__init__.py +3 -0
  8. vec2wav2/__pycache__/__init__.cpython-310.pyc +0 -0
  9. vec2wav2/__pycache__/__init__.cpython-311.pyc +0 -0
  10. vec2wav2/__pycache__/__init__.cpython-39.pyc +0 -0
  11. vec2wav2/bin/.DS_Store +0 -0
  12. vec2wav2/bin/__init__.py +0 -0
  13. vec2wav2/bin/__pycache__/__init__.cpython-310.pyc +0 -0
  14. vec2wav2/bin/__pycache__/vc.cpython-310.pyc +0 -0
  15. vec2wav2/bin/decode.py +163 -0
  16. vec2wav2/bin/gradio_app.py +51 -0
  17. vec2wav2/bin/train.py +1007 -0
  18. vec2wav2/bin/vc.py +128 -0
  19. vec2wav2/datasets/__init__.py +1 -0
  20. vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  21. vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  22. vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc +0 -0
  23. vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc +0 -0
  24. vec2wav2/datasets/scp_dataset.py +300 -0
  25. vec2wav2/distributed/__init__.py +0 -0
  26. vec2wav2/distributed/launch.py +163 -0
  27. vec2wav2/layers/__init__.py +6 -0
  28. vec2wav2/layers/__pycache__/__init__.cpython-310.pyc +0 -0
  29. vec2wav2/layers/__pycache__/__init__.cpython-39.pyc +0 -0
  30. vec2wav2/layers/__pycache__/activations.cpython-310.pyc +0 -0
  31. vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc +0 -0
  32. vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc +0 -0
  33. vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc +0 -0
  34. vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc +0 -0
  35. vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc +0 -0
  36. vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc +0 -0
  37. vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc +0 -0
  38. vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc +0 -0
  39. vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc +0 -0
  40. vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc +0 -0
  41. vec2wav2/layers/__pycache__/upsample.cpython-310.pyc +0 -0
  42. vec2wav2/layers/__pycache__/upsample.cpython-39.pyc +0 -0
  43. vec2wav2/layers/activations.py +197 -0
  44. vec2wav2/layers/causal_conv.py +66 -0
  45. vec2wav2/layers/pqmf.py +150 -0
  46. vec2wav2/layers/residual_block.py +222 -0
  47. vec2wav2/layers/residual_stack.py +85 -0
  48. vec2wav2/layers/tade_res_block.py +160 -0
  49. vec2wav2/layers/upsample.py +194 -0
  50. vec2wav2/losses/__init__.py +4 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import gradio as gr
5
+ import logging
6
+ import yaml
7
+ import soundfile as sf
8
+ import os
9
+ from pathlib import Path
10
+ from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args
11
+
12
+ # Create Gradio interface
13
+ def create_interface():
14
+ args = vc_args()
15
+ logger = configure_logging(args.verbose)
16
+ voice_converter = VoiceConverter(
17
+ expdir=args.expdir,
18
+ token_extractor=args.token_extractor,
19
+ prompt_extractor=args.prompt_extractor,
20
+ prompt_output_layer=args.prompt_output_layer,
21
+ checkpoint=args.checkpoint,
22
+ script_logger=logger
23
+ )
24
+ with gr.Blocks(title="Voice Conversion") as demo:
25
+ gr.Markdown("# vec2wav 2.0 Voice Conversion Demo")
26
+ gr.Markdown("Upload source audio and target speaker audio to convert the voice.")
27
+
28
+ with gr.Row():
29
+ source_audio = gr.Audio(label="Source Audio", type="filepath")
30
+ target_audio = gr.Audio(label="Target Speaker Audio", type="filepath")
31
+
32
+ examples = [
33
+ ["examples/Zuckerberg.wav", "examples/Rachel.wav"],
34
+ ["examples/TheresaMay.wav", "examples/OptimusPrime.wav"]
35
+ ]
36
+ gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio])
37
+
38
+ convert_btn = gr.Button("Convert Voice")
39
+ output_audio = gr.Audio(label="Converted Audio")
40
+
41
+ convert_btn.click(
42
+ fn=voice_converter.voice_conversion,
43
+ inputs=[source_audio, target_audio],
44
+ outputs=output_audio
45
+ )
46
+
47
+ return demo
48
+
49
+ if __name__ == "__main__":
50
+ demo = create_interface()
51
+ demo.launch(share=True)
pretrained/WavLM-Large.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f
3
+ size 1261965425
pretrained/config.yml ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ allow_cache: false
2
+ batch_frames: 3600
3
+ config: conf/ctxv2w.v1.yaml
4
+ crop_max_frames: 100
5
+ discriminator_adv_loss_params:
6
+ average_by_discriminators: false
7
+ discriminator_grad_norm: -1
8
+ discriminator_optimizer_params:
9
+ betas:
10
+ - 0.5
11
+ - 0.9
12
+ lr: 0.0002
13
+ weight_decay: 0.0
14
+ discriminator_optimizer_type: Adam
15
+ discriminator_params:
16
+ follow_official_norm: true
17
+ period_discriminator_params:
18
+ bias: true
19
+ channels: 32
20
+ downsample_scales:
21
+ - 3
22
+ - 3
23
+ - 3
24
+ - 3
25
+ - 1
26
+ in_channels: 1
27
+ kernel_sizes:
28
+ - 5
29
+ - 3
30
+ max_downsample_channels: 1024
31
+ nonlinear_activation: LeakyReLU
32
+ nonlinear_activation_params:
33
+ negative_slope: 0.1
34
+ out_channels: 1
35
+ use_spectral_norm: false
36
+ use_weight_norm: true
37
+ periods:
38
+ - 2
39
+ - 3
40
+ - 5
41
+ - 7
42
+ - 11
43
+ scale_discriminator_params:
44
+ bias: true
45
+ channels: 128
46
+ downsample_scales:
47
+ - 4
48
+ - 4
49
+ - 4
50
+ - 4
51
+ - 1
52
+ in_channels: 1
53
+ kernel_sizes:
54
+ - 15
55
+ - 41
56
+ - 5
57
+ - 3
58
+ max_downsample_channels: 1024
59
+ max_groups: 16
60
+ nonlinear_activation: LeakyReLU
61
+ nonlinear_activation_params:
62
+ negative_slope: 0.1
63
+ out_channels: 1
64
+ scale_downsample_pooling: AvgPool1d
65
+ scale_downsample_pooling_params:
66
+ kernel_size: 4
67
+ padding: 2
68
+ stride: 2
69
+ scales: 3
70
+ discriminator_scheduler_params:
71
+ gamma: 0.5
72
+ milestones:
73
+ - 200000
74
+ - 400000
75
+ - 600000
76
+ - 800000
77
+ discriminator_scheduler_type: MultiStepLR
78
+ discriminator_train_start_steps: 0
79
+ discriminator_type: HiFiGANMultiScaleMultiPeriodDiscriminator
80
+ distributed: true
81
+ dropout_features: 0.0
82
+ eval_interval_steps: 100000
83
+ feat_match_loss_params:
84
+ average_by_discriminators: false
85
+ average_by_layers: false
86
+ include_final_outputs: false
87
+ frontend_mel_prediction_stop_steps: 200000
88
+ frontend_params:
89
+ conformer_params:
90
+ activation_type: swish
91
+ attention_dim: 184
92
+ attention_dropout_rate: 0.2
93
+ attention_heads: 2
94
+ cnn_module_kernel: 31
95
+ concat_after: false
96
+ dropout_rate: 0.2
97
+ linear_units: 1536
98
+ macaron_style: true
99
+ normalize_before: true
100
+ num_blocks: 2
101
+ pos_enc_layer_type: rel_pos
102
+ positional_dropout_rate: 0.2
103
+ positionwise_conv_kernel_size: 3
104
+ positionwise_layer_type: conv1d
105
+ selfattention_layer_type: rel_selfattn
106
+ use_cnn_module: true
107
+ prompt_channels: 1024
108
+ vqvec_channels: 512
109
+ generator_adv_loss_params:
110
+ average_by_discriminators: false
111
+ generator_grad_norm: -1
112
+ generator_optimizer_params:
113
+ betas:
114
+ - 0.5
115
+ - 0.9
116
+ lr: 0.0002
117
+ weight_decay: 0.0
118
+ generator_optimizer_type: Adam
119
+ generator_params:
120
+ bias: true
121
+ channels: 512
122
+ condition_dim: 1024
123
+ in_channels: 184
124
+ kernel_size: 7
125
+ nonlinear_activation: snakebeta-condition
126
+ out_channels: 1
127
+ resblock: '1'
128
+ resblock_dilations:
129
+ - - 1
130
+ - 3
131
+ - 5
132
+ - - 1
133
+ - 3
134
+ - 5
135
+ - - 1
136
+ - 3
137
+ - 5
138
+ resblock_kernel_sizes:
139
+ - 3
140
+ - 7
141
+ - 11
142
+ snake_logscale: true
143
+ upsample_kernel_sizes:
144
+ - 16
145
+ - 10
146
+ - 6
147
+ - 4
148
+ upsample_scales:
149
+ - 8
150
+ - 5
151
+ - 3
152
+ - 2
153
+ use_additional_convs: true
154
+ use_weight_norm: true
155
+ generator_scheduler_params:
156
+ gamma: 0.5
157
+ milestones:
158
+ - 200000
159
+ - 400000
160
+ - 600000
161
+ - 800000
162
+ generator_scheduler_type: MultiStepLR
163
+ generator_train_start_steps: 1
164
+ generator_type: BigVGAN
165
+ hop_size: 240
166
+ lambda_adv: 1.0
167
+ lambda_aux: 45.0
168
+ lambda_feat_match: 2.0
169
+ lambda_frontend_mel_prediction: 60
170
+ log_interval_steps: 1000
171
+ max_num_frames: 3000
172
+ mel_loss_params:
173
+ fft_size: 2048
174
+ fmax: 8000
175
+ fmin: 40
176
+ fs: 24000
177
+ hop_size: 300
178
+ log_base: null
179
+ num_mels: 80
180
+ win_length: 1200
181
+ window: hann
182
+ min_num_frames: 600
183
+ num_mels: 80
184
+ num_save_intermediate_results: 4
185
+ num_workers: 8
186
+ outdir: exp/train_all_ctxv2w.v1
187
+ pin_memory: true
188
+ pretrain: ''
189
+ prompt_fold_by_2: true
190
+ prompt_net_type: ConvPromptPrenet
191
+ rank: 0
192
+ sampling_rate: 24000
193
+ save_interval_steps: 10000
194
+ use_feat_match_loss: true
195
+ use_mel_loss: true
196
+ use_stft_loss: false
197
+ verbose: 1
198
+ version: 0.5.3
199
+ vq_codebook: feats/vqidx/codebook.npy
200
+ win_length: 697
201
+ world_size: 4
pretrained/generator.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a10b9df62462bbf48382970ffba267b458b00b361bcb245701e3d3c0b6bd19f
3
+ size 161604549
pretrained/vq-wav2vec_kmeans.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c975a93479dc5f3cfc4339032e1547c6034eddd15eb1cba73364c20786b42a5a
3
+ size 336509919
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchaudio==0.13.1
2
+ auraloss==0.4.0
3
+ cython==3.0.10
4
+ einops
5
+ debugpy==1.8.0
6
+ fairseq==0.12.2
7
+ filelock~=3.12.2
8
+ h5py
9
+ kaldiio~=2.18.0
10
+ librosa==0.8.1
11
+ matplotlib~=3.4.3
12
+ nltk==3.8.1
13
+ numpy
14
+ pathlib~=1.0.1
15
+ pyyaml~=6.0
16
+ scikit-learn
17
+ scipy~=1.7.1
18
+ setuptools==65.6.3
19
+ six==1.16.0
20
+ soundfile~=0.10.3.post1
21
+ sox
22
+ tensorboard
23
+ tensorboardx~=2.5.1
24
+ tqdm~=4.62.3
25
+ transformers==4.42.3
vec2wav2/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __version__ = ""
vec2wav2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (211 Bytes). View file
 
vec2wav2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (289 Bytes). View file
 
vec2wav2/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (225 Bytes). View file
 
vec2wav2/bin/.DS_Store ADDED
Binary file (6.15 kB). View file
 
vec2wav2/bin/__init__.py ADDED
File without changes
vec2wav2/bin/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (199 Bytes). View file
 
vec2wav2/bin/__pycache__/vc.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
vec2wav2/bin/decode.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Tomoki Hayashi
5
+ # MIT License (https://opensource.org/licenses/MIT)
6
+
7
+ # Modified by Yiwei Guo, 2024
8
+
9
+ """Decode with trained vec2wav Generator."""
10
+
11
+ import argparse
12
+ import logging
13
+ import os
14
+ import time
15
+
16
+ import numpy as np
17
+ import soundfile as sf
18
+ import torch
19
+ import yaml
20
+
21
+ from tqdm import tqdm
22
+
23
+ from vec2wav2.datasets import MelSCPDataset
24
+ from vec2wav2.utils import load_model, load_feat_codebook, idx2vec
25
+
26
+
27
+ def set_loglevel(verbose):
28
+ # set logger
29
+ if verbose > 1:
30
+ logging.basicConfig(
31
+ level=logging.DEBUG,
32
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
33
+ )
34
+ elif verbose > 0:
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
38
+ )
39
+ else:
40
+ logging.basicConfig(
41
+ level=logging.WARN,
42
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
43
+ )
44
+ logging.warning("Skip DEBUG/INFO messages")
45
+
46
+
47
+ def main():
48
+ """Run decoding process."""
49
+ parser = argparse.ArgumentParser(
50
+ description="Decode from audio tokens and acoustic prompts with trained vec2wav model"
51
+ "(See detail in vec2wav2/bin/decode.py)."
52
+ )
53
+ parser.add_argument(
54
+ "--feats-scp",
55
+ "--scp",
56
+ default=None,
57
+ type=str,
58
+ required=True,
59
+ help="kaldi-style feats.scp file. "
60
+ )
61
+ parser.add_argument(
62
+ "--prompt-scp",
63
+ default=None,
64
+ type=str,
65
+ help="kaldi-style prompt.scp file. Similar to feats.scp."
66
+ )
67
+ parser.add_argument(
68
+ "--outdir",
69
+ type=str,
70
+ required=True,
71
+ help="directory to save generated speech.",
72
+ )
73
+ parser.add_argument(
74
+ "--checkpoint",
75
+ type=str,
76
+ required=True,
77
+ help="checkpoint file to be loaded.",
78
+ )
79
+ parser.add_argument(
80
+ "--config",
81
+ default=None,
82
+ type=str,
83
+ help="yaml format configuration file. if not explicitly provided, "
84
+ "it will be searched in the checkpoint directory. (default=None)",
85
+ )
86
+ parser.add_argument(
87
+ "--verbose",
88
+ type=int,
89
+ default=1,
90
+ help="logging level. higher is more logging. (default=1)",
91
+ )
92
+ args = parser.parse_args()
93
+ set_loglevel(args.verbose)
94
+
95
+ # check directory existence
96
+ if not os.path.exists(args.outdir):
97
+ os.makedirs(args.outdir)
98
+
99
+ # load config
100
+ if args.config is None:
101
+ dirname = os.path.dirname(args.checkpoint)
102
+ args.config = os.path.join(dirname, "config.yml")
103
+ with open(args.config) as f:
104
+ config = yaml.load(f, Loader=yaml.Loader)
105
+ config.update(vars(args))
106
+
107
+ # get dataset
108
+ dataset = MelSCPDataset(
109
+ vqidx_scp=args.feats_scp,
110
+ prompt_scp=args.prompt_scp,
111
+ return_utt_id=True,
112
+ )
113
+ logging.info(f"The number of features to be decoded = {len(dataset)}.")
114
+
115
+ # setup model
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+ logging.info(f"Using {'GPU' if torch.cuda.is_available() else 'CPU'}.")
118
+
119
+ model = load_model(args.checkpoint, config)
120
+ logging.info(f"Loaded model parameters from {args.checkpoint}.")
121
+
122
+ model.backend.remove_weight_norm()
123
+ model = model.eval().to(device)
124
+
125
+ # load vq codebook
126
+ feat_codebook, feat_codebook_numgroups = load_feat_codebook(np.load(config["vq_codebook"], allow_pickle=True), device)
127
+
128
+ # start generation
129
+ total_rtf = 0.0
130
+ with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
131
+ for idx, batch in enumerate(pbar, 1):
132
+ utt_id, vqidx, prompt = batch[0], batch[1], batch[2]
133
+
134
+ vqidx = torch.tensor(vqidx).to(device) # (L, G)
135
+ prompt = torch.tensor(prompt).unsqueeze(0).to(device) # (1, L', D')
136
+
137
+ vqidx = vqidx.long()
138
+ vqvec = idx2vec(feat_codebook, vqidx, feat_codebook_numgroups).unsqueeze(0) # (1, L, D)
139
+
140
+ # generate
141
+ start = time.time()
142
+ y = model.inference(vqvec, prompt)[-1].view(-1)
143
+ rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
144
+ pbar.set_postfix({"RTF": rtf})
145
+ total_rtf += rtf
146
+
147
+ tgt_dir = os.path.dirname(os.path.join(config["outdir"], f"{utt_id}.wav"))
148
+ os.makedirs(tgt_dir, exist_ok=True)
149
+ basename = os.path.basename(f"{utt_id}.wav")
150
+ # save as PCM 16 bit wav file
151
+ sf.write(
152
+ os.path.join(tgt_dir, basename),
153
+ y.cpu().numpy(),
154
+ config["sampling_rate"],
155
+ "PCM_16",
156
+ )
157
+
158
+ # report average RTF
159
+ logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()
vec2wav2/bin/gradio_app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import gradio as gr
5
+ import logging
6
+ import yaml
7
+ import soundfile as sf
8
+ import os
9
+ from pathlib import Path
10
+ from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args
11
+
12
+ # Create Gradio interface
13
+ def create_interface():
14
+ args = vc_args()
15
+ logger = configure_logging(args.verbose)
16
+ voice_converter = VoiceConverter(
17
+ expdir=args.expdir,
18
+ token_extractor=args.token_extractor,
19
+ prompt_extractor=args.prompt_extractor,
20
+ prompt_output_layer=args.prompt_output_layer,
21
+ checkpoint=args.checkpoint,
22
+ script_logger=logger
23
+ )
24
+ with gr.Blocks(title="Voice Conversion") as demo:
25
+ gr.Markdown("# vec2wav 2.0 Voice Conversion Demo")
26
+ gr.Markdown("Upload source audio and target speaker audio to convert the voice.")
27
+
28
+ with gr.Row():
29
+ source_audio = gr.Audio(label="Source Audio", type="filepath")
30
+ target_audio = gr.Audio(label="Target Speaker Audio", type="filepath")
31
+
32
+ examples = [
33
+ ["examples/Zuckerberg.wav", "examples/Rachel.wav"],
34
+ ["examples/TheresaMay.wav", "examples/OptimusPrime.wav"]
35
+ ]
36
+ gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio])
37
+
38
+ convert_btn = gr.Button("Convert Voice")
39
+ output_audio = gr.Audio(label="Converted Audio")
40
+
41
+ convert_btn.click(
42
+ fn=voice_converter.voice_conversion,
43
+ inputs=[source_audio, target_audio],
44
+ outputs=output_audio
45
+ )
46
+
47
+ return demo
48
+
49
+ if __name__ == "__main__":
50
+ demo = create_interface()
51
+ demo.launch(share=True)
vec2wav2/bin/train.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Tomoki Hayashi
5
+ # MIT License (https://opensource.org/licenses/MIT)
6
+
7
+ # Modified by Yiwei Guo, 2024
8
+
9
+ """Train vec2wav."""
10
+
11
+ import argparse
12
+ import logging
13
+ import os
14
+ import sys
15
+ import random
16
+
17
+ from collections import defaultdict
18
+
19
+ import matplotlib
20
+ import numpy as np
21
+ import soundfile as sf
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import yaml
25
+ import torch.multiprocessing as mp
26
+ from tensorboardX import SummaryWriter
27
+ from torch.utils.data import DataLoader
28
+ from tqdm import tqdm
29
+
30
+ import vec2wav2
31
+ import vec2wav2.models
32
+ import vec2wav2.optimizers
33
+ from torch.utils.data.distributed import DistributedSampler
34
+
35
+ from vec2wav2.datasets import AudioMelSCPDataset
36
+ from vec2wav2.layers import PQMF
37
+ from vec2wav2.losses import DiscriminatorAdversarialLoss
38
+ from vec2wav2.losses import FeatureMatchLoss
39
+ from vec2wav2.losses import GeneratorAdversarialLoss
40
+ from vec2wav2.losses import MelSpectrogramLoss
41
+ from vec2wav2.losses import MultiResolutionSTFTLoss
42
+ from vec2wav2.utils import crop_seq, load_feat_codebook, idx2vec
43
+
44
+ from vec2wav2.utils.espnet_utils import pad_list, make_non_pad_mask
45
+
46
+ # set to avoid matplotlib error in CLI environment
47
+ matplotlib.use("Agg")
48
+
49
+
50
+ def set_loglevel(verbose):
51
+ # set logger
52
+ if verbose > 1:
53
+ logging.basicConfig(
54
+ level=logging.DEBUG,
55
+ stream=sys.stdout,
56
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
57
+ )
58
+ elif verbose > 0:
59
+ logging.basicConfig(
60
+ level=logging.INFO,
61
+ stream=sys.stdout,
62
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
63
+ )
64
+ else:
65
+ logging.basicConfig(
66
+ level=logging.WARN,
67
+ stream=sys.stdout,
68
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
69
+ )
70
+ logging.warning("Skip DEBUG/INFO messages")
71
+
72
+
73
+ class Trainer(object):
74
+ """Customized trainer module for Parallel WaveGAN training."""
75
+
76
+ def __init__(
77
+ self,
78
+ steps,
79
+ epochs,
80
+ data_loader,
81
+ sampler,
82
+ model,
83
+ criterion,
84
+ optimizer,
85
+ scheduler,
86
+ config,
87
+ device=torch.device("cpu"),
88
+ ):
89
+ """Initialize trainer.
90
+
91
+ Args:
92
+ steps (int): Initial global steps.
93
+ epochs (int): Initial global epochs.
94
+ data_loader (dict): Dict of data loaders. It must contain "train" and "dev" loaders.
95
+ model (dict): Dict of models. It must contain "generator" and "discriminator" models.
96
+ criterion (dict): Dict of criteria. It must contain "stft" and "mse" criteria.
97
+ optimizer (dict): Dict of optimizers. It must contain "generator" and "discriminator" optimizers.
98
+ scheduler (dict): Dict of schedulers. It must contain "generator" and "discriminator" schedulers.
99
+ config (dict): Config dict loaded from yaml format configuration file.
100
+ device (torch.deive): Pytorch device instance.
101
+
102
+ """
103
+ self.steps = steps
104
+ self.epochs = epochs
105
+ self.data_loader = data_loader
106
+ self.sampler = sampler
107
+ self.model = model
108
+ self.criterion = criterion
109
+ self.optimizer = optimizer
110
+ self.scheduler = scheduler
111
+ self.config = config
112
+ self.device = device
113
+ self.writer = SummaryWriter(config["outdir"])
114
+ self.finish_train = False
115
+ self.total_train_loss = defaultdict(float)
116
+ self.total_eval_loss = defaultdict(float)
117
+
118
+ # load vq codebook
119
+ feat_codebook_path = self.config["vq_codebook"]
120
+
121
+ self.feat_codebook, self.feat_codebook_numgroups = load_feat_codebook(np.load(feat_codebook_path, allow_pickle=True), device)
122
+
123
+ def run(self):
124
+ """Run training."""
125
+ self.tqdm = tqdm(initial=self.steps, total=self.config["train_max_steps"], desc="[train]")
126
+ while True:
127
+ # train one epoch
128
+ self._train_epoch()
129
+
130
+ # check whether training is finished
131
+ if self.finish_train:
132
+ break
133
+
134
+ self.tqdm.close()
135
+ logging.info("Finished training.")
136
+
137
+ def save_checkpoint(self, checkpoint_path):
138
+ """Save checkpoint.
139
+ Args:
140
+ checkpoint_path (str): Checkpoint path to be saved.
141
+ """
142
+ state_dict = {
143
+ "optimizer": {
144
+ "generator": self.optimizer["generator"].state_dict(),
145
+ "discriminator": self.optimizer["discriminator"].state_dict(),
146
+ },
147
+ "scheduler": {
148
+ "generator": self.scheduler["generator"].state_dict(),
149
+ "discriminator": self.scheduler["discriminator"].state_dict(),
150
+ },
151
+ "steps": self.steps,
152
+ "epochs": self.epochs,
153
+ }
154
+ if self.config["distributed"]:
155
+ state_dict["model"] = {
156
+ "generator": self.model["generator"].module.state_dict(),
157
+ "discriminator": self.model["discriminator"].module.state_dict(),
158
+ }
159
+ else:
160
+ state_dict["model"] = {
161
+ "generator": self.model["generator"].state_dict(),
162
+ "discriminator": self.model["discriminator"].state_dict(),
163
+ }
164
+
165
+ if not os.path.exists(os.path.dirname(checkpoint_path)):
166
+ os.makedirs(os.path.dirname(checkpoint_path))
167
+ torch.save(state_dict, checkpoint_path)
168
+
169
+ def load_checkpoint(self, checkpoint_path, load_only_params=False):
170
+ """Load checkpoint.
171
+
172
+ Args:
173
+ checkpoint_path (str): Checkpoint path to be loaded.
174
+ load_only_params (bool): Whether to load only model parameters.
175
+
176
+ """
177
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
178
+ if self.config["distributed"]:
179
+ self.model["generator"].module.load_state_dict(
180
+ state_dict["model"]["generator"]
181
+ )
182
+ self.model["discriminator"].module.load_state_dict(
183
+ state_dict["model"]["discriminator"]
184
+ )
185
+ else:
186
+ self.model["generator"].load_state_dict(state_dict["model"]["generator"])
187
+ self.model["discriminator"].load_state_dict(
188
+ state_dict["model"]["discriminator"]
189
+ )
190
+ if not load_only_params:
191
+ self.steps = state_dict["steps"]
192
+ self.epochs = state_dict["epochs"]
193
+ self.optimizer["generator"].load_state_dict(state_dict["optimizer"]["generator"])
194
+ self.optimizer["discriminator"].load_state_dict(state_dict["optimizer"]["discriminator"])
195
+ self.scheduler["generator"].load_state_dict(state_dict["scheduler"]["generator"])
196
+ self.scheduler["discriminator"].load_state_dict(state_dict["scheduler"]["discriminator"])
197
+
198
+ def _train_step(self, batch):
199
+ """Train model one step."""
200
+ # parse batch
201
+ vqidx, mel, prompt, y, xlens, prompt_lens = batch
202
+ vqidx = vqidx.to(self.device)
203
+ mel = mel.to(self.device)
204
+ prompt = prompt.to(self.device)
205
+ vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups) # (B, L, D)
206
+ y = y.unsqueeze(-2).to(self.device) # (B, 1, T)
207
+
208
+ # build mask
209
+ mask = make_non_pad_mask(xlens).to(self.device) # (B, L)
210
+ prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt)
211
+
212
+ # crop wav sequence
213
+ crop_xlen = min(self.config["crop_max_frames"], min(xlens))
214
+ x_offsets = [np.random.randint(0, l - crop_xlen + 1) for l in xlens]
215
+ crop_ylen = crop_xlen * self.config["hop_size"]
216
+ y_offsets = [o * self.config["hop_size"] for o in x_offsets]
217
+ y = crop_seq(y, y_offsets, crop_ylen)
218
+
219
+ #######################
220
+ # Generator #
221
+ #######################
222
+ if self.steps > self.config.get("generator_train_start_steps", 0):
223
+ mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T)
224
+
225
+ # initialize
226
+ gen_loss, aux_loss = 0.0, 0.0
227
+
228
+ # frontend mel prediction loss
229
+ if self.steps <= self.config.get("frontend_mel_prediction_stop_steps", 0):
230
+ frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)),
231
+ torch.masked_select(mel_, mask.unsqueeze(-1)))
232
+ self.total_train_loss["train/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item()
233
+ gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss
234
+
235
+ # multi-resolution sfft loss
236
+ if self.config["use_stft_loss"]:
237
+ sc_loss, mag_loss = self.criterion["stft"](y_, y)
238
+ aux_loss += sc_loss + mag_loss
239
+ self.total_train_loss["train/spectral_convergence_loss"] += sc_loss.item()
240
+ self.total_train_loss["train/log_stft_magnitude_loss"] += mag_loss.item()
241
+
242
+ # subband multi-resolution stft loss
243
+ if self.config["use_subband_stft_loss"]:
244
+ aux_loss *= 0.5 # for balancing with subband stft loss
245
+ y_mb = self.criterion["pqmf"].analysis(y)
246
+ y_mb_ = self.criterion["pqmf"].analysis(y_)
247
+ sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
248
+ aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
249
+ self.total_train_loss["train/sub_spectral_convergence_loss"] += sub_sc_loss.item()
250
+ self.total_train_loss["train/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
251
+
252
+ # mel spectrogram loss
253
+ if self.config["use_mel_loss"]:
254
+ mel_loss = self.criterion["mel"](y_, y)
255
+ aux_loss += mel_loss
256
+ self.total_train_loss["train/mel_loss"] += mel_loss.item()
257
+
258
+ # weighting aux loss
259
+ gen_loss += self.config.get("lambda_aux", 1.0) * aux_loss
260
+
261
+ # adversarial loss
262
+ if self.steps > self.config["discriminator_train_start_steps"]:
263
+ p_ = self.model["discriminator"](y_)
264
+ adv_loss = self.criterion["gen_adv"](p_)
265
+ self.total_train_loss["train/adversarial_loss"] += adv_loss.item()
266
+
267
+ # feature matching loss
268
+ if self.config["use_feat_match_loss"]:
269
+ # no need to track gradients
270
+ with torch.no_grad():
271
+ p = self.model["discriminator"](y)
272
+ fm_loss = self.criterion["feat_match"](p_, p)
273
+ self.total_train_loss["train/feature_matching_loss"] += fm_loss.item()
274
+ adv_loss += self.config["lambda_feat_match"] * fm_loss
275
+
276
+ # add adversarial loss to generator loss
277
+ gen_loss += self.config["lambda_adv"] * adv_loss
278
+
279
+ self.total_train_loss["train/generator_loss"] += gen_loss.item()
280
+
281
+ # update generator
282
+ self.optimizer["generator"].zero_grad()
283
+ gen_loss.backward()
284
+ if self.config["generator_grad_norm"] > 0:
285
+ torch.nn.utils.clip_grad_norm_(
286
+ self.model["generator"].parameters(),
287
+ self.config["generator_grad_norm"],
288
+ )
289
+ self.optimizer["generator"].step()
290
+ self.scheduler["generator"].step()
291
+
292
+ #######################
293
+ # Discriminator #
294
+ #######################
295
+ if self.steps > self.config["discriminator_train_start_steps"]:
296
+ # re-compute y_ which leads better quality
297
+ with torch.no_grad():
298
+ # logging.info(f"{vqvec.shape, prompt.shape, mask.shape, prompt_mask.shape}")
299
+ _, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T)
300
+
301
+ if self.config["generator_params"]["out_channels"] > 1:
302
+ y_ = self.criterion["pqmf"].synthesis(y_)
303
+
304
+ # discriminator loss
305
+ p = self.model["discriminator"](y)
306
+ p_ = self.model["discriminator"](y_.detach())
307
+ real_loss, fake_loss = self.criterion["dis_adv"](p_, p)
308
+ dis_loss = real_loss + fake_loss
309
+ self.total_train_loss["train/real_loss"] += real_loss.item()
310
+ self.total_train_loss["train/fake_loss"] += fake_loss.item()
311
+ self.total_train_loss["train/discriminator_loss"] += dis_loss.item()
312
+
313
+ # update discriminator
314
+ self.optimizer["discriminator"].zero_grad()
315
+ dis_loss.backward()
316
+ if self.config["discriminator_grad_norm"] > 0:
317
+ torch.nn.utils.clip_grad_norm_(
318
+ self.model["discriminator"].parameters(),
319
+ self.config["discriminator_grad_norm"],
320
+ )
321
+ self.optimizer["discriminator"].step()
322
+ self.scheduler["discriminator"].step()
323
+
324
+ # update counts
325
+ self.steps += 1
326
+ self.tqdm.update(1)
327
+ self._check_train_finish()
328
+
329
+ def _train_epoch(self):
330
+ """Train model one epoch."""
331
+ for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
332
+ # train one step
333
+ self._train_step(batch)
334
+
335
+ # check interval
336
+ if self.config["rank"] == 0:
337
+ self._check_log_interval()
338
+ self._check_eval_interval()
339
+ self._check_save_interval()
340
+
341
+ # check whether training is finished
342
+ if self.finish_train:
343
+ return
344
+
345
+ # update
346
+ self.epochs += 1
347
+ self.train_steps_per_epoch = train_steps_per_epoch
348
+ logging.info(
349
+ f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
350
+ f"({self.train_steps_per_epoch} steps per epoch)."
351
+ )
352
+
353
+ # needed for shuffle in distributed training
354
+ if self.config["distributed"]:
355
+ self.sampler["train"].set_epoch(self.epochs)
356
+
357
+ @torch.no_grad()
358
+ def _eval_step(self, batch):
359
+ """Evaluate model one step."""
360
+ # parse batch
361
+ vqidx, mel, prompt, y, xlens, prompt_lens = batch
362
+ vqidx = vqidx.to(self.device).long()
363
+ mel = mel.to(self.device)
364
+ prompt = prompt.to(self.device)
365
+ vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups)
366
+ y = y.unsqueeze(-2).to(self.device) # (B, 1, T)
367
+
368
+ # build mask
369
+ mask = make_non_pad_mask(xlens).to(self.device) # (B, L)
370
+ prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt)
371
+
372
+ #######################
373
+ # Generator #
374
+ #######################
375
+ mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask) # (B, L, 80), (B, C, T)
376
+
377
+ # reconstruct the signal from multi-band signal
378
+ if self.config["generator_params"]["out_channels"] > 1:
379
+ y_mb_ = y_
380
+ y_ = self.criterion["pqmf"].synthesis(y_mb_)
381
+
382
+ # initialize
383
+ gen_loss = 0.0
384
+ aux_loss = 0.0
385
+
386
+ # frontend mel prediction loss
387
+ frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)),
388
+ torch.masked_select(mel_, mask.unsqueeze(-1)))
389
+ self.total_eval_loss["eval/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item()
390
+ gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss
391
+
392
+ # multi-resolution stft loss
393
+ if self.config["use_stft_loss"]:
394
+ sc_loss, mag_loss = self.criterion["stft"](y_, y)
395
+ aux_loss += sc_loss + mag_loss
396
+ self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item()
397
+ self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item()
398
+
399
+ # subband multi-resolution stft loss
400
+ if self.config.get("use_subband_stft_loss", False):
401
+ aux_loss *= 0.5 # for balancing with subband stft loss
402
+ y_mb = self.criterion["pqmf"].analysis(y)
403
+ sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
404
+ self.total_eval_loss["eval/sub_spectral_convergence_loss"] += sub_sc_loss.item()
405
+ self.total_eval_loss["eval/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
406
+ aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
407
+
408
+ # mel spectrogram loss
409
+ if self.config["use_mel_loss"]:
410
+ mel_loss = self.criterion["mel"](y_, y)
411
+ aux_loss += mel_loss
412
+ self.total_eval_loss["eval/mel_loss"] += mel_loss.item()
413
+
414
+ # weighting stft loss
415
+ gen_loss += aux_loss * self.config.get("lambda_aux", 1.0)
416
+
417
+ # adversarial loss
418
+ p_ = self.model["discriminator"](y_)
419
+ adv_loss = self.criterion["gen_adv"](p_)
420
+ gen_loss += self.config["lambda_adv"] * adv_loss
421
+
422
+ # feature matching loss
423
+ if self.config["use_feat_match_loss"]:
424
+ p = self.model["discriminator"](y)
425
+ fm_loss = self.criterion["feat_match"](p_, p)
426
+ self.total_eval_loss["eval/feature_matching_loss"] += fm_loss.item()
427
+ gen_loss += (
428
+ self.config["lambda_adv"] * self.config["lambda_feat_match"] * fm_loss
429
+ )
430
+
431
+ #######################
432
+ # Discriminator #
433
+ #######################
434
+ p = self.model["discriminator"](y)
435
+ p_ = self.model["discriminator"](y_)
436
+
437
+ # discriminator loss
438
+ real_loss, fake_loss = self.criterion["dis_adv"](p_, p)
439
+ dis_loss = real_loss + fake_loss
440
+
441
+ # add to total eval loss
442
+ self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item()
443
+ self.total_eval_loss["eval/generator_loss"] += gen_loss.item()
444
+ self.total_eval_loss["eval/real_loss"] += real_loss.item()
445
+ self.total_eval_loss["eval/fake_loss"] += fake_loss.item()
446
+ self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item()
447
+
448
+ def _eval_epoch(self):
449
+ """Evaluate model one epoch."""
450
+ logging.info(f"(Steps: {self.steps}) Start evaluation.")
451
+ # change mode
452
+ for key in self.model.keys():
453
+ self.model[key].eval()
454
+
455
+ # calculate loss for each batch
456
+ for eval_steps_per_epoch, batch in enumerate(tqdm(self.data_loader["dev"], desc="[eval]"), 1):
457
+ # eval one step
458
+ self._eval_step(batch)
459
+
460
+ logging.info(
461
+ f"(Steps: {self.steps}) Finished evaluation "
462
+ f"({eval_steps_per_epoch} steps per epoch)."
463
+ )
464
+
465
+ # average loss
466
+ for key in self.total_eval_loss.keys():
467
+ self.total_eval_loss[key] /= eval_steps_per_epoch
468
+ logging.info(f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}.")
469
+
470
+ # record
471
+ self._write_to_tensorboard(self.total_eval_loss)
472
+
473
+ # reset
474
+ self.total_eval_loss = defaultdict(float)
475
+
476
+ # restore mode
477
+ for key in self.model.keys():
478
+ self.model[key].train()
479
+
480
+ def _write_to_tensorboard(self, loss):
481
+ """Write to tensorboard."""
482
+ for key, value in loss.items():
483
+ self.writer.add_scalar(key, value, self.steps)
484
+
485
+ def _check_save_interval(self):
486
+ if self.steps % self.config["save_interval_steps"] == 0:
487
+ self.save_checkpoint(os.path.join(self.config["outdir"],
488
+ f"checkpoint-{self.steps}steps.pkl"))
489
+ logging.info(f"Successfully saved checkpoint @ {self.steps} steps.")
490
+
491
+ def _check_eval_interval(self):
492
+ if self.steps % self.config["eval_interval_steps"] == 0:
493
+ self._eval_epoch()
494
+
495
+ def _check_log_interval(self):
496
+ if self.steps % self.config["log_interval_steps"] == 0:
497
+ for key in self.total_train_loss.keys():
498
+ self.total_train_loss[key] /= self.config["log_interval_steps"]
499
+ logging.info(f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}.")
500
+ self._write_to_tensorboard(self.total_train_loss)
501
+
502
+ # reset
503
+ self.total_train_loss = defaultdict(float)
504
+
505
+ def _check_train_finish(self):
506
+ if self.steps >= self.config["train_max_steps"]:
507
+ self.finish_train = True
508
+
509
+
510
+ class Collator(object):
511
+ """Customized collator for Pytorch DataLoader in training."""
512
+
513
+ def __init__(
514
+ self,
515
+ hop_size=256,
516
+ win_length=1024,
517
+ sampling_rate=16000,
518
+ prompt_dim=1024,
519
+ prompt_fold_by_2=False
520
+ ):
521
+ """Initialize customized collator for PyTorch DataLoader.
522
+
523
+ Args:
524
+ hop_size (int): Hop size of features, in sampling points.
525
+ win_length (int): window length of features.
526
+ sampling_rate (int): sampling rate of waveform data
527
+ prompt_dim (int): number of prompt embedding dimensions
528
+ """
529
+ self.hop_size = hop_size
530
+ self.win_length = win_length
531
+ self.sampling_rate = sampling_rate
532
+ self.prompt_dim = prompt_dim
533
+ if prompt_fold_by_2:
534
+ self.prompt_len_factor = 2
535
+ else:
536
+ self.prompt_len_factor = 1
537
+
538
+ def construct_prompt(self, mel_lens):
539
+ prompt_lens = [random.randint(int(l / (3 * self.prompt_len_factor)), int(l / (2 * self.prompt_len_factor))) for l in mel_lens]
540
+ prompt_starts = []
541
+ is_from_start = []
542
+ for ml, pl in zip(mel_lens, prompt_lens):
543
+ if random.random() > 0.5:
544
+ # from start
545
+ prompt_start = random.randint(0, 1 * self.sampling_rate // (self.hop_size * self.prompt_len_factor))
546
+ is_from_start.append(True)
547
+ else:
548
+ # from ending
549
+ prompt_start = random.randint((ml - 1 * self.sampling_rate // self.hop_size) // self.prompt_len_factor, ml // self.prompt_len_factor) - pl
550
+ is_from_start.append(False)
551
+ prompt_starts.append(prompt_start)
552
+ return prompt_lens, prompt_starts, is_from_start
553
+
554
+ def __call__(self, batch):
555
+ """Convert into batch tensors.
556
+
557
+ Args:
558
+ batch (list): list of tuple of the pair of audio and features.
559
+
560
+ This collator will automatically determine the prompt segment (acoustic context) for each utterance.
561
+ The prompt is cut off from the current utterance, ranging from one third to half of the original utterance.
562
+ The prompt can be cut from either the starting or the ending of the utterance, within 1 second margin.
563
+ The other features include 2-dim VQ features (2 is the number of groups), and D-dim prompts (e.g. WavLM features)
564
+
565
+ Returns:
566
+ Tensor ys: waveform batch (B, T).
567
+ Tensors vqs, mels: Auxiliary feature batch (B, C, T'), where T' = T / hop_size.
568
+ Tensor prompts: prompt feature batch (B, C, T'')
569
+ List c_lengths, prompt_lengths: list of lengths
570
+ """
571
+ batch = batch[0]
572
+
573
+ # check length
574
+ batch = [self._adjust_length(*b) for b in batch]
575
+ ys, vqs, mels, prompts_old = list(map(list, zip(*batch))) # [(a,b), (c,d)] -> [a, c], [b, d]
576
+
577
+ batch_size = len(vqs)
578
+
579
+ prompt_lengths, prompt_starts, is_from_starts = self.construct_prompt([len(m) for m in mels])
580
+ c_lengths = []
581
+ prompts = torch.zeros(batch_size, max(prompt_lengths), self.prompt_dim)
582
+ for i in range(batch_size):
583
+ prompts[i, :prompt_lengths[i]] = torch.tensor(prompts_old[i][prompt_starts[i]:prompt_starts[i]+prompt_lengths[i], :])
584
+ if is_from_starts[i]:
585
+ start_idx = (prompt_starts[i] + prompt_lengths[i])*self.prompt_len_factor
586
+ mels[i] = mels[i][start_idx:]
587
+ vqs[i] = vqs[i][start_idx:]
588
+ ys[i] = ys[i][start_idx * self.hop_size: ]
589
+ else:
590
+ end_idx = prompt_starts[i]*self.prompt_len_factor
591
+ mels[i] = mels[i][:end_idx]
592
+ vqs[i] = vqs[i][:end_idx]
593
+ ys[i] = ys[i][:end_idx * self.hop_size]
594
+ c_lengths.append(len(mels[i]))
595
+
596
+ vqs = pad_list([torch.tensor(c) for c in vqs], pad_value=0) # (B, L, Groups)
597
+ vqs = vqs.long()
598
+ mels = pad_list([torch.tensor(c) for c in mels], pad_value=0) # (B, L, 80)
599
+
600
+ ys = pad_list([torch.tensor(y, dtype=torch.float) for y in ys], pad_value=0)[:, :mels.size(1) * self.hop_size] # (B, T)
601
+ assert ys.size(1) == mels.size(1) * self.hop_size == vqs.size(1) * self.hop_size
602
+
603
+ return vqs, mels, prompts, ys, c_lengths, prompt_lengths
604
+
605
+ def _adjust_length(self, x, c, *args):
606
+ """Adjust the audio and feature lengths.
607
+
608
+ Note:
609
+ Basically we assume that the length of x and c are adjusted
610
+ through preprocessing stage, but if we use other library processed
611
+ features, this process will be needed.
612
+
613
+ """
614
+ if len(x) > len(c) * self.hop_size:
615
+ x = x[(self.win_length - self.hop_size) // 2:]
616
+ x = x[:len(c) * self.hop_size]
617
+
618
+ # check the legnth is valid
619
+ assert len(x) == len(c) * self.hop_size
620
+
621
+ return x, c, *args
622
+
623
+
624
+ def main(rank, n_gpus):
625
+ """Run training process."""
626
+ parser = argparse.ArgumentParser(
627
+ description="Train vec2wav2 (See detail in vec2wav2/bin/train.py)."
628
+ )
629
+ parser.add_argument(
630
+ "--train-wav-scp",
631
+ default=None,
632
+ type=str,
633
+ help="kaldi-style wav.scp file for training. "
634
+ )
635
+ parser.add_argument(
636
+ "--train-vqidx-scp",
637
+ default=None,
638
+ type=str,
639
+ help="kaldi-style feats.scp file for training. "
640
+ )
641
+ parser.add_argument(
642
+ "--train-mel-scp",
643
+ default=None,
644
+ type=str,
645
+ help="kaldi-style feats.scp file for training. "
646
+ )
647
+ parser.add_argument(
648
+ "--train-prompt-scp",
649
+ default=None,
650
+ type=str,
651
+ help="prompt scp (in this case, utt to path)"
652
+ )
653
+ parser.add_argument(
654
+ "--train-segments",
655
+ default=None,
656
+ type=str,
657
+ help="kaldi-style segments file for training.",
658
+ )
659
+ parser.add_argument(
660
+ "--train-num-frames",
661
+ default=None,
662
+ type=str,
663
+ help="kaldi-style utt2num_frames file for training.",
664
+ )
665
+ parser.add_argument(
666
+ "--dev-wav-scp",
667
+ default=None,
668
+ type=str,
669
+ help="kaldi-style wav.scp file for validation. "
670
+ )
671
+ parser.add_argument(
672
+ "--dev-vqidx-scp",
673
+ default=None,
674
+ type=str,
675
+ help="kaldi-style feats.scp file for vaidation. "
676
+ )
677
+ parser.add_argument(
678
+ "--dev-mel-scp",
679
+ default=None,
680
+ type=str,
681
+ help="kaldi-style feats.scp file for vaidation. "
682
+ )
683
+ parser.add_argument(
684
+ "--dev-prompt-scp",
685
+ default=None,
686
+ type=str,
687
+ help="prompt scp (in this case, utt to path)"
688
+ )
689
+ parser.add_argument(
690
+ "--dev-segments",
691
+ default=None,
692
+ type=str,
693
+ help="kaldi-style segments file for validation.",
694
+ )
695
+ parser.add_argument(
696
+ "--dev-num-frames",
697
+ default=None,
698
+ type=str,
699
+ help="kaldi-style utt2num_frames file for validation.",
700
+ )
701
+ parser.add_argument(
702
+ "--outdir",
703
+ type=str,
704
+ required=True,
705
+ help="directory to save checkpoints.",
706
+ )
707
+ parser.add_argument(
708
+ "--config",
709
+ type=str,
710
+ required=True,
711
+ help="yaml format configuration file.",
712
+ )
713
+ parser.add_argument(
714
+ "--pretrain",
715
+ default="",
716
+ type=str,
717
+ nargs="?",
718
+ help='checkpoint file path to load pretrained params. (default="")',
719
+ )
720
+ parser.add_argument(
721
+ "--resume",
722
+ default="",
723
+ type=str,
724
+ nargs="?",
725
+ help='checkpoint file path to resume training. (default="")',
726
+ )
727
+ parser.add_argument(
728
+ "--verbose",
729
+ type=int,
730
+ default=1,
731
+ help="logging level. higher is more logging. (default=1)",
732
+ )
733
+ parser.add_argument("--vq-codebook", default=None, type=str)
734
+ # parser.add_argument("--sampling-rate", type=int)
735
+ # parser.add_argument("--num-mels", type=int)
736
+ # parser.add_argument("--hop-size", type=int)
737
+ # parser.add_argument("--win-length", type=int)
738
+ args = parser.parse_args()
739
+
740
+ # init distributed training
741
+ device = torch.device("cuda")
742
+ # effective when using fixed size inputs
743
+ # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
744
+ torch.backends.cudnn.benchmark = True
745
+ # setup for distributed training
746
+ # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
747
+ if n_gpus == 1:
748
+ assert rank == 0
749
+
750
+ set_loglevel(args.verbose)
751
+
752
+ # check directory existence
753
+ if not os.path.exists(args.outdir):
754
+ os.makedirs(args.outdir)
755
+
756
+ # init process group
757
+ logging.info("Synchronizing between all workers.")
758
+ torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank)
759
+ torch.cuda.set_device(rank)
760
+ logging.info("Finished init process group.")
761
+
762
+ # load and save config
763
+ with open(args.config) as f:
764
+ config = yaml.load(f, Loader=yaml.Loader)
765
+ config.update(vars(args))
766
+ config['rank'] = rank
767
+ config['distributed'] = True
768
+ config['world_size'] = n_gpus
769
+ config["version"] = vec2wav2.__version__ # add version info
770
+ if rank == 0:
771
+ with open(os.path.join(args.outdir, "config.yml"), "w") as f:
772
+ yaml.dump(config, f, Dumper=yaml.Dumper)
773
+ for key, value in config.items():
774
+ logging.info(f"{key} = {value}")
775
+
776
+ # get dataset
777
+ train_dataset = AudioMelSCPDataset(
778
+ wav_scp=args.train_wav_scp,
779
+ vqidx_scp=args.train_vqidx_scp,
780
+ mel_scp=args.train_mel_scp,
781
+ prompt_scp=args.train_prompt_scp,
782
+ utt2num_frames=args.train_num_frames,
783
+ segments=args.train_segments,
784
+ batch_frames=config.get("batch_frames", None),
785
+ batch_size=config.get("batch_size", None),
786
+ min_num_frames=config.get("min_num_frames", None),
787
+ max_num_frames=config.get("max_num_frames", None),
788
+ allow_cache=config.get("allow_cache", False), # keep compatibility
789
+ length_tolerance=config.get("length_tolerance", 2),
790
+ prompt_fold_by_2=config.get("prompt_fold_by_2", True)
791
+ )
792
+ if rank == 0:
793
+ logging.info(f"The number of training batches = {len(train_dataset)}.")
794
+ dev_dataset = AudioMelSCPDataset(
795
+ wav_scp=args.dev_wav_scp,
796
+ vqidx_scp=args.dev_vqidx_scp,
797
+ mel_scp=args.dev_mel_scp,
798
+ prompt_scp=args.dev_prompt_scp,
799
+ utt2num_frames=args.dev_num_frames,
800
+ segments=args.dev_segments,
801
+ min_num_frames=config.get("min_num_frames", None),
802
+ max_num_frames=config.get("max_num_frames", None),
803
+ allow_cache=config.get("allow_cache", False), # keep compatibility
804
+ length_tolerance=config.get("length_tolerance", 2),
805
+ prompt_fold_by_2=config.get("prompt_fold_by_2", True)
806
+ )
807
+ if rank == 0:
808
+ logging.info(f"The number of development batches = {len(dev_dataset)}.")
809
+ dataset = {
810
+ "train": train_dataset,
811
+ "dev": dev_dataset,
812
+ }
813
+
814
+ # get data loader
815
+ collator = Collator(
816
+ hop_size=config["hop_size"],
817
+ win_length=config["win_length"],
818
+ sampling_rate=config["sampling_rate"],
819
+ prompt_dim=config['frontend_params']['prompt_channels'],
820
+ prompt_fold_by_2=config.get("prompt_fold_by_2", True)
821
+ )
822
+
823
+ sampler = {
824
+ "train": DistributedSampler(
825
+ dataset=dataset["train"],
826
+ num_replicas=n_gpus,
827
+ rank=rank,
828
+ shuffle=True,
829
+ ),
830
+ "dev": DistributedSampler(
831
+ dataset=dataset["dev"],
832
+ num_replicas=n_gpus,
833
+ rank=rank,
834
+ shuffle=False,
835
+ )}
836
+ data_loader = {
837
+ "train": DataLoader(
838
+ dataset=dataset["train"],
839
+ shuffle=False,
840
+ collate_fn=collator,
841
+ num_workers=config["num_workers"],
842
+ sampler=sampler["train"],
843
+ pin_memory=config["pin_memory"],
844
+ ),
845
+ "dev": DataLoader(
846
+ dataset=dataset["dev"],
847
+ shuffle=False,
848
+ collate_fn=collator,
849
+ num_workers=config["num_workers"],
850
+ sampler=sampler["dev"],
851
+ pin_memory=config["pin_memory"],
852
+ ),
853
+ }
854
+
855
+ # define models
856
+ generator_class = getattr(
857
+ vec2wav2.models,
858
+ # keep compatibility
859
+ config.get("generator_type", "ParallelWaveGANGenerator"),
860
+ )
861
+ discriminator_class = getattr(
862
+ vec2wav2.models,
863
+ # keep compatibility
864
+ config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
865
+ )
866
+ model = {
867
+ "generator": vec2wav2.models.VEC2WAV2Generator(
868
+ vec2wav2.models.CTXVEC2WAVFrontend(config["prompt_net_type"], config["num_mels"], **config["frontend_params"]),
869
+ generator_class(**config["generator_params"])
870
+ ).to(device),
871
+ "discriminator": discriminator_class(
872
+ **config["discriminator_params"],
873
+ ).to(device),
874
+ }
875
+
876
+ # define criteria
877
+ criterion = {
878
+ "gen_adv": GeneratorAdversarialLoss(
879
+ # keep compatibility
880
+ **config.get("generator_adv_loss_params", {})
881
+ ).to(device),
882
+ "dis_adv": DiscriminatorAdversarialLoss(
883
+ # keep compatibility
884
+ **config.get("discriminator_adv_loss_params", {})
885
+ ).to(device),
886
+ }
887
+ if config.get("use_stft_loss", True): # keep compatibility
888
+ config["use_stft_loss"] = True
889
+ criterion["stft"] = MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device)
890
+ if config.get("use_subband_stft_loss", False): # keep compatibility
891
+ assert config["generator_params"]["out_channels"] > 1
892
+ criterion["sub_stft"] = MultiResolutionSTFTLoss(**config["subband_stft_loss_params"]).to(device)
893
+ else:
894
+ config["use_subband_stft_loss"] = False
895
+ if config.get("use_feat_match_loss", False): # keep compatibility
896
+ criterion["feat_match"] = FeatureMatchLoss(
897
+ # keep compatibility
898
+ **config.get("feat_match_loss_params", {}),
899
+ ).to(device)
900
+ else:
901
+ config["use_feat_match_loss"] = False
902
+ if config.get("use_mel_loss", False): # keep compatibility
903
+ criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"],).to(device)
904
+ else:
905
+ config["use_mel_loss"] = False
906
+
907
+ # define optimizers and schedulers
908
+ generator_optimizer_class = getattr(
909
+ vec2wav2.optimizers,
910
+ # keep compatibility
911
+ config.get("generator_optimizer_type", "RAdam"),
912
+ )
913
+ discriminator_optimizer_class = getattr(
914
+ vec2wav2.optimizers,
915
+ # keep compatibility
916
+ config.get("discriminator_optimizer_type", "RAdam"),
917
+ )
918
+ optimizer = {
919
+ "generator": generator_optimizer_class(
920
+ model["generator"].parameters(),
921
+ **config["generator_optimizer_params"],
922
+ ),
923
+ "discriminator": discriminator_optimizer_class(
924
+ model["discriminator"].parameters(),
925
+ **config["discriminator_optimizer_params"],
926
+ ),
927
+ }
928
+ generator_scheduler_class = getattr(
929
+ torch.optim.lr_scheduler,
930
+ # keep compatibility
931
+ config.get("generator_scheduler_type", "StepLR"),
932
+ )
933
+ discriminator_scheduler_class = getattr(
934
+ torch.optim.lr_scheduler,
935
+ # keep compatibility
936
+ config.get("discriminator_scheduler_type", "StepLR"),
937
+ )
938
+ scheduler = {
939
+ "generator": generator_scheduler_class(
940
+ optimizer=optimizer["generator"],
941
+ **config["generator_scheduler_params"],
942
+ ),
943
+ "discriminator": discriminator_scheduler_class(
944
+ optimizer=optimizer["discriminator"],
945
+ **config["discriminator_scheduler_params"],
946
+ ),
947
+ }
948
+ from torch.nn.parallel import DistributedDataParallel
949
+ model["generator"] = DistributedDataParallel(model["generator"], device_ids=[rank], find_unused_parameters=True)
950
+ model["discriminator"] = DistributedDataParallel(model["discriminator"], device_ids=[rank], find_unused_parameters=True)
951
+
952
+ if rank == 0:
953
+ # show settings
954
+ logging.info(model["generator"])
955
+ logging.info(f"Generator has nparams: {sum([p.numel() for p in model['generator'].parameters()])}")
956
+ logging.info(model["discriminator"])
957
+ logging.info(f"Discriminator has nparams: {sum([p.numel() for p in model['discriminator'].parameters()])}")
958
+ logging.info(optimizer["generator"])
959
+ logging.info(optimizer["discriminator"])
960
+
961
+ # define trainer
962
+ trainer = Trainer(
963
+ steps=0,
964
+ epochs=0,
965
+ data_loader=data_loader,
966
+ sampler=sampler,
967
+ model=model,
968
+ criterion=criterion,
969
+ optimizer=optimizer,
970
+ scheduler=scheduler,
971
+ config=config,
972
+ device=device,
973
+ )
974
+
975
+ # load pretrained parameters from checkpoint
976
+ if len(args.pretrain) != 0:
977
+ trainer.load_checkpoint(args.pretrain, load_only_params=True)
978
+ if rank == 0:
979
+ logging.info(f"Successfully load parameters from {args.pretrain}.")
980
+
981
+ # resume from checkpoint
982
+ if len(args.resume) != 0:
983
+ trainer.load_checkpoint(args.resume)
984
+ if rank == 0:
985
+ logging.info(f"Successfully resumed from {args.resume}.")
986
+
987
+ # run training loop
988
+ try:
989
+ trainer.run()
990
+ finally:
991
+ if rank == 0:
992
+ trainer.save_checkpoint(os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl"))
993
+ logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
994
+
995
+
996
+ if __name__ == "__main__":
997
+ assert torch.cuda.is_available(), "CPU training is not allowed."
998
+ n_gpus = torch.cuda.device_count()
999
+ print(f"============> using {n_gpus} GPUS")
1000
+ os.environ["MASTER_ADDR"] = "localhost"
1001
+ os.environ["MASTER_PORT"] = "8000"
1002
+
1003
+ mp.spawn(
1004
+ main,
1005
+ nprocs=n_gpus,
1006
+ args=(n_gpus,)
1007
+ )
vec2wav2/bin/vc.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright 2024 Yiwei Guo
4
+
5
+ """ Run VC inference with trained model """
6
+
7
+ import vec2wav2
8
+ from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor
9
+ from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor
10
+ # from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor
11
+ import torch
12
+ import logging
13
+ import argparse
14
+ from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k
15
+ import soundfile as sf
16
+ import yaml
17
+ import os
18
+
19
+
20
+ def configure_logging(verbose):
21
+ if verbose:
22
+ logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG)
23
+ logging.getLogger().setLevel(logging.DEBUG)
24
+ logging.basicConfig(level=logging.DEBUG)
25
+ else:
26
+ logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR)
27
+ logging.getLogger().setLevel(logging.ERROR)
28
+ logging.basicConfig(level=logging.ERROR)
29
+
30
+ script_logger = logging.getLogger("script_logger")
31
+ handler = logging.StreamHandler()
32
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s'))
33
+ script_logger.addHandler(handler)
34
+ script_logger.setLevel(logging.INFO)
35
+ script_logger.propagate = False
36
+ return script_logger
37
+
38
+ def vc_args():
39
+ parser = argparse.ArgumentParser()
40
+ # required arguments
41
+ parser.add_argument("-s", "--source", default="examples/source.wav", type=str,
42
+ help="source wav path")
43
+ parser.add_argument("-t", "--target", default="examples/target.wav", type=str,
44
+ help="target speaker prompt path")
45
+ parser.add_argument("-o", "--output", default="output.wav", type=str,
46
+ help="path of the output wav file")
47
+
48
+ # optional arguments
49
+ parser.add_argument("--expdir", default="pretrained/", type=str,
50
+ help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.")
51
+ parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.")
52
+ parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str,
53
+ help="checkpoint or model flag of input token extractor")
54
+ parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str,
55
+ help="checkpoint or model flag of speaker prompt extractor")
56
+ parser.add_argument("--prompt-output-layer", default=6, type=int,
57
+ help="output layer when prompt is extracted from WavLM.")
58
+
59
+ parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
60
+
61
+ args = parser.parse_args()
62
+ return args
63
+
64
+
65
+ class VoiceConverter:
66
+ def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt",
67
+ prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6,
68
+ checkpoint=None, script_logger=None):
69
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ self.script_logger = script_logger
71
+ self.log_if_possible(f"Using device: {self.device}")
72
+
73
+ # set up token extractor
74
+ self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device)
75
+ feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device)
76
+ self.feat_codebook = feat_codebook
77
+ self.feat_codebook_numgroups = feat_codebook_numgroups
78
+ self.log_if_possible(f"Successfully set up token extractor from {token_extractor}")
79
+
80
+ # set up prompt extractor
81
+ self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer)
82
+ self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}")
83
+
84
+ # load VC model
85
+ self.config_path = os.path.join(expdir, "config.yml")
86
+ with open(self.config_path) as f:
87
+ self.config = yaml.load(f, Loader=yaml.Loader)
88
+ if checkpoint is not None:
89
+ checkpoint = os.path.join(expdir, checkpoint)
90
+ else:
91
+ checkpoint = os.path.join(expdir, "generator.ckpt")
92
+ self.model = load_model(checkpoint, self.config)
93
+ self.log_if_possible(f"Successfully set up VC model from {checkpoint}")
94
+
95
+ self.model.backend.remove_weight_norm()
96
+ self.model.eval().to(self.device)
97
+
98
+ @torch.no_grad()
99
+ def voice_conversion(self, source_audio, target_audio, output_path="output.wav"):
100
+ self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}")
101
+ source_wav = read_wav_16k(source_audio)
102
+ target_wav = read_wav_16k(target_audio)
103
+ vq_idx = self.token_extractor.extract(source_wav).long().to(self.device)
104
+
105
+ vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0)
106
+ prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device)
107
+ converted = self.model.inference(vqvec, prompt)[-1].view(-1)
108
+ sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate'])
109
+ self.log_if_possible(f"Saved audio file to {output_path}")
110
+ return output_path
111
+
112
+ def log_if_possible(self, msg):
113
+ if self.script_logger is not None:
114
+ self.script_logger.info(msg)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ args = vc_args()
119
+ script_logger = configure_logging(args.verbose)
120
+
121
+ source_wav = read_wav_16k(args.source)
122
+ target_prompt = read_wav_16k(args.target)
123
+
124
+ with torch.no_grad():
125
+ voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor,
126
+ prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer,
127
+ checkpoint=args.checkpoint, script_logger=script_logger)
128
+ voice_converter.voice_conversion(args.source, args.target, args.output)
vec2wav2/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scp_dataset import * # NOQA
vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (288 Bytes). View file
 
vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (241 Bytes). View file
 
vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc ADDED
Binary file (8.4 kB). View file
 
vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc ADDED
Binary file (8.95 kB). View file
 
vec2wav2/datasets/scp_dataset.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ # Modified by Yiwei Guo, 2024
7
+
8
+ """Dataset modules based on kaldi-style scp files."""
9
+
10
+ import logging
11
+ import random
12
+ import copy
13
+ from multiprocessing import Manager
14
+
15
+ import kaldiio
16
+ import numpy as np
17
+
18
+ from torch.utils.data import Dataset
19
+ from tqdm import tqdm
20
+ from vec2wav2.utils import HDF5ScpLoader
21
+ from vec2wav2.utils import NpyScpLoader
22
+
23
+
24
+ def _get_feats_scp_loader(feats_scp):
25
+ # read the first line of feats.scp file
26
+ with open(feats_scp) as f:
27
+ key, value = f.readlines()[0].replace("\n", "").split()
28
+
29
+ # check scp type
30
+ if ":" in value:
31
+ value_1, value_2 = value.split(":")
32
+ if value_1.endswith(".ark"):
33
+ # kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index
34
+ return kaldiio.load_scp(feats_scp)
35
+ elif value_1.endswith(".h5"):
36
+ # hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats
37
+ return HDF5ScpLoader(feats_scp)
38
+ else:
39
+ raise ValueError("Not supported feats.scp type.")
40
+ else:
41
+ if value.endswith(".h5"):
42
+ # hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5
43
+ return HDF5ScpLoader(feats_scp)
44
+ elif value.endswith(".npy"):
45
+ # npy case: utt_id_1 /path/to/utt_id_1.npy
46
+ return NpyScpLoader(feats_scp)
47
+ else:
48
+ raise ValueError("Not supported feats.scp type.")
49
+
50
+
51
+ class AudioMelSCPDataset(Dataset):
52
+ """PyTorch compatible audio and feat dataset based on kaldi-stype scp files."""
53
+
54
+ def __init__(
55
+ self,
56
+ wav_scp,
57
+ vqidx_scp,
58
+ mel_scp,
59
+ prompt_scp,
60
+ utt2num_frames=None,
61
+ segments=None,
62
+ batch_frames=None,
63
+ batch_size=None,
64
+ min_num_frames=None,
65
+ max_num_frames=None,
66
+ return_utt_id=False,
67
+ return_sampling_rate=False,
68
+ allow_cache=False,
69
+ length_tolerance=2,
70
+ prompt_fold_by_2=True
71
+ ):
72
+ """Initialize dataset.
73
+
74
+ Args:
75
+ wav_scp (str): Kaldi-style wav.scp file.
76
+ vqidx_scp (str): Kaldi-style fests.scp file.
77
+ mel_scp (str): Kaldi-style fests.scp file.
78
+ segments (str): Kaldi-style segments file.
79
+ min_num_frames (int): Threshold to remove short feature files.
80
+ max_num_frames (int): Threshold to remove long feature files.
81
+ return_utt_id (bool): Whether to return utterance id.
82
+ return_sampling_rate (bool): Whether to return sampling rate.
83
+ allow_cache (bool): Whether to allow cache of the loaded files.
84
+ prompt_fold_by_2 (bool): if true, then prompt have half the length of vqidx sequence.
85
+
86
+ """
87
+ # load scp as lazy dict
88
+ self.audio_loader = kaldiio.load_scp(wav_scp, segments=segments)
89
+ self.vqidx_loader = _get_feats_scp_loader(vqidx_scp)
90
+ self.mel_loader = _get_feats_scp_loader(mel_scp)
91
+
92
+ self.prompt_loader = _get_feats_scp_loader(prompt_scp)
93
+
94
+ self.utt_ids = list(self.mel_loader.keys())
95
+ self.return_utt_id = return_utt_id
96
+ self.return_sampling_rate = return_sampling_rate
97
+ self.allow_cache = allow_cache
98
+
99
+ utt2num_frames_loader = None
100
+ if utt2num_frames is not None:
101
+ with open(utt2num_frames, 'r') as f:
102
+ utt2num_frames_loader = dict([(x.split()[0], int(x.split()[1])) for x in f.readlines()])
103
+ else:
104
+ utt2num_frames_loader = dict([(k, mel.shape[0]) for k, mel in self.mel_loader.items()])
105
+
106
+ self.utt2num_frames_loader = utt2num_frames_loader
107
+
108
+ # filter by threshold
109
+ if (min_num_frames or max_num_frames) is not None:
110
+ mel_lengths = [utt2num_frames_loader[key] for key in self.utt_ids]
111
+ idxs = [
112
+ idx
113
+ for idx in range(len(self.utt_ids))
114
+ if (min_num_frames and mel_lengths[idx] >= min_num_frames) and (max_num_frames and mel_lengths[idx] <= max_num_frames)
115
+ ]
116
+ if len(self.utt_ids) != len(idxs):
117
+ logging.warning(
118
+ f"Some files are filtered by mel length threshold "
119
+ f"({len(self.utt_ids)} -> {len(idxs)})."
120
+ )
121
+ self.utt_ids = [self.utt_ids[idx] for idx in idxs]
122
+
123
+ # batchify
124
+ if batch_frames is not None:
125
+ self.batches = self.batchify(utt2num_frames_loader, batch_frames=batch_frames)
126
+ elif batch_size is not None:
127
+ self.batches = self.batchify(utt2num_frames_loader, batch_size=batch_size)
128
+ else:
129
+ self.batches = [[utt_id] for utt_id in self.utt_ids]
130
+
131
+ if allow_cache:
132
+ # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
133
+ self.manager = Manager()
134
+ self.caches = self.manager.dict()
135
+ self.length_tolerance = length_tolerance
136
+ if prompt_fold_by_2:
137
+ self.prompt_len_factor = 2
138
+ else:
139
+ self.prompt_len_factor = 1
140
+
141
+ def batchify(self, utt2num_frames_loader, batch_frames=None, batch_size=None, min_batch_size=1, drop_last=True):
142
+
143
+ assert batch_size is None or batch_size > min_batch_size
144
+
145
+ batches = []
146
+ batch = []
147
+ accum_num_frames = 0
148
+ utt_ids_set = set(self.utt_ids)
149
+ for utt_id, mel_length in tqdm(sorted(list(utt2num_frames_loader.items()), key=lambda x: x[1], reverse=True)):
150
+ if utt_id not in utt_ids_set:
151
+ continue
152
+ if (batch_frames is not None and accum_num_frames + mel_length > batch_frames and len(batch) > min_batch_size) or (batch_size is not None and len(batch) == batch_size):
153
+ batches.append(batch)
154
+ batch = []
155
+ accum_num_frames = 0
156
+ batch.append(utt_id)
157
+ accum_num_frames += mel_length
158
+ if len(batch) > min_batch_size and not drop_last:
159
+ batches.append(batch)
160
+ return batches
161
+
162
+ def __getitem__(self, idx):
163
+ """Get specified idx items.
164
+
165
+ Args:
166
+ idx (int): Index of the item.
167
+
168
+ Returns:
169
+ str: Utterance id (only in return_utt_id = True).
170
+ ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True).
171
+ ndarrays: Features (T', C).
172
+
173
+ """
174
+ batch = self.batches[idx]
175
+ batch_items = []
176
+
177
+ for utt_id in batch:
178
+ if self.allow_cache and self.caches.get(utt_id) is not None:
179
+ items = self.caches[utt_id]
180
+ else:
181
+ fs, audio = self.audio_loader[utt_id]
182
+ mel = self.mel_loader[utt_id]
183
+ prompt = self.prompt_loader[utt_id]
184
+ vqidx = self.vqidx_loader[utt_id]
185
+
186
+ min_len = min(len(mel), len(vqidx), len(prompt)*self.prompt_len_factor)
187
+ assert ((abs(len(mel) - min_len) <= self.length_tolerance) and
188
+ (abs(len(vqidx) - min_len) <= self.length_tolerance) and
189
+ (abs(len(prompt)*self.prompt_len_factor - min_len) <= self.length_tolerance)), \
190
+ f"Audio feature lengths difference exceeds length tolerance for {utt_id}"
191
+ mel, vqidx, prompt = mel[:min_len], vqidx[:min_len], prompt[:min_len//self.prompt_len_factor]
192
+
193
+ # normalize audio signal to be [-1, 1]
194
+ audio = audio.astype(np.float32)
195
+ audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit
196
+
197
+ if self.return_sampling_rate:
198
+ audio = (audio, fs)
199
+
200
+ if self.return_utt_id:
201
+ items = utt_id, audio, vqidx, mel, prompt
202
+ else:
203
+ items = audio, vqidx, mel, prompt
204
+
205
+ if self.allow_cache:
206
+ self.caches[utt_id] = items
207
+
208
+ batch_items.append(items)
209
+
210
+ return batch_items
211
+
212
+ def __len__(self):
213
+ """Return dataset length.
214
+ Returns:
215
+ int: The length of dataset.
216
+ """
217
+ return len(self.batches)
218
+
219
+
220
+ class MelSCPDataset(Dataset):
221
+ """PyTorch compatible feat dataset based on kaldi-stype scp files."""
222
+
223
+ def __init__(
224
+ self,
225
+ vqidx_scp,
226
+ prompt_scp,
227
+ return_utt_id=False,
228
+ allow_cache=False,
229
+ ):
230
+ """Initialize dataset.
231
+
232
+ Args:
233
+ vqidx_scp (str): Kaldi-style fests.scp file.
234
+ prompt_scp (str): Kaldi-style scp file. In this file, every utt is associated with its prompt's mel-spectrogram.
235
+ min_num_frames (int): Threshold to remove short feature files.
236
+ max_num_frames (int): Threshold to remove long feature files.
237
+ return_utt_id (bool): Whether to return utterance id.
238
+ allow_cache (bool): Whether to allow cache of the loaded files.
239
+ """
240
+ # load scp as lazy dict
241
+ vqidx_loader = _get_feats_scp_loader(vqidx_scp)
242
+ self.prompt_loader = _get_feats_scp_loader(prompt_scp)
243
+ # self.prompt_loader = dict()
244
+ # with open(prompt_scp, 'r') as fr:
245
+ # for line in fr.readlines():
246
+ # terms = line.strip().split()
247
+ # self.prompt_loader[terms[0]] = terms[1]
248
+ vqidx_keys = list(set(self.prompt_loader.keys()) & set(vqidx_loader.keys()))
249
+
250
+ # NOTE: this dataset does not apply filtering, because it is usually used for decoding
251
+
252
+ self.vqidx_loader = vqidx_loader
253
+ self.utt_ids = vqidx_keys
254
+ self.return_utt_id = return_utt_id
255
+ self.allow_cache = allow_cache
256
+
257
+ if allow_cache:
258
+ # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
259
+ self.manager = Manager()
260
+ self.caches = self.manager.list()
261
+ self.caches += [() for _ in range(len(self.utt_ids))]
262
+
263
+ def __getitem__(self, idx):
264
+ """Get specified idx items.
265
+
266
+ Args:
267
+ idx (int): Index of the item.
268
+
269
+ Returns:
270
+ str: Utterance id (only in return_utt_id = True).
271
+ ndarray: Feature (T', C).
272
+
273
+ """
274
+ if self.allow_cache and len(self.caches[idx]) != 0:
275
+ return self.caches[idx]
276
+
277
+ utt_id = self.utt_ids[idx]
278
+ vqidx = self.vqidx_loader[utt_id].astype(int)
279
+
280
+ # prompt = torch.load(self.prompt_loader[utt_id]).float().numpy()
281
+ prompt = self.prompt_loader[utt_id]
282
+
283
+ if self.return_utt_id:
284
+ items = utt_id, vqidx, prompt
285
+ else:
286
+ items = vqidx, prompt
287
+
288
+ if self.allow_cache:
289
+ self.caches[idx] = items
290
+
291
+ return items
292
+
293
+ def __len__(self):
294
+ """Return dataset length.
295
+
296
+ Returns:
297
+ int: The length of dataset.
298
+
299
+ """
300
+ return len(self.utt_ids)
vec2wav2/distributed/__init__.py ADDED
File without changes
vec2wav2/distributed/launch.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Distributed process launcher.
5
+
6
+ This code is modified from https://github.com/pytorch/pytorch/blob/v1.3.0/torch/distributed/launch.py.
7
+
8
+ """
9
+ import os
10
+ import subprocess
11
+ import sys
12
+
13
+ from argparse import ArgumentParser
14
+ from argparse import REMAINDER
15
+
16
+
17
+ def parse_args():
18
+ """Parse arguments."""
19
+ parser = ArgumentParser(
20
+ description="PyTorch distributed training launch "
21
+ "helper utilty that will spawn up "
22
+ "multiple distributed processes"
23
+ )
24
+
25
+ # Optional arguments for the launch helper
26
+ parser.add_argument(
27
+ "--nnodes",
28
+ type=int,
29
+ default=1,
30
+ help="The number of nodes to use for distributed " "training",
31
+ )
32
+ parser.add_argument(
33
+ "--node_rank",
34
+ type=int,
35
+ default=0,
36
+ help="The rank of the node for multi-node distributed " "training",
37
+ )
38
+ parser.add_argument(
39
+ "--nproc_per_node",
40
+ type=int,
41
+ default=1,
42
+ help="The number of processes to launch on each node, "
43
+ "for GPU training, this is recommended to be set "
44
+ "to the number of GPUs in your system so that "
45
+ "each process can be bound to a single GPU.",
46
+ )
47
+ parser.add_argument(
48
+ "--master_addr",
49
+ default="127.0.0.1",
50
+ type=str,
51
+ help="Master node (rank 0)'s address, should be either "
52
+ "the IP address or the hostname of node 0, for "
53
+ "single node multi-proc training, the "
54
+ "--master_addr can simply be 127.0.0.1",
55
+ )
56
+ parser.add_argument(
57
+ "--master_port",
58
+ default=29500,
59
+ type=int,
60
+ help="Master node (rank 0)'s free port that needs to "
61
+ "be used for communciation during distributed "
62
+ "training",
63
+ )
64
+ parser.add_argument(
65
+ "--use_env",
66
+ default=False,
67
+ action="store_true",
68
+ help="Use environment variable to pass "
69
+ "'local rank'. For legacy reasons, the default value is False. "
70
+ "If set to True, the script will not pass "
71
+ "--local_rank as argument, and will instead set LOCAL_RANK.",
72
+ )
73
+ parser.add_argument(
74
+ "-m",
75
+ "--module",
76
+ default=False,
77
+ action="store_true",
78
+ help="Changes each process to interpret the launch script "
79
+ "as a python module, executing with the same behavior as"
80
+ "'python -m'.",
81
+ )
82
+ parser.add_argument(
83
+ "-c",
84
+ "--command",
85
+ default=False,
86
+ action="store_true",
87
+ help="Changes each process to interpret the launch script " "as a command.",
88
+ )
89
+
90
+ # positional
91
+ parser.add_argument(
92
+ "training_script",
93
+ type=str,
94
+ help="The full path to the single GPU training "
95
+ "program/script/command to be launched in parallel, "
96
+ "followed by all the arguments for the "
97
+ "training script",
98
+ )
99
+
100
+ # rest from the training program
101
+ parser.add_argument("training_script_args", nargs=REMAINDER)
102
+ return parser.parse_args()
103
+
104
+
105
+ def main():
106
+ """Launch distributed processes."""
107
+ args = parse_args()
108
+
109
+ # world size in terms of number of processes
110
+ dist_world_size = args.nproc_per_node * args.nnodes
111
+
112
+ # set PyTorch distributed related environmental variables
113
+ current_env = os.environ.copy()
114
+ current_env["MASTER_ADDR"] = args.master_addr
115
+ current_env["MASTER_PORT"] = str(args.master_port)
116
+ current_env["WORLD_SIZE"] = str(dist_world_size)
117
+
118
+ processes = []
119
+
120
+ if "OMP_NUM_THREADS" not in os.environ and args.nproc_per_node > 1:
121
+ current_env["OMP_NUM_THREADS"] = str(1)
122
+ print(
123
+ "*****************************************\n"
124
+ "Setting OMP_NUM_THREADS environment variable for each process "
125
+ "to be {} in default, to avoid your system being overloaded, "
126
+ "please further tune the variable for optimal performance in "
127
+ "your application as needed. \n"
128
+ "*****************************************".format(
129
+ current_env["OMP_NUM_THREADS"]
130
+ )
131
+ )
132
+
133
+ for local_rank in range(0, args.nproc_per_node):
134
+ # each process's rank
135
+ dist_rank = args.nproc_per_node * args.node_rank + local_rank
136
+ current_env["RANK"] = str(dist_rank)
137
+ current_env["LOCAL_RANK"] = str(local_rank)
138
+
139
+ # spawn the processes
140
+ if args.command:
141
+ cmd = [args.training_script]
142
+ else:
143
+ cmd = [sys.executable, "-u"]
144
+ if args.module:
145
+ cmd.append("-m")
146
+ cmd.append(args.training_script)
147
+
148
+ if not args.use_env:
149
+ cmd.append("--local_rank={}".format(local_rank))
150
+
151
+ cmd.extend(args.training_script_args)
152
+
153
+ process = subprocess.Popen(cmd, env=current_env)
154
+ processes.append(process)
155
+
156
+ for process in processes:
157
+ process.wait()
158
+ if process.returncode != 0:
159
+ raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()
vec2wav2/layers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .causal_conv import * # NOQA
2
+ from .pqmf import * # NOQA
3
+ from .residual_block import * # NOQA
4
+ from .residual_stack import * # NOQA
5
+ from .tade_res_block import * # NOQA
6
+ from .upsample import * # NOQA
vec2wav2/layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (344 Bytes). View file
 
vec2wav2/layers/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (353 Bytes). View file
 
vec2wav2/layers/__pycache__/activations.cpython-310.pyc ADDED
Binary file (6.64 kB). View file
 
vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc ADDED
Binary file (2.24 kB). View file
 
vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc ADDED
Binary file (4.14 kB). View file
 
vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc ADDED
Binary file (4.14 kB). View file
 
vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc ADDED
Binary file (6.21 kB). View file
 
vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc ADDED
Binary file (6.18 kB). View file
 
vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc ADDED
Binary file (2.51 kB). View file
 
vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc ADDED
Binary file (2.51 kB). View file
 
vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc ADDED
Binary file (3.56 kB). View file
 
vec2wav2/layers/__pycache__/upsample.cpython-310.pyc ADDED
Binary file (6.01 kB). View file
 
vec2wav2/layers/__pycache__/upsample.cpython-39.pyc ADDED
Binary file (6 kB). View file
 
vec2wav2/layers/activations.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ # Modified by Yiwei Guo, 2024
5
+ # including conditioned snakebeta activation
6
+
7
+ import torch
8
+ from torch import nn, sin, pow
9
+ from torch.nn import Parameter
10
+
11
+
12
+ class Snake(nn.Module):
13
+ '''
14
+ Implementation of a sine-based periodic activation function
15
+ Shape:
16
+ - Input: (B, C, T)
17
+ - Output: (B, C, T), same shape as the input
18
+ Parameters:
19
+ - alpha - trainable parameter
20
+ References:
21
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
22
+ https://arxiv.org/abs/2006.08195
23
+ Examples:
24
+ >>> a1 = snake(256)
25
+ >>> x = torch.randn(256)
26
+ >>> x = a1(x)
27
+ '''
28
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
29
+ '''
30
+ Initialization.
31
+ INPUT:
32
+ - in_features: shape of the input
33
+ - alpha: trainable parameter
34
+ alpha is initialized to 1 by default, higher values = higher-frequency.
35
+ alpha will be trained along with the rest of your model.
36
+ '''
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ '''
53
+ Forward pass of the function.
54
+ Applies the function to the input elementwise.
55
+ Snake := x + 1/a * sin^2 (xa)
56
+ '''
57
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
58
+ if self.alpha_logscale:
59
+ alpha = torch.exp(alpha)
60
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
61
+
62
+ return x
63
+
64
+
65
+ class SnakeBeta(nn.Module):
66
+ '''
67
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
68
+ Shape:
69
+ - Input: (B, C, T)
70
+ - Output: (B, C, T), same shape as the input
71
+ Parameters:
72
+ - alpha - trainable parameter that controls frequency
73
+ - beta - trainable parameter that controls magnitude
74
+ References:
75
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
76
+ https://arxiv.org/abs/2006.08195
77
+ Examples:
78
+ >>> a1 = snakebeta(256)
79
+ >>> x = torch.randn(256)
80
+ >>> x = a1(x)
81
+ '''
82
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
83
+ '''
84
+ Initialization.
85
+ INPUT:
86
+ - in_features: shape of the input
87
+ - alpha - trainable parameter that controls frequency
88
+ - beta - trainable parameter that controls magnitude
89
+ alpha is initialized to 1 by default, higher values = higher-frequency.
90
+ beta is initialized to 1 by default, higher values = higher-magnitude.
91
+ alpha will be trained along with the rest of your model.
92
+ '''
93
+ super(SnakeBeta, self).__init__()
94
+ self.in_features = in_features
95
+
96
+ # initialize alpha
97
+ self.alpha_logscale = alpha_logscale
98
+ if self.alpha_logscale: # log scale alphas initialized to zeros
99
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
100
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
101
+ else: # linear scale alphas initialized to ones
102
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
103
+ self.beta = Parameter(torch.ones(in_features) * alpha)
104
+
105
+ self.alpha.requires_grad = alpha_trainable
106
+ self.beta.requires_grad = alpha_trainable
107
+
108
+ self.no_div_by_zero = 0.000000001
109
+
110
+ def forward(self, x, cond=None):
111
+ '''
112
+ Forward pass of the function.
113
+ Applies the function to the input elementwise.
114
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
115
+ '''
116
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
117
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
118
+ if self.alpha_logscale:
119
+ alpha = torch.exp(alpha)
120
+ beta = torch.exp(beta)
121
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
122
+
123
+ return x
124
+
125
+
126
+ class SnakeBetaWithCondition(nn.Module):
127
+ '''
128
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
129
+ Shape:
130
+ - Input: (B, C, T)
131
+ - Condition: (B, D), where D-dimension will be mapped to C dimensions
132
+ - Output: (B, C, T), same shape as the input
133
+ Parameters:
134
+ - alpha - trainable parameter that controls frequency
135
+ - beta - trainable parameter that controls magnitude
136
+ - condition_alpha_prenet - trainable parameter that controls alpha and beta using condition
137
+ References:
138
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
139
+ https://arxiv.org/abs/2006.08195
140
+ Examples:
141
+ >>> a1 = snakebeta(256, 128)
142
+ >>> x = torch.randn(256)
143
+ >>> cond = torch.randn(128)
144
+ >>> x = a1(x, cond)
145
+ '''
146
+ def __init__(self, in_features, condition_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
147
+ '''
148
+ Initialization.
149
+ INPUT:
150
+ - in_features: dimension of the input
151
+ - condition_features: dimension of the condition vectors
152
+ - alpha - trainable parameter that controls frequency
153
+ - beta - trainable parameter that controls magnitude
154
+ alpha is initialized to 1 by default, higher values = higher-frequency.
155
+ beta is initialized to 1 by default, higher values = higher-magnitude.
156
+ alpha, beta will be trained along with the rest of your model.
157
+ '''
158
+ super(SnakeBetaWithCondition, self).__init__()
159
+ self.in_features = in_features
160
+
161
+ self.condition_alpha_prenet = torch.nn.Linear(condition_features, in_features)
162
+ # self.condition_beta_prenet = torch.nn.Linear(condition_features, in_features)
163
+
164
+ # initialize alpha
165
+ self.alpha_logscale = alpha_logscale
166
+ if self.alpha_logscale: # log scale alphas initialized to zeros
167
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
168
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
169
+ else: # linear scale alphas initialized to ones
170
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
171
+ self.beta = Parameter(torch.ones(in_features) * alpha)
172
+
173
+ self.alpha.requires_grad = alpha_trainable
174
+ self.beta.requires_grad = alpha_trainable
175
+
176
+ self.no_div_by_zero = 0.000000001
177
+
178
+ def forward(self, x, condition):
179
+ '''
180
+ condition: [B, D]
181
+ Forward pass of the function.
182
+ Applies the function to the input elementwise.
183
+ SnakeBeta := x + 1/b * sin^2 (xa)
184
+ '''
185
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
186
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
187
+ if self.alpha_logscale:
188
+ alpha = torch.exp(alpha)
189
+ beta = torch.exp(beta)
190
+
191
+ condition = torch.tanh(self.condition_alpha_prenet(condition).unsqueeze(-1)) # Same prenet for both alpha and beta, to save parameters
192
+ alpha = alpha + condition
193
+ beta = beta + 0.5 * condition # multiply 0.5 for avoiding beta being too small
194
+
195
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
196
+
197
+ return x
vec2wav2/layers/causal_conv.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Causal convolusion layer modules."""
7
+
8
+
9
+ import torch
10
+
11
+
12
+ class CausalConv1d(torch.nn.Module):
13
+ """CausalConv1d module with customized initialization."""
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ dilation=1,
21
+ bias=True,
22
+ pad="ConstantPad1d",
23
+ pad_params={"value": 0.0},
24
+ ):
25
+ """Initialize CausalConv1d module."""
26
+ super(CausalConv1d, self).__init__()
27
+ self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
28
+ self.conv = torch.nn.Conv1d(
29
+ in_channels, out_channels, kernel_size, dilation=dilation, bias=bias
30
+ )
31
+
32
+ def forward(self, x):
33
+ """Calculate forward propagation.
34
+
35
+ Args:
36
+ x (Tensor): Input tensor (B, in_channels, T).
37
+
38
+ Returns:
39
+ Tensor: Output tensor (B, out_channels, T).
40
+
41
+ """
42
+ return self.conv(self.pad(x))[:, :, : x.size(2)]
43
+
44
+
45
+ class CausalConvTranspose1d(torch.nn.Module):
46
+ """CausalConvTranspose1d module with customized initialization."""
47
+
48
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
49
+ """Initialize CausalConvTranspose1d module."""
50
+ super(CausalConvTranspose1d, self).__init__()
51
+ self.deconv = torch.nn.ConvTranspose1d(
52
+ in_channels, out_channels, kernel_size, stride, bias=bias
53
+ )
54
+ self.stride = stride
55
+
56
+ def forward(self, x):
57
+ """Calculate forward propagation.
58
+
59
+ Args:
60
+ x (Tensor): Input tensor (B, in_channels, T_in).
61
+
62
+ Returns:
63
+ Tensor: Output tensor (B, out_channels, T_out).
64
+
65
+ """
66
+ return self.deconv(x)[:, :, : -self.stride]
vec2wav2/layers/pqmf.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Pseudo QMF modules."""
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from scipy.signal import kaiser
13
+
14
+
15
+ def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
16
+ """Design prototype filter for PQMF.
17
+
18
+ This method is based on `A Kaiser window approach for the design of prototype
19
+ filters of cosine modulated filterbanks`_.
20
+
21
+ Args:
22
+ taps (int): The number of filter taps.
23
+ cutoff_ratio (float): Cut-off frequency ratio.
24
+ beta (float): Beta coefficient for kaiser window.
25
+
26
+ Returns:
27
+ ndarray: Impluse response of prototype filter (taps + 1,).
28
+
29
+ .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
30
+ https://ieeexplore.ieee.org/abstract/document/681427
31
+
32
+ """
33
+ # check the arguments are valid
34
+ assert taps % 2 == 0, "The number of taps mush be even number."
35
+ assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
36
+
37
+ # make initial filter
38
+ omega_c = np.pi * cutoff_ratio
39
+ with np.errstate(invalid="ignore"):
40
+ h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
41
+ np.pi * (np.arange(taps + 1) - 0.5 * taps)
42
+ )
43
+ h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
44
+
45
+ # apply kaiser window
46
+ w = kaiser(taps + 1, beta)
47
+ h = h_i * w
48
+
49
+ return h
50
+
51
+
52
+ class PQMF(torch.nn.Module):
53
+ """PQMF module.
54
+
55
+ This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
56
+
57
+ .. _`Near-perfect-reconstruction pseudo-QMF banks`:
58
+ https://ieeexplore.ieee.org/document/258122
59
+
60
+ """
61
+
62
+ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
63
+ """Initilize PQMF module.
64
+
65
+ The cutoff_ratio and beta parameters are optimized for #subbands = 4.
66
+ See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
67
+
68
+ Args:
69
+ subbands (int): The number of subbands.
70
+ taps (int): The number of filter taps.
71
+ cutoff_ratio (float): Cut-off frequency ratio.
72
+ beta (float): Beta coefficient for kaiser window.
73
+
74
+ """
75
+ super(PQMF, self).__init__()
76
+
77
+ # build analysis & synthesis filter coefficients
78
+ h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
79
+ h_analysis = np.zeros((subbands, len(h_proto)))
80
+ h_synthesis = np.zeros((subbands, len(h_proto)))
81
+ for k in range(subbands):
82
+ h_analysis[k] = (
83
+ 2
84
+ * h_proto
85
+ * np.cos(
86
+ (2 * k + 1)
87
+ * (np.pi / (2 * subbands))
88
+ * (np.arange(taps + 1) - (taps / 2))
89
+ + (-1) ** k * np.pi / 4
90
+ )
91
+ )
92
+ h_synthesis[k] = (
93
+ 2
94
+ * h_proto
95
+ * np.cos(
96
+ (2 * k + 1)
97
+ * (np.pi / (2 * subbands))
98
+ * (np.arange(taps + 1) - (taps / 2))
99
+ - (-1) ** k * np.pi / 4
100
+ )
101
+ )
102
+
103
+ # convert to tensor
104
+ analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
105
+ synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
106
+
107
+ # register coefficients as beffer
108
+ self.register_buffer("analysis_filter", analysis_filter)
109
+ self.register_buffer("synthesis_filter", synthesis_filter)
110
+
111
+ # filter for downsampling & upsampling
112
+ updown_filter = torch.zeros((subbands, subbands, subbands)).float()
113
+ for k in range(subbands):
114
+ updown_filter[k, k, 0] = 1.0
115
+ self.register_buffer("updown_filter", updown_filter)
116
+ self.subbands = subbands
117
+
118
+ # keep padding info
119
+ self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
120
+
121
+ def analysis(self, x):
122
+ """Analysis with PQMF.
123
+
124
+ Args:
125
+ x (Tensor): Input tensor (B, 1, T).
126
+
127
+ Returns:
128
+ Tensor: Output tensor (B, subbands, T // subbands).
129
+
130
+ """
131
+ x = F.conv1d(self.pad_fn(x), self.analysis_filter)
132
+ return F.conv1d(x, self.updown_filter, stride=self.subbands)
133
+
134
+ def synthesis(self, x):
135
+ """Synthesis with PQMF.
136
+
137
+ Args:
138
+ x (Tensor): Input tensor (B, subbands, T // subbands).
139
+
140
+ Returns:
141
+ Tensor: Output tensor (B, 1, T).
142
+
143
+ """
144
+ # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
145
+ # Not sure this is the correct way, it is better to check again.
146
+ # TODO(kan-bayashi): Understand the reconstruction procedure
147
+ x = F.conv_transpose1d(
148
+ x, self.updown_filter * self.subbands, stride=self.subbands
149
+ )
150
+ return F.conv1d(self.pad_fn(x), self.synthesis_filter)
vec2wav2/layers/residual_block.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Residual block modules.
4
+
5
+ References:
6
+ - https://github.com/r9y9/wavenet_vocoder
7
+ - https://github.com/jik876/hifi-gan
8
+
9
+ """
10
+
11
+ import math
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class Conv1d(torch.nn.Conv1d):
18
+ """Conv1d module with customized initialization."""
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ """Initialize Conv1d module."""
22
+ super(Conv1d, self).__init__(*args, **kwargs)
23
+
24
+ def reset_parameters(self):
25
+ """Reset parameters."""
26
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
27
+ if self.bias is not None:
28
+ torch.nn.init.constant_(self.bias, 0.0)
29
+
30
+
31
+ class Conv1d1x1(Conv1d):
32
+ """1x1 Conv1d with customized initialization."""
33
+
34
+ def __init__(self, in_channels, out_channels, bias):
35
+ """Initialize 1x1 Conv1d module."""
36
+ super(Conv1d1x1, self).__init__(
37
+ in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
38
+ )
39
+
40
+
41
+ class WaveNetResidualBlock(torch.nn.Module):
42
+ """Residual block module in WaveNet."""
43
+
44
+ def __init__(
45
+ self,
46
+ kernel_size=3,
47
+ residual_channels=64,
48
+ gate_channels=128,
49
+ skip_channels=64,
50
+ aux_channels=80,
51
+ dropout=0.0,
52
+ dilation=1,
53
+ bias=True,
54
+ use_causal_conv=False,
55
+ ):
56
+ """Initialize WaveNetResidualBlock module.
57
+
58
+ Args:
59
+ kernel_size (int): Kernel size of dilation convolution layer.
60
+ residual_channels (int): Number of channels for residual connection.
61
+ skip_channels (int): Number of channels for skip connection.
62
+ aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
63
+ dropout (float): Dropout probability.
64
+ dilation (int): Dilation factor.
65
+ bias (bool): Whether to add bias parameter in convolution layers.
66
+ use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
67
+
68
+ """
69
+ super().__init__()
70
+ self.dropout = dropout
71
+ # no future time stamps available
72
+ if use_causal_conv:
73
+ padding = (kernel_size - 1) * dilation
74
+ else:
75
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
76
+ padding = (kernel_size - 1) // 2 * dilation
77
+ self.use_causal_conv = use_causal_conv
78
+
79
+ # dilation conv
80
+ self.conv = Conv1d(
81
+ residual_channels,
82
+ gate_channels,
83
+ kernel_size,
84
+ padding=padding,
85
+ dilation=dilation,
86
+ bias=bias,
87
+ )
88
+
89
+ # local conditioning
90
+ if aux_channels > 0:
91
+ self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
92
+ else:
93
+ self.conv1x1_aux = None
94
+
95
+ # conv output is split into two groups
96
+ gate_out_channels = gate_channels // 2
97
+ self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
98
+ self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
99
+
100
+ def forward(self, x, c):
101
+ """Calculate forward propagation.
102
+
103
+ Args:
104
+ x (Tensor): Input tensor (B, residual_channels, T).
105
+ c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
106
+
107
+ Returns:
108
+ Tensor: Output tensor for residual connection (B, residual_channels, T).
109
+ Tensor: Output tensor for skip connection (B, skip_channels, T).
110
+
111
+ """
112
+ residual = x
113
+ x = F.dropout(x, p=self.dropout, training=self.training)
114
+ x = self.conv(x)
115
+
116
+ # remove future time steps if use_causal_conv conv
117
+ x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
118
+
119
+ # split into two part for gated activation
120
+ splitdim = 1
121
+ xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
122
+
123
+ # local conditioning
124
+ if c is not None:
125
+ assert self.conv1x1_aux is not None
126
+ c = self.conv1x1_aux(c)
127
+ ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
128
+ xa, xb = xa + ca, xb + cb
129
+
130
+ x = torch.tanh(xa) * torch.sigmoid(xb)
131
+
132
+ # for skip connection
133
+ s = self.conv1x1_skip(x)
134
+
135
+ # for residual connection
136
+ x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
137
+
138
+ return x, s
139
+
140
+
141
+ class HiFiGANResidualBlock(torch.nn.Module):
142
+ """Residual block module in HiFiGAN."""
143
+
144
+ def __init__(
145
+ self,
146
+ kernel_size=3,
147
+ channels=512,
148
+ dilations=(1, 3, 5),
149
+ bias=True,
150
+ use_additional_convs=True,
151
+ nonlinear_activation="LeakyReLU",
152
+ nonlinear_activation_params={"negative_slope": 0.1},
153
+ ):
154
+ """Initialize HiFiGANResidualBlock module.
155
+
156
+ Args:
157
+ kernel_size (int): Kernel size of dilation convolution layer.
158
+ channels (int): Number of channels for convolution layer.
159
+ dilations (List[int]): List of dilation factors.
160
+ use_additional_convs (bool): Whether to use additional convolution layers.
161
+ bias (bool): Whether to add bias parameter in convolution layers.
162
+ nonlinear_activation (str): Activation function module name.
163
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
164
+
165
+ """
166
+ super().__init__()
167
+ self.use_additional_convs = use_additional_convs
168
+ self.convs1 = torch.nn.ModuleList()
169
+ if use_additional_convs:
170
+ self.convs2 = torch.nn.ModuleList()
171
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
172
+ for dilation in dilations:
173
+ self.convs1 += [
174
+ torch.nn.Sequential(
175
+ getattr(torch.nn, nonlinear_activation)(
176
+ **nonlinear_activation_params
177
+ ),
178
+ torch.nn.Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ 1,
183
+ dilation=dilation,
184
+ bias=bias,
185
+ padding=(kernel_size - 1) // 2 * dilation,
186
+ ),
187
+ )
188
+ ]
189
+ if use_additional_convs:
190
+ self.convs2 += [
191
+ torch.nn.Sequential(
192
+ getattr(torch.nn, nonlinear_activation)(
193
+ **nonlinear_activation_params
194
+ ),
195
+ torch.nn.Conv1d(
196
+ channels,
197
+ channels,
198
+ kernel_size,
199
+ 1,
200
+ dilation=1,
201
+ bias=bias,
202
+ padding=(kernel_size - 1) // 2,
203
+ ),
204
+ )
205
+ ]
206
+
207
+ def forward(self, x):
208
+ """Calculate forward propagation.
209
+
210
+ Args:
211
+ x (Tensor): Input tensor (B, channels, T).
212
+
213
+ Returns:
214
+ Tensor: Output tensor (B, channels, T).
215
+
216
+ """
217
+ for idx in range(len(self.convs1)):
218
+ xt = self.convs1[idx](x)
219
+ if self.use_additional_convs:
220
+ xt = self.convs2[idx](xt)
221
+ x = xt + x
222
+ return x
vec2wav2/layers/residual_stack.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Residual stack module in MelGAN."""
7
+
8
+ import torch
9
+
10
+ from vec2wav2.layers import CausalConv1d
11
+
12
+
13
+ class ResidualStack(torch.nn.Module):
14
+ """Residual stack module introduced in MelGAN."""
15
+
16
+ def __init__(
17
+ self,
18
+ kernel_size=3,
19
+ channels=32,
20
+ dilation=1,
21
+ bias=True,
22
+ nonlinear_activation="LeakyReLU",
23
+ nonlinear_activation_params={"negative_slope": 0.2},
24
+ pad="ReflectionPad1d",
25
+ pad_params={},
26
+ use_causal_conv=False,
27
+ ):
28
+ """Initialize ResidualStack module.
29
+
30
+ Args:
31
+ kernel_size (int): Kernel size of dilation convolution layer.
32
+ channels (int): Number of channels of convolution layers.
33
+ dilation (int): Dilation factor.
34
+ bias (bool): Whether to add bias parameter in convolution layers.
35
+ nonlinear_activation (str): Activation function module name.
36
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
37
+ pad (str): Padding function module name before dilated convolution layer.
38
+ pad_params (dict): Hyperparameters for padding function.
39
+ use_causal_conv (bool): Whether to use causal convolution.
40
+
41
+ """
42
+ super(ResidualStack, self).__init__()
43
+
44
+ # defile residual stack part
45
+ if not use_causal_conv:
46
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
47
+ self.stack = torch.nn.Sequential(
48
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
49
+ getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
50
+ torch.nn.Conv1d(
51
+ channels, channels, kernel_size, dilation=dilation, bias=bias
52
+ ),
53
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
54
+ torch.nn.Conv1d(channels, channels, 1, bias=bias),
55
+ )
56
+ else:
57
+ self.stack = torch.nn.Sequential(
58
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
59
+ CausalConv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ dilation=dilation,
64
+ bias=bias,
65
+ pad=pad,
66
+ pad_params=pad_params,
67
+ ),
68
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
69
+ torch.nn.Conv1d(channels, channels, 1, bias=bias),
70
+ )
71
+
72
+ # defile extra layer for skip connection
73
+ self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
74
+
75
+ def forward(self, c):
76
+ """Calculate forward propagation.
77
+
78
+ Args:
79
+ c (Tensor): Input tensor (B, channels, T).
80
+
81
+ Returns:
82
+ Tensor: Output tensor (B, chennels, T).
83
+
84
+ """
85
+ return self.stack(c) + self.skip_layer(c)
vec2wav2/layers/tade_res_block.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+
4
+ """StyleMelGAN's TADEResBlock Modules."""
5
+
6
+ from functools import partial
7
+
8
+ import torch
9
+
10
+
11
+ class TADELayer(torch.nn.Module):
12
+ """TADE Layer module."""
13
+
14
+ def __init__(
15
+ self,
16
+ in_channels=64,
17
+ aux_channels=80,
18
+ kernel_size=9,
19
+ bias=True,
20
+ upsample_factor=2,
21
+ upsample_mode="nearest",
22
+ ):
23
+ """Initilize TADE layer."""
24
+ super().__init__()
25
+ self.norm = torch.nn.InstanceNorm1d(in_channels)
26
+ self.aux_conv = torch.nn.Sequential(
27
+ torch.nn.Conv1d(
28
+ aux_channels,
29
+ in_channels,
30
+ kernel_size,
31
+ 1,
32
+ bias=bias,
33
+ padding=(kernel_size - 1) // 2,
34
+ ),
35
+ # NOTE(kan-bayashi): Use non-linear activation?
36
+ )
37
+ self.gated_conv = torch.nn.Sequential(
38
+ torch.nn.Conv1d(
39
+ in_channels,
40
+ in_channels * 2,
41
+ kernel_size,
42
+ 1,
43
+ bias=bias,
44
+ padding=(kernel_size - 1) // 2,
45
+ ),
46
+ # NOTE(kan-bayashi): Use non-linear activation?
47
+ )
48
+ self.upsample = torch.nn.Upsample(
49
+ scale_factor=upsample_factor, mode=upsample_mode
50
+ )
51
+
52
+ def forward(self, x, c):
53
+ """Calculate forward propagation.
54
+
55
+ Args:
56
+ x (Tensor): Input tensor (B, in_channels, T).
57
+ c (Tensor): Auxiliary input tensor (B, aux_channels, T').
58
+
59
+ Returns:
60
+ Tensor: Output tensor (B, in_channels, T * in_upsample_factor).
61
+ Tensor: Upsampled aux tensor (B, in_channels, T * aux_upsample_factor).
62
+
63
+ """
64
+ x = self.norm(x)
65
+ c = self.upsample(c)
66
+ c = self.aux_conv(c)
67
+ cg = self.gated_conv(c)
68
+ cg1, cg2 = cg.split(cg.size(1) // 2, dim=1)
69
+ # NOTE(kan-bayashi): Use upsample for noise input here?
70
+ y = cg1 * self.upsample(x) + cg2
71
+ # NOTE(kan-bayashi): Return upsampled aux here?
72
+ return y, c
73
+
74
+
75
+ class TADEResBlock(torch.nn.Module):
76
+ """TADEResBlock module."""
77
+
78
+ def __init__(
79
+ self,
80
+ in_channels=64,
81
+ aux_channels=80,
82
+ kernel_size=9,
83
+ dilation=2,
84
+ bias=True,
85
+ upsample_factor=2,
86
+ upsample_mode="nearest",
87
+ gated_function="softmax",
88
+ ):
89
+ """Initialize TADEResBlock module."""
90
+ super().__init__()
91
+ self.tade1 = TADELayer(
92
+ in_channels=in_channels,
93
+ aux_channels=aux_channels,
94
+ kernel_size=kernel_size,
95
+ bias=bias,
96
+ # NOTE(kan-bayashi): Use upsample in the first TADE layer?
97
+ upsample_factor=1,
98
+ upsample_mode=upsample_mode,
99
+ )
100
+ self.gated_conv1 = torch.nn.Conv1d(
101
+ in_channels,
102
+ in_channels * 2,
103
+ kernel_size,
104
+ 1,
105
+ bias=bias,
106
+ padding=(kernel_size - 1) // 2,
107
+ )
108
+ self.tade2 = TADELayer(
109
+ in_channels=in_channels,
110
+ aux_channels=in_channels,
111
+ kernel_size=kernel_size,
112
+ bias=bias,
113
+ upsample_factor=upsample_factor,
114
+ upsample_mode=upsample_mode,
115
+ )
116
+ self.gated_conv2 = torch.nn.Conv1d(
117
+ in_channels,
118
+ in_channels * 2,
119
+ kernel_size,
120
+ 1,
121
+ bias=bias,
122
+ dilation=dilation,
123
+ padding=(kernel_size - 1) // 2 * dilation,
124
+ )
125
+ self.upsample = torch.nn.Upsample(
126
+ scale_factor=upsample_factor, mode=upsample_mode
127
+ )
128
+ if gated_function == "softmax":
129
+ self.gated_function = partial(torch.softmax, dim=1)
130
+ elif gated_function == "sigmoid":
131
+ self.gated_function = torch.sigmoid
132
+ else:
133
+ raise ValueError(f"{gated_function} is not supported.")
134
+
135
+ def forward(self, x, c):
136
+ """Calculate forward propagation.
137
+
138
+ Args:
139
+ x (Tensor): Input tensor (B, in_channels, T).
140
+ c (Tensor): Auxiliary input tensor (B, aux_channels, T').
141
+
142
+ Returns:
143
+ Tensor: Output tensor (B, in_channels, T * in_upsample_factor).
144
+ Tensor: Upsampled auxirialy tensor (B, in_channels, T * in_upsample_factor).
145
+
146
+ """
147
+ residual = x
148
+
149
+ x, c = self.tade1(x, c)
150
+ x = self.gated_conv1(x)
151
+ xa, xb = x.split(x.size(1) // 2, dim=1)
152
+ x = self.gated_function(xa) * torch.tanh(xb)
153
+
154
+ x, c = self.tade2(x, c)
155
+ x = self.gated_conv2(x)
156
+ xa, xb = x.split(x.size(1) // 2, dim=1)
157
+ x = self.gated_function(xa) * torch.tanh(xb)
158
+
159
+ # NOTE(kan-bayashi): Return upsampled aux here?
160
+ return self.upsample(residual) + x, c
vec2wav2/layers/upsample.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Upsampling module.
4
+
5
+ This code is modified from https://github.com/r9y9/wavenet_vocoder.
6
+
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from vec2wav2.layers import Conv1d
14
+
15
+
16
+ class Stretch2d(torch.nn.Module):
17
+ """Stretch2d module."""
18
+
19
+ def __init__(self, x_scale, y_scale, mode="nearest"):
20
+ """Initialize Stretch2d module.
21
+
22
+ Args:
23
+ x_scale (int): X scaling factor (Time axis in spectrogram).
24
+ y_scale (int): Y scaling factor (Frequency axis in spectrogram).
25
+ mode (str): Interpolation mode.
26
+
27
+ """
28
+ super(Stretch2d, self).__init__()
29
+ self.x_scale = x_scale
30
+ self.y_scale = y_scale
31
+ self.mode = mode
32
+
33
+ def forward(self, x):
34
+ """Calculate forward propagation.
35
+
36
+ Args:
37
+ x (Tensor): Input tensor (B, C, F, T).
38
+
39
+ Returns:
40
+ Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
41
+
42
+ """
43
+ return F.interpolate(
44
+ x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
45
+ )
46
+
47
+
48
+ class Conv2d(torch.nn.Conv2d):
49
+ """Conv2d module with customized initialization."""
50
+
51
+ def __init__(self, *args, **kwargs):
52
+ """Initialize Conv2d module."""
53
+ super(Conv2d, self).__init__(*args, **kwargs)
54
+
55
+ def reset_parameters(self):
56
+ """Reset parameters."""
57
+ self.weight.data.fill_(1.0 / np.prod(self.kernel_size))
58
+ if self.bias is not None:
59
+ torch.nn.init.constant_(self.bias, 0.0)
60
+
61
+
62
+ class UpsampleNetwork(torch.nn.Module):
63
+ """Upsampling network module."""
64
+
65
+ def __init__(
66
+ self,
67
+ upsample_scales,
68
+ nonlinear_activation=None,
69
+ nonlinear_activation_params={},
70
+ interpolate_mode="nearest",
71
+ freq_axis_kernel_size=1,
72
+ use_causal_conv=False,
73
+ ):
74
+ """Initialize upsampling network module.
75
+
76
+ Args:
77
+ upsample_scales (list): List of upsampling scales.
78
+ nonlinear_activation (str): Activation function name.
79
+ nonlinear_activation_params (dict): Arguments for specified activation function.
80
+ interpolate_mode (str): Interpolation mode.
81
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
82
+
83
+ """
84
+ super(UpsampleNetwork, self).__init__()
85
+ self.use_causal_conv = use_causal_conv
86
+ self.up_layers = torch.nn.ModuleList()
87
+ for scale in upsample_scales:
88
+ # interpolation layer
89
+ stretch = Stretch2d(scale, 1, interpolate_mode)
90
+ self.up_layers += [stretch]
91
+
92
+ # conv layer
93
+ assert (
94
+ freq_axis_kernel_size - 1
95
+ ) % 2 == 0, "Not support even number freq axis kernel size."
96
+ freq_axis_padding = (freq_axis_kernel_size - 1) // 2
97
+ kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
98
+ if use_causal_conv:
99
+ padding = (freq_axis_padding, scale * 2)
100
+ else:
101
+ padding = (freq_axis_padding, scale)
102
+ conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
103
+ self.up_layers += [conv]
104
+
105
+ # nonlinear
106
+ if nonlinear_activation is not None:
107
+ nonlinear = getattr(torch.nn, nonlinear_activation)(
108
+ **nonlinear_activation_params
109
+ )
110
+ self.up_layers += [nonlinear]
111
+
112
+ def forward(self, c):
113
+ """Calculate forward propagation.
114
+
115
+ Args:
116
+ c : Input tensor (B, C, T).
117
+
118
+ Returns:
119
+ Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
120
+
121
+ """
122
+ c = c.unsqueeze(1) # (B, 1, C, T)
123
+ for f in self.up_layers:
124
+ if self.use_causal_conv and isinstance(f, Conv2d):
125
+ c = f(c)[..., : c.size(-1)]
126
+ else:
127
+ c = f(c)
128
+ return c.squeeze(1) # (B, C, T')
129
+
130
+
131
+ class ConvInUpsampleNetwork(torch.nn.Module):
132
+ """Convolution + upsampling network module."""
133
+
134
+ def __init__(
135
+ self,
136
+ upsample_scales,
137
+ nonlinear_activation=None,
138
+ nonlinear_activation_params={},
139
+ interpolate_mode="nearest",
140
+ freq_axis_kernel_size=1,
141
+ aux_channels=80,
142
+ aux_context_window=0,
143
+ use_causal_conv=False,
144
+ ):
145
+ """Initialize convolution + upsampling network module.
146
+
147
+ Args:
148
+ upsample_scales (list): List of upsampling scales.
149
+ nonlinear_activation (str): Activation function name.
150
+ nonlinear_activation_params (dict): Arguments for specified activation function.
151
+ mode (str): Interpolation mode.
152
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
153
+ aux_channels (int): Number of channels of pre-convolutional layer.
154
+ aux_context_window (int): Context window size of the pre-convolutional layer.
155
+ use_causal_conv (bool): Whether to use causal structure.
156
+
157
+ """
158
+ super(ConvInUpsampleNetwork, self).__init__()
159
+ self.aux_context_window = aux_context_window
160
+ self.use_causal_conv = use_causal_conv and aux_context_window > 0
161
+ # To capture wide-context information in conditional features
162
+ kernel_size = (
163
+ aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
164
+ )
165
+ # NOTE(kan-bayashi): Here do not use padding because the input is already padded
166
+ self.conv_in = Conv1d(
167
+ aux_channels, aux_channels, kernel_size=kernel_size, bias=False
168
+ )
169
+ self.upsample = UpsampleNetwork(
170
+ upsample_scales=upsample_scales,
171
+ nonlinear_activation=nonlinear_activation,
172
+ nonlinear_activation_params=nonlinear_activation_params,
173
+ interpolate_mode=interpolate_mode,
174
+ freq_axis_kernel_size=freq_axis_kernel_size,
175
+ use_causal_conv=use_causal_conv,
176
+ )
177
+
178
+ def forward(self, c):
179
+ """Calculate forward propagation.
180
+
181
+ Args:
182
+ c : Input tensor (B, C, T').
183
+
184
+ Returns:
185
+ Tensor: Upsampled tensor (B, C, T),
186
+ where T = (T' - aux_context_window * 2) * prod(upsample_scales).
187
+
188
+ Note:
189
+ The length of inputs considers the context window size.
190
+
191
+ """
192
+ c_ = self.conv_in(c)
193
+ c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
194
+ return self.upsample(c)
vec2wav2/losses/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .adversarial_loss import * # NOQA
2
+ from .feat_match_loss import * # NOQA
3
+ from .mel_loss import * # NOQA
4
+ from .stft_loss import * # NOQA