diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..847f7b59de0f3d9d484e5c2ced4410396995e08b --- /dev/null +++ b/app.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import gradio as gr +import logging +import yaml +import soundfile as sf +import os +from pathlib import Path +from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args + +# Create Gradio interface +def create_interface(): + args = vc_args() + logger = configure_logging(args.verbose) + voice_converter = VoiceConverter( + expdir=args.expdir, + token_extractor=args.token_extractor, + prompt_extractor=args.prompt_extractor, + prompt_output_layer=args.prompt_output_layer, + checkpoint=args.checkpoint, + script_logger=logger + ) + with gr.Blocks(title="Voice Conversion") as demo: + gr.Markdown("# vec2wav 2.0 Voice Conversion Demo") + gr.Markdown("Upload source audio and target speaker audio to convert the voice.") + + with gr.Row(): + source_audio = gr.Audio(label="Source Audio", type="filepath") + target_audio = gr.Audio(label="Target Speaker Audio", type="filepath") + + examples = [ + ["examples/Zuckerberg.wav", "examples/Rachel.wav"], + ["examples/TheresaMay.wav", "examples/OptimusPrime.wav"] + ] + gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio]) + + convert_btn = gr.Button("Convert Voice") + output_audio = gr.Audio(label="Converted Audio") + + convert_btn.click( + fn=voice_converter.voice_conversion, + inputs=[source_audio, target_audio], + outputs=output_audio + ) + + return demo + +if __name__ == "__main__": + demo = create_interface() + demo.launch(share=True) diff --git a/pretrained/WavLM-Large.pt b/pretrained/WavLM-Large.pt new file mode 100644 index 0000000000000000000000000000000000000000..b704cf463982904321df737eb8a3fe092a0aa019 --- /dev/null +++ b/pretrained/WavLM-Large.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f +size 1261965425 diff --git a/pretrained/config.yml b/pretrained/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..2b639c1cb797233dd3fef2d0963aa84228b40057 --- /dev/null +++ b/pretrained/config.yml @@ -0,0 +1,201 @@ +allow_cache: false +batch_frames: 3600 +config: conf/ctxv2w.v1.yaml +crop_max_frames: 100 +discriminator_adv_loss_params: + average_by_discriminators: false +discriminator_grad_norm: -1 +discriminator_optimizer_params: + betas: + - 0.5 + - 0.9 + lr: 0.0002 + weight_decay: 0.0 +discriminator_optimizer_type: Adam +discriminator_params: + follow_official_norm: true + period_discriminator_params: + bias: true + channels: 32 + downsample_scales: + - 3 + - 3 + - 3 + - 3 + - 1 + in_channels: 1 + kernel_sizes: + - 5 + - 3 + max_downsample_channels: 1024 + nonlinear_activation: LeakyReLU + nonlinear_activation_params: + negative_slope: 0.1 + out_channels: 1 + use_spectral_norm: false + use_weight_norm: true + periods: + - 2 + - 3 + - 5 + - 7 + - 11 + scale_discriminator_params: + bias: true + channels: 128 + downsample_scales: + - 4 + - 4 + - 4 + - 4 + - 1 + in_channels: 1 + kernel_sizes: + - 15 + - 41 + - 5 + - 3 + max_downsample_channels: 1024 + max_groups: 16 + nonlinear_activation: LeakyReLU + nonlinear_activation_params: + negative_slope: 0.1 + out_channels: 1 + scale_downsample_pooling: AvgPool1d + scale_downsample_pooling_params: + kernel_size: 4 + padding: 2 + stride: 2 + scales: 3 +discriminator_scheduler_params: + gamma: 0.5 + milestones: + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_scheduler_type: MultiStepLR +discriminator_train_start_steps: 0 +discriminator_type: HiFiGANMultiScaleMultiPeriodDiscriminator +distributed: true +dropout_features: 0.0 +eval_interval_steps: 100000 +feat_match_loss_params: + average_by_discriminators: false + average_by_layers: false + include_final_outputs: false +frontend_mel_prediction_stop_steps: 200000 +frontend_params: + conformer_params: + activation_type: swish + attention_dim: 184 + attention_dropout_rate: 0.2 + attention_heads: 2 + cnn_module_kernel: 31 + concat_after: false + dropout_rate: 0.2 + linear_units: 1536 + macaron_style: true + normalize_before: true + num_blocks: 2 + pos_enc_layer_type: rel_pos + positional_dropout_rate: 0.2 + positionwise_conv_kernel_size: 3 + positionwise_layer_type: conv1d + selfattention_layer_type: rel_selfattn + use_cnn_module: true + prompt_channels: 1024 + vqvec_channels: 512 +generator_adv_loss_params: + average_by_discriminators: false +generator_grad_norm: -1 +generator_optimizer_params: + betas: + - 0.5 + - 0.9 + lr: 0.0002 + weight_decay: 0.0 +generator_optimizer_type: Adam +generator_params: + bias: true + channels: 512 + condition_dim: 1024 + in_channels: 184 + kernel_size: 7 + nonlinear_activation: snakebeta-condition + out_channels: 1 + resblock: '1' + resblock_dilations: + - - 1 + - 3 + - 5 + - - 1 + - 3 + - 5 + - - 1 + - 3 + - 5 + resblock_kernel_sizes: + - 3 + - 7 + - 11 + snake_logscale: true + upsample_kernel_sizes: + - 16 + - 10 + - 6 + - 4 + upsample_scales: + - 8 + - 5 + - 3 + - 2 + use_additional_convs: true + use_weight_norm: true +generator_scheduler_params: + gamma: 0.5 + milestones: + - 200000 + - 400000 + - 600000 + - 800000 +generator_scheduler_type: MultiStepLR +generator_train_start_steps: 1 +generator_type: BigVGAN +hop_size: 240 +lambda_adv: 1.0 +lambda_aux: 45.0 +lambda_feat_match: 2.0 +lambda_frontend_mel_prediction: 60 +log_interval_steps: 1000 +max_num_frames: 3000 +mel_loss_params: + fft_size: 2048 + fmax: 8000 + fmin: 40 + fs: 24000 + hop_size: 300 + log_base: null + num_mels: 80 + win_length: 1200 + window: hann +min_num_frames: 600 +num_mels: 80 +num_save_intermediate_results: 4 +num_workers: 8 +outdir: exp/train_all_ctxv2w.v1 +pin_memory: true +pretrain: '' +prompt_fold_by_2: true +prompt_net_type: ConvPromptPrenet +rank: 0 +sampling_rate: 24000 +save_interval_steps: 10000 +use_feat_match_loss: true +use_mel_loss: true +use_stft_loss: false +verbose: 1 +version: 0.5.3 +vq_codebook: feats/vqidx/codebook.npy +win_length: 697 +world_size: 4 diff --git a/pretrained/generator.ckpt b/pretrained/generator.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..98e008e79fa35bea3ea34bfaeb659172cd65e6d1 --- /dev/null +++ b/pretrained/generator.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a10b9df62462bbf48382970ffba267b458b00b361bcb245701e3d3c0b6bd19f +size 161604549 diff --git a/pretrained/vq-wav2vec_kmeans.pt b/pretrained/vq-wav2vec_kmeans.pt new file mode 100644 index 0000000000000000000000000000000000000000..6709c4e657fc700204fed1a5ccbcaed2572e0372 --- /dev/null +++ b/pretrained/vq-wav2vec_kmeans.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c975a93479dc5f3cfc4339032e1547c6034eddd15eb1cba73364c20786b42a5a +size 336509919 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7e0abf23bd829bbe1906d752cbc88090826be582 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,25 @@ +torchaudio==0.13.1 +auraloss==0.4.0 +cython==3.0.10 +einops +debugpy==1.8.0 +fairseq==0.12.2 +filelock~=3.12.2 +h5py +kaldiio~=2.18.0 +librosa==0.8.1 +matplotlib~=3.4.3 +nltk==3.8.1 +numpy +pathlib~=1.0.1 +pyyaml~=6.0 +scikit-learn +scipy~=1.7.1 +setuptools==65.6.3 +six==1.16.0 +soundfile~=0.10.3.post1 +sox +tensorboard +tensorboardx~=2.5.1 +tqdm~=4.62.3 +transformers==4.42.3 \ No newline at end of file diff --git a/vec2wav2/__init__.py b/vec2wav2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd57cc89a2ba3cc800ce37986a5c03fc37340e5 --- /dev/null +++ b/vec2wav2/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +__version__ = "" diff --git a/vec2wav2/__pycache__/__init__.cpython-310.pyc b/vec2wav2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4aa33d2e095cac36362ea23278c17236b4f0f7 Binary files /dev/null and b/vec2wav2/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/__pycache__/__init__.cpython-311.pyc b/vec2wav2/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f94fe5da11233c9a5f5d0b142f67a08444ea6385 Binary files /dev/null and b/vec2wav2/__pycache__/__init__.cpython-311.pyc differ diff --git a/vec2wav2/__pycache__/__init__.cpython-39.pyc b/vec2wav2/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed8422feab0c92a6ffc693a62652983e16acb3b5 Binary files /dev/null and b/vec2wav2/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/bin/.DS_Store b/vec2wav2/bin/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/vec2wav2/bin/.DS_Store differ diff --git a/vec2wav2/bin/__init__.py b/vec2wav2/bin/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vec2wav2/bin/__pycache__/__init__.cpython-310.pyc b/vec2wav2/bin/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf269eb68cdae608b05d8b0ce78834f269a4967b Binary files /dev/null and b/vec2wav2/bin/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/bin/__pycache__/vc.cpython-310.pyc b/vec2wav2/bin/__pycache__/vc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b10f9f94b21a25b26824126a6c03caf1e95d7f69 Binary files /dev/null and b/vec2wav2/bin/__pycache__/vc.cpython-310.pyc differ diff --git a/vec2wav2/bin/decode.py b/vec2wav2/bin/decode.py new file mode 100755 index 0000000000000000000000000000000000000000..c5a4cafdc49495fa0acf83f25cc6f801dcb04322 --- /dev/null +++ b/vec2wav2/bin/decode.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +# Modified by Yiwei Guo, 2024 + +"""Decode with trained vec2wav Generator.""" + +import argparse +import logging +import os +import time + +import numpy as np +import soundfile as sf +import torch +import yaml + +from tqdm import tqdm + +from vec2wav2.datasets import MelSCPDataset +from vec2wav2.utils import load_model, load_feat_codebook, idx2vec + + +def set_loglevel(verbose): + # set logger + if verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + +def main(): + """Run decoding process.""" + parser = argparse.ArgumentParser( + description="Decode from audio tokens and acoustic prompts with trained vec2wav model" + "(See detail in vec2wav2/bin/decode.py)." + ) + parser.add_argument( + "--feats-scp", + "--scp", + default=None, + type=str, + required=True, + help="kaldi-style feats.scp file. " + ) + parser.add_argument( + "--prompt-scp", + default=None, + type=str, + help="kaldi-style prompt.scp file. Similar to feats.scp." + ) + parser.add_argument( + "--outdir", + type=str, + required=True, + help="directory to save generated speech.", + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="checkpoint file to be loaded.", + ) + parser.add_argument( + "--config", + default=None, + type=str, + help="yaml format configuration file. if not explicitly provided, " + "it will be searched in the checkpoint directory. (default=None)", + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + args = parser.parse_args() + set_loglevel(args.verbose) + + # check directory existence + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + # load config + if args.config is None: + dirname = os.path.dirname(args.checkpoint) + args.config = os.path.join(dirname, "config.yml") + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + config.update(vars(args)) + + # get dataset + dataset = MelSCPDataset( + vqidx_scp=args.feats_scp, + prompt_scp=args.prompt_scp, + return_utt_id=True, + ) + logging.info(f"The number of features to be decoded = {len(dataset)}.") + + # setup model + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logging.info(f"Using {'GPU' if torch.cuda.is_available() else 'CPU'}.") + + model = load_model(args.checkpoint, config) + logging.info(f"Loaded model parameters from {args.checkpoint}.") + + model.backend.remove_weight_norm() + model = model.eval().to(device) + + # load vq codebook + feat_codebook, feat_codebook_numgroups = load_feat_codebook(np.load(config["vq_codebook"], allow_pickle=True), device) + + # start generation + total_rtf = 0.0 + with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar: + for idx, batch in enumerate(pbar, 1): + utt_id, vqidx, prompt = batch[0], batch[1], batch[2] + + vqidx = torch.tensor(vqidx).to(device) # (L, G) + prompt = torch.tensor(prompt).unsqueeze(0).to(device) # (1, L', D') + + vqidx = vqidx.long() + vqvec = idx2vec(feat_codebook, vqidx, feat_codebook_numgroups).unsqueeze(0) # (1, L, D) + + # generate + start = time.time() + y = model.inference(vqvec, prompt)[-1].view(-1) + rtf = (time.time() - start) / (len(y) / config["sampling_rate"]) + pbar.set_postfix({"RTF": rtf}) + total_rtf += rtf + + tgt_dir = os.path.dirname(os.path.join(config["outdir"], f"{utt_id}.wav")) + os.makedirs(tgt_dir, exist_ok=True) + basename = os.path.basename(f"{utt_id}.wav") + # save as PCM 16 bit wav file + sf.write( + os.path.join(tgt_dir, basename), + y.cpu().numpy(), + config["sampling_rate"], + "PCM_16", + ) + + # report average RTF + logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).") + + +if __name__ == "__main__": + main() diff --git a/vec2wav2/bin/gradio_app.py b/vec2wav2/bin/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..847f7b59de0f3d9d484e5c2ced4410396995e08b --- /dev/null +++ b/vec2wav2/bin/gradio_app.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import gradio as gr +import logging +import yaml +import soundfile as sf +import os +from pathlib import Path +from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args + +# Create Gradio interface +def create_interface(): + args = vc_args() + logger = configure_logging(args.verbose) + voice_converter = VoiceConverter( + expdir=args.expdir, + token_extractor=args.token_extractor, + prompt_extractor=args.prompt_extractor, + prompt_output_layer=args.prompt_output_layer, + checkpoint=args.checkpoint, + script_logger=logger + ) + with gr.Blocks(title="Voice Conversion") as demo: + gr.Markdown("# vec2wav 2.0 Voice Conversion Demo") + gr.Markdown("Upload source audio and target speaker audio to convert the voice.") + + with gr.Row(): + source_audio = gr.Audio(label="Source Audio", type="filepath") + target_audio = gr.Audio(label="Target Speaker Audio", type="filepath") + + examples = [ + ["examples/Zuckerberg.wav", "examples/Rachel.wav"], + ["examples/TheresaMay.wav", "examples/OptimusPrime.wav"] + ] + gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio]) + + convert_btn = gr.Button("Convert Voice") + output_audio = gr.Audio(label="Converted Audio") + + convert_btn.click( + fn=voice_converter.voice_conversion, + inputs=[source_audio, target_audio], + outputs=output_audio + ) + + return demo + +if __name__ == "__main__": + demo = create_interface() + demo.launch(share=True) diff --git a/vec2wav2/bin/train.py b/vec2wav2/bin/train.py new file mode 100755 index 0000000000000000000000000000000000000000..b93ba78c51530de37a0c8c66e9e3f3472022dc3c --- /dev/null +++ b/vec2wav2/bin/train.py @@ -0,0 +1,1007 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +# Modified by Yiwei Guo, 2024 + +"""Train vec2wav.""" + +import argparse +import logging +import os +import sys +import random + +from collections import defaultdict + +import matplotlib +import numpy as np +import soundfile as sf +import torch +import torch.nn.functional as F +import yaml +import torch.multiprocessing as mp +from tensorboardX import SummaryWriter +from torch.utils.data import DataLoader +from tqdm import tqdm + +import vec2wav2 +import vec2wav2.models +import vec2wav2.optimizers +from torch.utils.data.distributed import DistributedSampler + +from vec2wav2.datasets import AudioMelSCPDataset +from vec2wav2.layers import PQMF +from vec2wav2.losses import DiscriminatorAdversarialLoss +from vec2wav2.losses import FeatureMatchLoss +from vec2wav2.losses import GeneratorAdversarialLoss +from vec2wav2.losses import MelSpectrogramLoss +from vec2wav2.losses import MultiResolutionSTFTLoss +from vec2wav2.utils import crop_seq, load_feat_codebook, idx2vec + +from vec2wav2.utils.espnet_utils import pad_list, make_non_pad_mask + +# set to avoid matplotlib error in CLI environment +matplotlib.use("Agg") + + +def set_loglevel(verbose): + # set logger + if verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif verbose > 0: + logging.basicConfig( + level=logging.INFO, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + stream=sys.stdout, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + +class Trainer(object): + """Customized trainer module for Parallel WaveGAN training.""" + + def __init__( + self, + steps, + epochs, + data_loader, + sampler, + model, + criterion, + optimizer, + scheduler, + config, + device=torch.device("cpu"), + ): + """Initialize trainer. + + Args: + steps (int): Initial global steps. + epochs (int): Initial global epochs. + data_loader (dict): Dict of data loaders. It must contain "train" and "dev" loaders. + model (dict): Dict of models. It must contain "generator" and "discriminator" models. + criterion (dict): Dict of criteria. It must contain "stft" and "mse" criteria. + optimizer (dict): Dict of optimizers. It must contain "generator" and "discriminator" optimizers. + scheduler (dict): Dict of schedulers. It must contain "generator" and "discriminator" schedulers. + config (dict): Config dict loaded from yaml format configuration file. + device (torch.deive): Pytorch device instance. + + """ + self.steps = steps + self.epochs = epochs + self.data_loader = data_loader + self.sampler = sampler + self.model = model + self.criterion = criterion + self.optimizer = optimizer + self.scheduler = scheduler + self.config = config + self.device = device + self.writer = SummaryWriter(config["outdir"]) + self.finish_train = False + self.total_train_loss = defaultdict(float) + self.total_eval_loss = defaultdict(float) + + # load vq codebook + feat_codebook_path = self.config["vq_codebook"] + + self.feat_codebook, self.feat_codebook_numgroups = load_feat_codebook(np.load(feat_codebook_path, allow_pickle=True), device) + + def run(self): + """Run training.""" + self.tqdm = tqdm(initial=self.steps, total=self.config["train_max_steps"], desc="[train]") + while True: + # train one epoch + self._train_epoch() + + # check whether training is finished + if self.finish_train: + break + + self.tqdm.close() + logging.info("Finished training.") + + def save_checkpoint(self, checkpoint_path): + """Save checkpoint. + Args: + checkpoint_path (str): Checkpoint path to be saved. + """ + state_dict = { + "optimizer": { + "generator": self.optimizer["generator"].state_dict(), + "discriminator": self.optimizer["discriminator"].state_dict(), + }, + "scheduler": { + "generator": self.scheduler["generator"].state_dict(), + "discriminator": self.scheduler["discriminator"].state_dict(), + }, + "steps": self.steps, + "epochs": self.epochs, + } + if self.config["distributed"]: + state_dict["model"] = { + "generator": self.model["generator"].module.state_dict(), + "discriminator": self.model["discriminator"].module.state_dict(), + } + else: + state_dict["model"] = { + "generator": self.model["generator"].state_dict(), + "discriminator": self.model["discriminator"].state_dict(), + } + + if not os.path.exists(os.path.dirname(checkpoint_path)): + os.makedirs(os.path.dirname(checkpoint_path)) + torch.save(state_dict, checkpoint_path) + + def load_checkpoint(self, checkpoint_path, load_only_params=False): + """Load checkpoint. + + Args: + checkpoint_path (str): Checkpoint path to be loaded. + load_only_params (bool): Whether to load only model parameters. + + """ + state_dict = torch.load(checkpoint_path, map_location="cpu") + if self.config["distributed"]: + self.model["generator"].module.load_state_dict( + state_dict["model"]["generator"] + ) + self.model["discriminator"].module.load_state_dict( + state_dict["model"]["discriminator"] + ) + else: + self.model["generator"].load_state_dict(state_dict["model"]["generator"]) + self.model["discriminator"].load_state_dict( + state_dict["model"]["discriminator"] + ) + if not load_only_params: + self.steps = state_dict["steps"] + self.epochs = state_dict["epochs"] + self.optimizer["generator"].load_state_dict(state_dict["optimizer"]["generator"]) + self.optimizer["discriminator"].load_state_dict(state_dict["optimizer"]["discriminator"]) + self.scheduler["generator"].load_state_dict(state_dict["scheduler"]["generator"]) + self.scheduler["discriminator"].load_state_dict(state_dict["scheduler"]["discriminator"]) + + def _train_step(self, batch): + """Train model one step.""" + # parse batch + vqidx, mel, prompt, y, xlens, prompt_lens = batch + vqidx = vqidx.to(self.device) + mel = mel.to(self.device) + prompt = prompt.to(self.device) + vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups) # (B, L, D) + y = y.unsqueeze(-2).to(self.device) # (B, 1, T) + + # build mask + mask = make_non_pad_mask(xlens).to(self.device) # (B, L) + prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt) + + # crop wav sequence + crop_xlen = min(self.config["crop_max_frames"], min(xlens)) + x_offsets = [np.random.randint(0, l - crop_xlen + 1) for l in xlens] + crop_ylen = crop_xlen * self.config["hop_size"] + y_offsets = [o * self.config["hop_size"] for o in x_offsets] + y = crop_seq(y, y_offsets, crop_ylen) + + ####################### + # Generator # + ####################### + if self.steps > self.config.get("generator_train_start_steps", 0): + mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T) + + # initialize + gen_loss, aux_loss = 0.0, 0.0 + + # frontend mel prediction loss + if self.steps <= self.config.get("frontend_mel_prediction_stop_steps", 0): + frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)), + torch.masked_select(mel_, mask.unsqueeze(-1))) + self.total_train_loss["train/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item() + gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss + + # multi-resolution sfft loss + if self.config["use_stft_loss"]: + sc_loss, mag_loss = self.criterion["stft"](y_, y) + aux_loss += sc_loss + mag_loss + self.total_train_loss["train/spectral_convergence_loss"] += sc_loss.item() + self.total_train_loss["train/log_stft_magnitude_loss"] += mag_loss.item() + + # subband multi-resolution stft loss + if self.config["use_subband_stft_loss"]: + aux_loss *= 0.5 # for balancing with subband stft loss + y_mb = self.criterion["pqmf"].analysis(y) + y_mb_ = self.criterion["pqmf"].analysis(y_) + sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb) + aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + self.total_train_loss["train/sub_spectral_convergence_loss"] += sub_sc_loss.item() + self.total_train_loss["train/sub_log_stft_magnitude_loss"] += sub_mag_loss.item() + + # mel spectrogram loss + if self.config["use_mel_loss"]: + mel_loss = self.criterion["mel"](y_, y) + aux_loss += mel_loss + self.total_train_loss["train/mel_loss"] += mel_loss.item() + + # weighting aux loss + gen_loss += self.config.get("lambda_aux", 1.0) * aux_loss + + # adversarial loss + if self.steps > self.config["discriminator_train_start_steps"]: + p_ = self.model["discriminator"](y_) + adv_loss = self.criterion["gen_adv"](p_) + self.total_train_loss["train/adversarial_loss"] += adv_loss.item() + + # feature matching loss + if self.config["use_feat_match_loss"]: + # no need to track gradients + with torch.no_grad(): + p = self.model["discriminator"](y) + fm_loss = self.criterion["feat_match"](p_, p) + self.total_train_loss["train/feature_matching_loss"] += fm_loss.item() + adv_loss += self.config["lambda_feat_match"] * fm_loss + + # add adversarial loss to generator loss + gen_loss += self.config["lambda_adv"] * adv_loss + + self.total_train_loss["train/generator_loss"] += gen_loss.item() + + # update generator + self.optimizer["generator"].zero_grad() + gen_loss.backward() + if self.config["generator_grad_norm"] > 0: + torch.nn.utils.clip_grad_norm_( + self.model["generator"].parameters(), + self.config["generator_grad_norm"], + ) + self.optimizer["generator"].step() + self.scheduler["generator"].step() + + ####################### + # Discriminator # + ####################### + if self.steps > self.config["discriminator_train_start_steps"]: + # re-compute y_ which leads better quality + with torch.no_grad(): + # logging.info(f"{vqvec.shape, prompt.shape, mask.shape, prompt_mask.shape}") + _, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T) + + if self.config["generator_params"]["out_channels"] > 1: + y_ = self.criterion["pqmf"].synthesis(y_) + + # discriminator loss + p = self.model["discriminator"](y) + p_ = self.model["discriminator"](y_.detach()) + real_loss, fake_loss = self.criterion["dis_adv"](p_, p) + dis_loss = real_loss + fake_loss + self.total_train_loss["train/real_loss"] += real_loss.item() + self.total_train_loss["train/fake_loss"] += fake_loss.item() + self.total_train_loss["train/discriminator_loss"] += dis_loss.item() + + # update discriminator + self.optimizer["discriminator"].zero_grad() + dis_loss.backward() + if self.config["discriminator_grad_norm"] > 0: + torch.nn.utils.clip_grad_norm_( + self.model["discriminator"].parameters(), + self.config["discriminator_grad_norm"], + ) + self.optimizer["discriminator"].step() + self.scheduler["discriminator"].step() + + # update counts + self.steps += 1 + self.tqdm.update(1) + self._check_train_finish() + + def _train_epoch(self): + """Train model one epoch.""" + for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): + # train one step + self._train_step(batch) + + # check interval + if self.config["rank"] == 0: + self._check_log_interval() + self._check_eval_interval() + self._check_save_interval() + + # check whether training is finished + if self.finish_train: + return + + # update + self.epochs += 1 + self.train_steps_per_epoch = train_steps_per_epoch + logging.info( + f"(Steps: {self.steps}) Finished {self.epochs} epoch training " + f"({self.train_steps_per_epoch} steps per epoch)." + ) + + # needed for shuffle in distributed training + if self.config["distributed"]: + self.sampler["train"].set_epoch(self.epochs) + + @torch.no_grad() + def _eval_step(self, batch): + """Evaluate model one step.""" + # parse batch + vqidx, mel, prompt, y, xlens, prompt_lens = batch + vqidx = vqidx.to(self.device).long() + mel = mel.to(self.device) + prompt = prompt.to(self.device) + vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups) + y = y.unsqueeze(-2).to(self.device) # (B, 1, T) + + # build mask + mask = make_non_pad_mask(xlens).to(self.device) # (B, L) + prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt) + + ####################### + # Generator # + ####################### + mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask) # (B, L, 80), (B, C, T) + + # reconstruct the signal from multi-band signal + if self.config["generator_params"]["out_channels"] > 1: + y_mb_ = y_ + y_ = self.criterion["pqmf"].synthesis(y_mb_) + + # initialize + gen_loss = 0.0 + aux_loss = 0.0 + + # frontend mel prediction loss + frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)), + torch.masked_select(mel_, mask.unsqueeze(-1))) + self.total_eval_loss["eval/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item() + gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss + + # multi-resolution stft loss + if self.config["use_stft_loss"]: + sc_loss, mag_loss = self.criterion["stft"](y_, y) + aux_loss += sc_loss + mag_loss + self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item() + self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item() + + # subband multi-resolution stft loss + if self.config.get("use_subband_stft_loss", False): + aux_loss *= 0.5 # for balancing with subband stft loss + y_mb = self.criterion["pqmf"].analysis(y) + sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb) + self.total_eval_loss["eval/sub_spectral_convergence_loss"] += sub_sc_loss.item() + self.total_eval_loss["eval/sub_log_stft_magnitude_loss"] += sub_mag_loss.item() + aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss) + + # mel spectrogram loss + if self.config["use_mel_loss"]: + mel_loss = self.criterion["mel"](y_, y) + aux_loss += mel_loss + self.total_eval_loss["eval/mel_loss"] += mel_loss.item() + + # weighting stft loss + gen_loss += aux_loss * self.config.get("lambda_aux", 1.0) + + # adversarial loss + p_ = self.model["discriminator"](y_) + adv_loss = self.criterion["gen_adv"](p_) + gen_loss += self.config["lambda_adv"] * adv_loss + + # feature matching loss + if self.config["use_feat_match_loss"]: + p = self.model["discriminator"](y) + fm_loss = self.criterion["feat_match"](p_, p) + self.total_eval_loss["eval/feature_matching_loss"] += fm_loss.item() + gen_loss += ( + self.config["lambda_adv"] * self.config["lambda_feat_match"] * fm_loss + ) + + ####################### + # Discriminator # + ####################### + p = self.model["discriminator"](y) + p_ = self.model["discriminator"](y_) + + # discriminator loss + real_loss, fake_loss = self.criterion["dis_adv"](p_, p) + dis_loss = real_loss + fake_loss + + # add to total eval loss + self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item() + self.total_eval_loss["eval/generator_loss"] += gen_loss.item() + self.total_eval_loss["eval/real_loss"] += real_loss.item() + self.total_eval_loss["eval/fake_loss"] += fake_loss.item() + self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item() + + def _eval_epoch(self): + """Evaluate model one epoch.""" + logging.info(f"(Steps: {self.steps}) Start evaluation.") + # change mode + for key in self.model.keys(): + self.model[key].eval() + + # calculate loss for each batch + for eval_steps_per_epoch, batch in enumerate(tqdm(self.data_loader["dev"], desc="[eval]"), 1): + # eval one step + self._eval_step(batch) + + logging.info( + f"(Steps: {self.steps}) Finished evaluation " + f"({eval_steps_per_epoch} steps per epoch)." + ) + + # average loss + for key in self.total_eval_loss.keys(): + self.total_eval_loss[key] /= eval_steps_per_epoch + logging.info(f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}.") + + # record + self._write_to_tensorboard(self.total_eval_loss) + + # reset + self.total_eval_loss = defaultdict(float) + + # restore mode + for key in self.model.keys(): + self.model[key].train() + + def _write_to_tensorboard(self, loss): + """Write to tensorboard.""" + for key, value in loss.items(): + self.writer.add_scalar(key, value, self.steps) + + def _check_save_interval(self): + if self.steps % self.config["save_interval_steps"] == 0: + self.save_checkpoint(os.path.join(self.config["outdir"], + f"checkpoint-{self.steps}steps.pkl")) + logging.info(f"Successfully saved checkpoint @ {self.steps} steps.") + + def _check_eval_interval(self): + if self.steps % self.config["eval_interval_steps"] == 0: + self._eval_epoch() + + def _check_log_interval(self): + if self.steps % self.config["log_interval_steps"] == 0: + for key in self.total_train_loss.keys(): + self.total_train_loss[key] /= self.config["log_interval_steps"] + logging.info(f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}.") + self._write_to_tensorboard(self.total_train_loss) + + # reset + self.total_train_loss = defaultdict(float) + + def _check_train_finish(self): + if self.steps >= self.config["train_max_steps"]: + self.finish_train = True + + +class Collator(object): + """Customized collator for Pytorch DataLoader in training.""" + + def __init__( + self, + hop_size=256, + win_length=1024, + sampling_rate=16000, + prompt_dim=1024, + prompt_fold_by_2=False + ): + """Initialize customized collator for PyTorch DataLoader. + + Args: + hop_size (int): Hop size of features, in sampling points. + win_length (int): window length of features. + sampling_rate (int): sampling rate of waveform data + prompt_dim (int): number of prompt embedding dimensions + """ + self.hop_size = hop_size + self.win_length = win_length + self.sampling_rate = sampling_rate + self.prompt_dim = prompt_dim + if prompt_fold_by_2: + self.prompt_len_factor = 2 + else: + self.prompt_len_factor = 1 + + def construct_prompt(self, mel_lens): + prompt_lens = [random.randint(int(l / (3 * self.prompt_len_factor)), int(l / (2 * self.prompt_len_factor))) for l in mel_lens] + prompt_starts = [] + is_from_start = [] + for ml, pl in zip(mel_lens, prompt_lens): + if random.random() > 0.5: + # from start + prompt_start = random.randint(0, 1 * self.sampling_rate // (self.hop_size * self.prompt_len_factor)) + is_from_start.append(True) + else: + # from ending + prompt_start = random.randint((ml - 1 * self.sampling_rate // self.hop_size) // self.prompt_len_factor, ml // self.prompt_len_factor) - pl + is_from_start.append(False) + prompt_starts.append(prompt_start) + return prompt_lens, prompt_starts, is_from_start + + def __call__(self, batch): + """Convert into batch tensors. + + Args: + batch (list): list of tuple of the pair of audio and features. + + This collator will automatically determine the prompt segment (acoustic context) for each utterance. + The prompt is cut off from the current utterance, ranging from one third to half of the original utterance. + The prompt can be cut from either the starting or the ending of the utterance, within 1 second margin. + The other features include 2-dim VQ features (2 is the number of groups), and D-dim prompts (e.g. WavLM features) + + Returns: + Tensor ys: waveform batch (B, T). + Tensors vqs, mels: Auxiliary feature batch (B, C, T'), where T' = T / hop_size. + Tensor prompts: prompt feature batch (B, C, T'') + List c_lengths, prompt_lengths: list of lengths + """ + batch = batch[0] + + # check length + batch = [self._adjust_length(*b) for b in batch] + ys, vqs, mels, prompts_old = list(map(list, zip(*batch))) # [(a,b), (c,d)] -> [a, c], [b, d] + + batch_size = len(vqs) + + prompt_lengths, prompt_starts, is_from_starts = self.construct_prompt([len(m) for m in mels]) + c_lengths = [] + prompts = torch.zeros(batch_size, max(prompt_lengths), self.prompt_dim) + for i in range(batch_size): + prompts[i, :prompt_lengths[i]] = torch.tensor(prompts_old[i][prompt_starts[i]:prompt_starts[i]+prompt_lengths[i], :]) + if is_from_starts[i]: + start_idx = (prompt_starts[i] + prompt_lengths[i])*self.prompt_len_factor + mels[i] = mels[i][start_idx:] + vqs[i] = vqs[i][start_idx:] + ys[i] = ys[i][start_idx * self.hop_size: ] + else: + end_idx = prompt_starts[i]*self.prompt_len_factor + mels[i] = mels[i][:end_idx] + vqs[i] = vqs[i][:end_idx] + ys[i] = ys[i][:end_idx * self.hop_size] + c_lengths.append(len(mels[i])) + + vqs = pad_list([torch.tensor(c) for c in vqs], pad_value=0) # (B, L, Groups) + vqs = vqs.long() + mels = pad_list([torch.tensor(c) for c in mels], pad_value=0) # (B, L, 80) + + ys = pad_list([torch.tensor(y, dtype=torch.float) for y in ys], pad_value=0)[:, :mels.size(1) * self.hop_size] # (B, T) + assert ys.size(1) == mels.size(1) * self.hop_size == vqs.size(1) * self.hop_size + + return vqs, mels, prompts, ys, c_lengths, prompt_lengths + + def _adjust_length(self, x, c, *args): + """Adjust the audio and feature lengths. + + Note: + Basically we assume that the length of x and c are adjusted + through preprocessing stage, but if we use other library processed + features, this process will be needed. + + """ + if len(x) > len(c) * self.hop_size: + x = x[(self.win_length - self.hop_size) // 2:] + x = x[:len(c) * self.hop_size] + + # check the legnth is valid + assert len(x) == len(c) * self.hop_size + + return x, c, *args + + +def main(rank, n_gpus): + """Run training process.""" + parser = argparse.ArgumentParser( + description="Train vec2wav2 (See detail in vec2wav2/bin/train.py)." + ) + parser.add_argument( + "--train-wav-scp", + default=None, + type=str, + help="kaldi-style wav.scp file for training. " + ) + parser.add_argument( + "--train-vqidx-scp", + default=None, + type=str, + help="kaldi-style feats.scp file for training. " + ) + parser.add_argument( + "--train-mel-scp", + default=None, + type=str, + help="kaldi-style feats.scp file for training. " + ) + parser.add_argument( + "--train-prompt-scp", + default=None, + type=str, + help="prompt scp (in this case, utt to path)" + ) + parser.add_argument( + "--train-segments", + default=None, + type=str, + help="kaldi-style segments file for training.", + ) + parser.add_argument( + "--train-num-frames", + default=None, + type=str, + help="kaldi-style utt2num_frames file for training.", + ) + parser.add_argument( + "--dev-wav-scp", + default=None, + type=str, + help="kaldi-style wav.scp file for validation. " + ) + parser.add_argument( + "--dev-vqidx-scp", + default=None, + type=str, + help="kaldi-style feats.scp file for vaidation. " + ) + parser.add_argument( + "--dev-mel-scp", + default=None, + type=str, + help="kaldi-style feats.scp file for vaidation. " + ) + parser.add_argument( + "--dev-prompt-scp", + default=None, + type=str, + help="prompt scp (in this case, utt to path)" + ) + parser.add_argument( + "--dev-segments", + default=None, + type=str, + help="kaldi-style segments file for validation.", + ) + parser.add_argument( + "--dev-num-frames", + default=None, + type=str, + help="kaldi-style utt2num_frames file for validation.", + ) + parser.add_argument( + "--outdir", + type=str, + required=True, + help="directory to save checkpoints.", + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="yaml format configuration file.", + ) + parser.add_argument( + "--pretrain", + default="", + type=str, + nargs="?", + help='checkpoint file path to load pretrained params. (default="")', + ) + parser.add_argument( + "--resume", + default="", + type=str, + nargs="?", + help='checkpoint file path to resume training. (default="")', + ) + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)", + ) + parser.add_argument("--vq-codebook", default=None, type=str) + # parser.add_argument("--sampling-rate", type=int) + # parser.add_argument("--num-mels", type=int) + # parser.add_argument("--hop-size", type=int) + # parser.add_argument("--win-length", type=int) + args = parser.parse_args() + + # init distributed training + device = torch.device("cuda") + # effective when using fixed size inputs + # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 + torch.backends.cudnn.benchmark = True + # setup for distributed training + # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed + if n_gpus == 1: + assert rank == 0 + + set_loglevel(args.verbose) + + # check directory existence + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + + # init process group + logging.info("Synchronizing between all workers.") + torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank) + torch.cuda.set_device(rank) + logging.info("Finished init process group.") + + # load and save config + with open(args.config) as f: + config = yaml.load(f, Loader=yaml.Loader) + config.update(vars(args)) + config['rank'] = rank + config['distributed'] = True + config['world_size'] = n_gpus + config["version"] = vec2wav2.__version__ # add version info + if rank == 0: + with open(os.path.join(args.outdir, "config.yml"), "w") as f: + yaml.dump(config, f, Dumper=yaml.Dumper) + for key, value in config.items(): + logging.info(f"{key} = {value}") + + # get dataset + train_dataset = AudioMelSCPDataset( + wav_scp=args.train_wav_scp, + vqidx_scp=args.train_vqidx_scp, + mel_scp=args.train_mel_scp, + prompt_scp=args.train_prompt_scp, + utt2num_frames=args.train_num_frames, + segments=args.train_segments, + batch_frames=config.get("batch_frames", None), + batch_size=config.get("batch_size", None), + min_num_frames=config.get("min_num_frames", None), + max_num_frames=config.get("max_num_frames", None), + allow_cache=config.get("allow_cache", False), # keep compatibility + length_tolerance=config.get("length_tolerance", 2), + prompt_fold_by_2=config.get("prompt_fold_by_2", True) + ) + if rank == 0: + logging.info(f"The number of training batches = {len(train_dataset)}.") + dev_dataset = AudioMelSCPDataset( + wav_scp=args.dev_wav_scp, + vqidx_scp=args.dev_vqidx_scp, + mel_scp=args.dev_mel_scp, + prompt_scp=args.dev_prompt_scp, + utt2num_frames=args.dev_num_frames, + segments=args.dev_segments, + min_num_frames=config.get("min_num_frames", None), + max_num_frames=config.get("max_num_frames", None), + allow_cache=config.get("allow_cache", False), # keep compatibility + length_tolerance=config.get("length_tolerance", 2), + prompt_fold_by_2=config.get("prompt_fold_by_2", True) + ) + if rank == 0: + logging.info(f"The number of development batches = {len(dev_dataset)}.") + dataset = { + "train": train_dataset, + "dev": dev_dataset, + } + + # get data loader + collator = Collator( + hop_size=config["hop_size"], + win_length=config["win_length"], + sampling_rate=config["sampling_rate"], + prompt_dim=config['frontend_params']['prompt_channels'], + prompt_fold_by_2=config.get("prompt_fold_by_2", True) + ) + + sampler = { + "train": DistributedSampler( + dataset=dataset["train"], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ), + "dev": DistributedSampler( + dataset=dataset["dev"], + num_replicas=n_gpus, + rank=rank, + shuffle=False, + )} + data_loader = { + "train": DataLoader( + dataset=dataset["train"], + shuffle=False, + collate_fn=collator, + num_workers=config["num_workers"], + sampler=sampler["train"], + pin_memory=config["pin_memory"], + ), + "dev": DataLoader( + dataset=dataset["dev"], + shuffle=False, + collate_fn=collator, + num_workers=config["num_workers"], + sampler=sampler["dev"], + pin_memory=config["pin_memory"], + ), + } + + # define models + generator_class = getattr( + vec2wav2.models, + # keep compatibility + config.get("generator_type", "ParallelWaveGANGenerator"), + ) + discriminator_class = getattr( + vec2wav2.models, + # keep compatibility + config.get("discriminator_type", "ParallelWaveGANDiscriminator"), + ) + model = { + "generator": vec2wav2.models.VEC2WAV2Generator( + vec2wav2.models.CTXVEC2WAVFrontend(config["prompt_net_type"], config["num_mels"], **config["frontend_params"]), + generator_class(**config["generator_params"]) + ).to(device), + "discriminator": discriminator_class( + **config["discriminator_params"], + ).to(device), + } + + # define criteria + criterion = { + "gen_adv": GeneratorAdversarialLoss( + # keep compatibility + **config.get("generator_adv_loss_params", {}) + ).to(device), + "dis_adv": DiscriminatorAdversarialLoss( + # keep compatibility + **config.get("discriminator_adv_loss_params", {}) + ).to(device), + } + if config.get("use_stft_loss", True): # keep compatibility + config["use_stft_loss"] = True + criterion["stft"] = MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device) + if config.get("use_subband_stft_loss", False): # keep compatibility + assert config["generator_params"]["out_channels"] > 1 + criterion["sub_stft"] = MultiResolutionSTFTLoss(**config["subband_stft_loss_params"]).to(device) + else: + config["use_subband_stft_loss"] = False + if config.get("use_feat_match_loss", False): # keep compatibility + criterion["feat_match"] = FeatureMatchLoss( + # keep compatibility + **config.get("feat_match_loss_params", {}), + ).to(device) + else: + config["use_feat_match_loss"] = False + if config.get("use_mel_loss", False): # keep compatibility + criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"],).to(device) + else: + config["use_mel_loss"] = False + + # define optimizers and schedulers + generator_optimizer_class = getattr( + vec2wav2.optimizers, + # keep compatibility + config.get("generator_optimizer_type", "RAdam"), + ) + discriminator_optimizer_class = getattr( + vec2wav2.optimizers, + # keep compatibility + config.get("discriminator_optimizer_type", "RAdam"), + ) + optimizer = { + "generator": generator_optimizer_class( + model["generator"].parameters(), + **config["generator_optimizer_params"], + ), + "discriminator": discriminator_optimizer_class( + model["discriminator"].parameters(), + **config["discriminator_optimizer_params"], + ), + } + generator_scheduler_class = getattr( + torch.optim.lr_scheduler, + # keep compatibility + config.get("generator_scheduler_type", "StepLR"), + ) + discriminator_scheduler_class = getattr( + torch.optim.lr_scheduler, + # keep compatibility + config.get("discriminator_scheduler_type", "StepLR"), + ) + scheduler = { + "generator": generator_scheduler_class( + optimizer=optimizer["generator"], + **config["generator_scheduler_params"], + ), + "discriminator": discriminator_scheduler_class( + optimizer=optimizer["discriminator"], + **config["discriminator_scheduler_params"], + ), + } + from torch.nn.parallel import DistributedDataParallel + model["generator"] = DistributedDataParallel(model["generator"], device_ids=[rank], find_unused_parameters=True) + model["discriminator"] = DistributedDataParallel(model["discriminator"], device_ids=[rank], find_unused_parameters=True) + + if rank == 0: + # show settings + logging.info(model["generator"]) + logging.info(f"Generator has nparams: {sum([p.numel() for p in model['generator'].parameters()])}") + logging.info(model["discriminator"]) + logging.info(f"Discriminator has nparams: {sum([p.numel() for p in model['discriminator'].parameters()])}") + logging.info(optimizer["generator"]) + logging.info(optimizer["discriminator"]) + + # define trainer + trainer = Trainer( + steps=0, + epochs=0, + data_loader=data_loader, + sampler=sampler, + model=model, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + config=config, + device=device, + ) + + # load pretrained parameters from checkpoint + if len(args.pretrain) != 0: + trainer.load_checkpoint(args.pretrain, load_only_params=True) + if rank == 0: + logging.info(f"Successfully load parameters from {args.pretrain}.") + + # resume from checkpoint + if len(args.resume) != 0: + trainer.load_checkpoint(args.resume) + if rank == 0: + logging.info(f"Successfully resumed from {args.resume}.") + + # run training loop + try: + trainer.run() + finally: + if rank == 0: + trainer.save_checkpoint(os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl")) + logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") + + +if __name__ == "__main__": + assert torch.cuda.is_available(), "CPU training is not allowed." + n_gpus = torch.cuda.device_count() + print(f"============> using {n_gpus} GPUS") + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8000" + + mp.spawn( + main, + nprocs=n_gpus, + args=(n_gpus,) + ) diff --git a/vec2wav2/bin/vc.py b/vec2wav2/bin/vc.py new file mode 100755 index 0000000000000000000000000000000000000000..e81a0e800d246103f80d3ef5b3ad4fa9722db0cd --- /dev/null +++ b/vec2wav2/bin/vc.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024 Yiwei Guo + +""" Run VC inference with trained model """ + +import vec2wav2 +from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor +from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor +# from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor +import torch +import logging +import argparse +from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k +import soundfile as sf +import yaml +import os + + +def configure_logging(verbose): + if verbose: + logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG) + logging.getLogger().setLevel(logging.DEBUG) + logging.basicConfig(level=logging.DEBUG) + else: + logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR) + logging.getLogger().setLevel(logging.ERROR) + logging.basicConfig(level=logging.ERROR) + + script_logger = logging.getLogger("script_logger") + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s')) + script_logger.addHandler(handler) + script_logger.setLevel(logging.INFO) + script_logger.propagate = False + return script_logger + +def vc_args(): + parser = argparse.ArgumentParser() + # required arguments + parser.add_argument("-s", "--source", default="examples/source.wav", type=str, + help="source wav path") + parser.add_argument("-t", "--target", default="examples/target.wav", type=str, + help="target speaker prompt path") + parser.add_argument("-o", "--output", default="output.wav", type=str, + help="path of the output wav file") + + # optional arguments + parser.add_argument("--expdir", default="pretrained/", type=str, + help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.") + parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.") + parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str, + help="checkpoint or model flag of input token extractor") + parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str, + help="checkpoint or model flag of speaker prompt extractor") + parser.add_argument("--prompt-output-layer", default=6, type=int, + help="output layer when prompt is extracted from WavLM.") + + parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") + + args = parser.parse_args() + return args + + +class VoiceConverter: + def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt", + prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6, + checkpoint=None, script_logger=None): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.script_logger = script_logger + self.log_if_possible(f"Using device: {self.device}") + + # set up token extractor + self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device) + feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device) + self.feat_codebook = feat_codebook + self.feat_codebook_numgroups = feat_codebook_numgroups + self.log_if_possible(f"Successfully set up token extractor from {token_extractor}") + + # set up prompt extractor + self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer) + self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}") + + # load VC model + self.config_path = os.path.join(expdir, "config.yml") + with open(self.config_path) as f: + self.config = yaml.load(f, Loader=yaml.Loader) + if checkpoint is not None: + checkpoint = os.path.join(expdir, checkpoint) + else: + checkpoint = os.path.join(expdir, "generator.ckpt") + self.model = load_model(checkpoint, self.config) + self.log_if_possible(f"Successfully set up VC model from {checkpoint}") + + self.model.backend.remove_weight_norm() + self.model.eval().to(self.device) + + @torch.no_grad() + def voice_conversion(self, source_audio, target_audio, output_path="output.wav"): + self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}") + source_wav = read_wav_16k(source_audio) + target_wav = read_wav_16k(target_audio) + vq_idx = self.token_extractor.extract(source_wav).long().to(self.device) + + vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0) + prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device) + converted = self.model.inference(vqvec, prompt)[-1].view(-1) + sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate']) + self.log_if_possible(f"Saved audio file to {output_path}") + return output_path + + def log_if_possible(self, msg): + if self.script_logger is not None: + self.script_logger.info(msg) + + +if __name__ == "__main__": + args = vc_args() + script_logger = configure_logging(args.verbose) + + source_wav = read_wav_16k(args.source) + target_prompt = read_wav_16k(args.target) + + with torch.no_grad(): + voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor, + prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer, + checkpoint=args.checkpoint, script_logger=script_logger) + voice_converter.voice_conversion(args.source, args.target, args.output) diff --git a/vec2wav2/datasets/__init__.py b/vec2wav2/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c32e9d65b83d4846a11ffe7dd5d1d6327f4b4fe8 --- /dev/null +++ b/vec2wav2/datasets/__init__.py @@ -0,0 +1 @@ +from .scp_dataset import * # NOQA diff --git a/vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc b/vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a2b7f7376a07492a9f57b0a1377b2cbd73ffa1c Binary files /dev/null and b/vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc b/vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800fa2a65dc898c28e7481e7e63bade0d739fd0f Binary files /dev/null and b/vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc b/vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8917e1e64932e8a311fbc3256f473696a4d7a9c4 Binary files /dev/null and b/vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc differ diff --git a/vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc b/vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed432eba961459a8c1f61bbd1acc87cc9ec3fd9 Binary files /dev/null and b/vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc differ diff --git a/vec2wav2/datasets/scp_dataset.py b/vec2wav2/datasets/scp_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c6620d47165f33b9ec941503b35479f0db34c3a5 --- /dev/null +++ b/vec2wav2/datasets/scp_dataset.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +# Modified by Yiwei Guo, 2024 + +"""Dataset modules based on kaldi-style scp files.""" + +import logging +import random +import copy +from multiprocessing import Manager + +import kaldiio +import numpy as np + +from torch.utils.data import Dataset +from tqdm import tqdm +from vec2wav2.utils import HDF5ScpLoader +from vec2wav2.utils import NpyScpLoader + + +def _get_feats_scp_loader(feats_scp): + # read the first line of feats.scp file + with open(feats_scp) as f: + key, value = f.readlines()[0].replace("\n", "").split() + + # check scp type + if ":" in value: + value_1, value_2 = value.split(":") + if value_1.endswith(".ark"): + # kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index + return kaldiio.load_scp(feats_scp) + elif value_1.endswith(".h5"): + # hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats + return HDF5ScpLoader(feats_scp) + else: + raise ValueError("Not supported feats.scp type.") + else: + if value.endswith(".h5"): + # hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5 + return HDF5ScpLoader(feats_scp) + elif value.endswith(".npy"): + # npy case: utt_id_1 /path/to/utt_id_1.npy + return NpyScpLoader(feats_scp) + else: + raise ValueError("Not supported feats.scp type.") + + +class AudioMelSCPDataset(Dataset): + """PyTorch compatible audio and feat dataset based on kaldi-stype scp files.""" + + def __init__( + self, + wav_scp, + vqidx_scp, + mel_scp, + prompt_scp, + utt2num_frames=None, + segments=None, + batch_frames=None, + batch_size=None, + min_num_frames=None, + max_num_frames=None, + return_utt_id=False, + return_sampling_rate=False, + allow_cache=False, + length_tolerance=2, + prompt_fold_by_2=True + ): + """Initialize dataset. + + Args: + wav_scp (str): Kaldi-style wav.scp file. + vqidx_scp (str): Kaldi-style fests.scp file. + mel_scp (str): Kaldi-style fests.scp file. + segments (str): Kaldi-style segments file. + min_num_frames (int): Threshold to remove short feature files. + max_num_frames (int): Threshold to remove long feature files. + return_utt_id (bool): Whether to return utterance id. + return_sampling_rate (bool): Whether to return sampling rate. + allow_cache (bool): Whether to allow cache of the loaded files. + prompt_fold_by_2 (bool): if true, then prompt have half the length of vqidx sequence. + + """ + # load scp as lazy dict + self.audio_loader = kaldiio.load_scp(wav_scp, segments=segments) + self.vqidx_loader = _get_feats_scp_loader(vqidx_scp) + self.mel_loader = _get_feats_scp_loader(mel_scp) + + self.prompt_loader = _get_feats_scp_loader(prompt_scp) + + self.utt_ids = list(self.mel_loader.keys()) + self.return_utt_id = return_utt_id + self.return_sampling_rate = return_sampling_rate + self.allow_cache = allow_cache + + utt2num_frames_loader = None + if utt2num_frames is not None: + with open(utt2num_frames, 'r') as f: + utt2num_frames_loader = dict([(x.split()[0], int(x.split()[1])) for x in f.readlines()]) + else: + utt2num_frames_loader = dict([(k, mel.shape[0]) for k, mel in self.mel_loader.items()]) + + self.utt2num_frames_loader = utt2num_frames_loader + + # filter by threshold + if (min_num_frames or max_num_frames) is not None: + mel_lengths = [utt2num_frames_loader[key] for key in self.utt_ids] + idxs = [ + idx + for idx in range(len(self.utt_ids)) + if (min_num_frames and mel_lengths[idx] >= min_num_frames) and (max_num_frames and mel_lengths[idx] <= max_num_frames) + ] + if len(self.utt_ids) != len(idxs): + logging.warning( + f"Some files are filtered by mel length threshold " + f"({len(self.utt_ids)} -> {len(idxs)})." + ) + self.utt_ids = [self.utt_ids[idx] for idx in idxs] + + # batchify + if batch_frames is not None: + self.batches = self.batchify(utt2num_frames_loader, batch_frames=batch_frames) + elif batch_size is not None: + self.batches = self.batchify(utt2num_frames_loader, batch_size=batch_size) + else: + self.batches = [[utt_id] for utt_id in self.utt_ids] + + if allow_cache: + # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 + self.manager = Manager() + self.caches = self.manager.dict() + self.length_tolerance = length_tolerance + if prompt_fold_by_2: + self.prompt_len_factor = 2 + else: + self.prompt_len_factor = 1 + + def batchify(self, utt2num_frames_loader, batch_frames=None, batch_size=None, min_batch_size=1, drop_last=True): + + assert batch_size is None or batch_size > min_batch_size + + batches = [] + batch = [] + accum_num_frames = 0 + utt_ids_set = set(self.utt_ids) + for utt_id, mel_length in tqdm(sorted(list(utt2num_frames_loader.items()), key=lambda x: x[1], reverse=True)): + if utt_id not in utt_ids_set: + continue + if (batch_frames is not None and accum_num_frames + mel_length > batch_frames and len(batch) > min_batch_size) or (batch_size is not None and len(batch) == batch_size): + batches.append(batch) + batch = [] + accum_num_frames = 0 + batch.append(utt_id) + accum_num_frames += mel_length + if len(batch) > min_batch_size and not drop_last: + batches.append(batch) + return batches + + def __getitem__(self, idx): + """Get specified idx items. + + Args: + idx (int): Index of the item. + + Returns: + str: Utterance id (only in return_utt_id = True). + ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True). + ndarrays: Features (T', C). + + """ + batch = self.batches[idx] + batch_items = [] + + for utt_id in batch: + if self.allow_cache and self.caches.get(utt_id) is not None: + items = self.caches[utt_id] + else: + fs, audio = self.audio_loader[utt_id] + mel = self.mel_loader[utt_id] + prompt = self.prompt_loader[utt_id] + vqidx = self.vqidx_loader[utt_id] + + min_len = min(len(mel), len(vqidx), len(prompt)*self.prompt_len_factor) + assert ((abs(len(mel) - min_len) <= self.length_tolerance) and + (abs(len(vqidx) - min_len) <= self.length_tolerance) and + (abs(len(prompt)*self.prompt_len_factor - min_len) <= self.length_tolerance)), \ + f"Audio feature lengths difference exceeds length tolerance for {utt_id}" + mel, vqidx, prompt = mel[:min_len], vqidx[:min_len], prompt[:min_len//self.prompt_len_factor] + + # normalize audio signal to be [-1, 1] + audio = audio.astype(np.float32) + audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit + + if self.return_sampling_rate: + audio = (audio, fs) + + if self.return_utt_id: + items = utt_id, audio, vqidx, mel, prompt + else: + items = audio, vqidx, mel, prompt + + if self.allow_cache: + self.caches[utt_id] = items + + batch_items.append(items) + + return batch_items + + def __len__(self): + """Return dataset length. + Returns: + int: The length of dataset. + """ + return len(self.batches) + + +class MelSCPDataset(Dataset): + """PyTorch compatible feat dataset based on kaldi-stype scp files.""" + + def __init__( + self, + vqidx_scp, + prompt_scp, + return_utt_id=False, + allow_cache=False, + ): + """Initialize dataset. + + Args: + vqidx_scp (str): Kaldi-style fests.scp file. + prompt_scp (str): Kaldi-style scp file. In this file, every utt is associated with its prompt's mel-spectrogram. + min_num_frames (int): Threshold to remove short feature files. + max_num_frames (int): Threshold to remove long feature files. + return_utt_id (bool): Whether to return utterance id. + allow_cache (bool): Whether to allow cache of the loaded files. + """ + # load scp as lazy dict + vqidx_loader = _get_feats_scp_loader(vqidx_scp) + self.prompt_loader = _get_feats_scp_loader(prompt_scp) + # self.prompt_loader = dict() + # with open(prompt_scp, 'r') as fr: + # for line in fr.readlines(): + # terms = line.strip().split() + # self.prompt_loader[terms[0]] = terms[1] + vqidx_keys = list(set(self.prompt_loader.keys()) & set(vqidx_loader.keys())) + + # NOTE: this dataset does not apply filtering, because it is usually used for decoding + + self.vqidx_loader = vqidx_loader + self.utt_ids = vqidx_keys + self.return_utt_id = return_utt_id + self.allow_cache = allow_cache + + if allow_cache: + # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 + self.manager = Manager() + self.caches = self.manager.list() + self.caches += [() for _ in range(len(self.utt_ids))] + + def __getitem__(self, idx): + """Get specified idx items. + + Args: + idx (int): Index of the item. + + Returns: + str: Utterance id (only in return_utt_id = True). + ndarray: Feature (T', C). + + """ + if self.allow_cache and len(self.caches[idx]) != 0: + return self.caches[idx] + + utt_id = self.utt_ids[idx] + vqidx = self.vqidx_loader[utt_id].astype(int) + + # prompt = torch.load(self.prompt_loader[utt_id]).float().numpy() + prompt = self.prompt_loader[utt_id] + + if self.return_utt_id: + items = utt_id, vqidx, prompt + else: + items = vqidx, prompt + + if self.allow_cache: + self.caches[idx] = items + + return items + + def __len__(self): + """Return dataset length. + + Returns: + int: The length of dataset. + + """ + return len(self.utt_ids) diff --git a/vec2wav2/distributed/__init__.py b/vec2wav2/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vec2wav2/distributed/launch.py b/vec2wav2/distributed/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..292f2a92287bfd201815748465727b76d9a5008e --- /dev/null +++ b/vec2wav2/distributed/launch.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""Distributed process launcher. + +This code is modified from https://github.com/pytorch/pytorch/blob/v1.3.0/torch/distributed/launch.py. + +""" +import os +import subprocess +import sys + +from argparse import ArgumentParser +from argparse import REMAINDER + + +def parse_args(): + """Parse arguments.""" + parser = ArgumentParser( + description="PyTorch distributed training launch " + "helper utilty that will spawn up " + "multiple distributed processes" + ) + + # Optional arguments for the launch helper + parser.add_argument( + "--nnodes", + type=int, + default=1, + help="The number of nodes to use for distributed " "training", + ) + parser.add_argument( + "--node_rank", + type=int, + default=0, + help="The rank of the node for multi-node distributed " "training", + ) + parser.add_argument( + "--nproc_per_node", + type=int, + default=1, + help="The number of processes to launch on each node, " + "for GPU training, this is recommended to be set " + "to the number of GPUs in your system so that " + "each process can be bound to a single GPU.", + ) + parser.add_argument( + "--master_addr", + default="127.0.0.1", + type=str, + help="Master node (rank 0)'s address, should be either " + "the IP address or the hostname of node 0, for " + "single node multi-proc training, the " + "--master_addr can simply be 127.0.0.1", + ) + parser.add_argument( + "--master_port", + default=29500, + type=int, + help="Master node (rank 0)'s free port that needs to " + "be used for communciation during distributed " + "training", + ) + parser.add_argument( + "--use_env", + default=False, + action="store_true", + help="Use environment variable to pass " + "'local rank'. For legacy reasons, the default value is False. " + "If set to True, the script will not pass " + "--local_rank as argument, and will instead set LOCAL_RANK.", + ) + parser.add_argument( + "-m", + "--module", + default=False, + action="store_true", + help="Changes each process to interpret the launch script " + "as a python module, executing with the same behavior as" + "'python -m'.", + ) + parser.add_argument( + "-c", + "--command", + default=False, + action="store_true", + help="Changes each process to interpret the launch script " "as a command.", + ) + + # positional + parser.add_argument( + "training_script", + type=str, + help="The full path to the single GPU training " + "program/script/command to be launched in parallel, " + "followed by all the arguments for the " + "training script", + ) + + # rest from the training program + parser.add_argument("training_script_args", nargs=REMAINDER) + return parser.parse_args() + + +def main(): + """Launch distributed processes.""" + args = parse_args() + + # world size in terms of number of processes + dist_world_size = args.nproc_per_node * args.nnodes + + # set PyTorch distributed related environmental variables + current_env = os.environ.copy() + current_env["MASTER_ADDR"] = args.master_addr + current_env["MASTER_PORT"] = str(args.master_port) + current_env["WORLD_SIZE"] = str(dist_world_size) + + processes = [] + + if "OMP_NUM_THREADS" not in os.environ and args.nproc_per_node > 1: + current_env["OMP_NUM_THREADS"] = str(1) + print( + "*****************************************\n" + "Setting OMP_NUM_THREADS environment variable for each process " + "to be {} in default, to avoid your system being overloaded, " + "please further tune the variable for optimal performance in " + "your application as needed. \n" + "*****************************************".format( + current_env["OMP_NUM_THREADS"] + ) + ) + + for local_rank in range(0, args.nproc_per_node): + # each process's rank + dist_rank = args.nproc_per_node * args.node_rank + local_rank + current_env["RANK"] = str(dist_rank) + current_env["LOCAL_RANK"] = str(local_rank) + + # spawn the processes + if args.command: + cmd = [args.training_script] + else: + cmd = [sys.executable, "-u"] + if args.module: + cmd.append("-m") + cmd.append(args.training_script) + + if not args.use_env: + cmd.append("--local_rank={}".format(local_rank)) + + cmd.extend(args.training_script_args) + + process = subprocess.Popen(cmd, env=current_env) + processes.append(process) + + for process in processes: + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + + +if __name__ == "__main__": + main() diff --git a/vec2wav2/layers/__init__.py b/vec2wav2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0b7f142ce105f662f69f3e0c5d4967b5c86c22 --- /dev/null +++ b/vec2wav2/layers/__init__.py @@ -0,0 +1,6 @@ +from .causal_conv import * # NOQA +from .pqmf import * # NOQA +from .residual_block import * # NOQA +from .residual_stack import * # NOQA +from .tade_res_block import * # NOQA +from .upsample import * # NOQA diff --git a/vec2wav2/layers/__pycache__/__init__.cpython-310.pyc b/vec2wav2/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a7f9fd848ab9791853355cb569e1645aed0f5f6 Binary files /dev/null and b/vec2wav2/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/__init__.cpython-39.pyc b/vec2wav2/layers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97a6f53eca7e48bcb9117a2b8a329d4c942dfc34 Binary files /dev/null and b/vec2wav2/layers/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/layers/__pycache__/activations.cpython-310.pyc b/vec2wav2/layers/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8b5ddce4545ea2441a87e4b4301b653bee09f3d Binary files /dev/null and b/vec2wav2/layers/__pycache__/activations.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc b/vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff232260788614257d40f304c5d736326980c0e5 Binary files /dev/null and b/vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc b/vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a0b1a269d7deb986418030fae67909d78f2c0b5 Binary files /dev/null and b/vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc differ diff --git a/vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc b/vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f82cb076574b3a900a7a58551f55657da8f8f3c6 Binary files /dev/null and b/vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc b/vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e0538ba42a6cd8866cce0c56b5b6a8f6d7a9123 Binary files /dev/null and b/vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc differ diff --git a/vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc b/vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e7331961b6d6687fe0a89565fd3c839ead91451 Binary files /dev/null and b/vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc b/vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cf7ef008d2e23b228388b7dee52ee1735bd38f8 Binary files /dev/null and b/vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc differ diff --git a/vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc b/vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24407f41c2ee11a6d6f89034c7207b7a5e1d190c Binary files /dev/null and b/vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc b/vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d896e8994f8d4f49ca9015d96b903749e74d3a97 Binary files /dev/null and b/vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc differ diff --git a/vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc b/vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbbb33f7e95e10020fa2cf9e31076d2ee5de0f8d Binary files /dev/null and b/vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc b/vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..138bda42d9cc9ccbbead0abfea12bdb74ab61052 Binary files /dev/null and b/vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc differ diff --git a/vec2wav2/layers/__pycache__/upsample.cpython-310.pyc b/vec2wav2/layers/__pycache__/upsample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1a7565b116f2c36b6b989337f3206fb78af4c70 Binary files /dev/null and b/vec2wav2/layers/__pycache__/upsample.cpython-310.pyc differ diff --git a/vec2wav2/layers/__pycache__/upsample.cpython-39.pyc b/vec2wav2/layers/__pycache__/upsample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26f5630b423065e01625065b8d08197f23e8b576 Binary files /dev/null and b/vec2wav2/layers/__pycache__/upsample.cpython-39.pyc differ diff --git a/vec2wav2/layers/activations.py b/vec2wav2/layers/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..9c98131e2352ae589fd067cd3e82ad0a27801369 --- /dev/null +++ b/vec2wav2/layers/activations.py @@ -0,0 +1,197 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +# Modified by Yiwei Guo, 2024 +# including conditioned snakebeta activation + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake := x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x, cond=None): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBetaWithCondition(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Condition: (B, D), where D-dimension will be mapped to C dimensions + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + - condition_alpha_prenet - trainable parameter that controls alpha and beta using condition + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256, 128) + >>> x = torch.randn(256) + >>> cond = torch.randn(128) + >>> x = a1(x, cond) + ''' + def __init__(self, in_features, condition_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: dimension of the input + - condition_features: dimension of the condition vectors + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha, beta will be trained along with the rest of your model. + ''' + super(SnakeBetaWithCondition, self).__init__() + self.in_features = in_features + + self.condition_alpha_prenet = torch.nn.Linear(condition_features, in_features) + # self.condition_beta_prenet = torch.nn.Linear(condition_features, in_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x, condition): + ''' + condition: [B, D] + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta := x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + + condition = torch.tanh(self.condition_alpha_prenet(condition).unsqueeze(-1)) # Same prenet for both alpha and beta, to save parameters + alpha = alpha + condition + beta = beta + 0.5 * condition # multiply 0.5 for avoiding beta being too small + + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/vec2wav2/layers/causal_conv.py b/vec2wav2/layers/causal_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..abf51b8e95dc5eaefb8938aac10d77a07d85dca6 --- /dev/null +++ b/vec2wav2/layers/causal_conv.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Causal convolusion layer modules.""" + + +import torch + + +class CausalConv1d(torch.nn.Module): + """CausalConv1d module with customized initialization.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation=1, + bias=True, + pad="ConstantPad1d", + pad_params={"value": 0.0}, + ): + """Initialize CausalConv1d module.""" + super(CausalConv1d, self).__init__() + self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) + self.conv = torch.nn.Conv1d( + in_channels, out_channels, kernel_size, dilation=dilation, bias=bias + ) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + return self.conv(self.pad(x))[:, :, : x.size(2)] + + +class CausalConvTranspose1d(torch.nn.Module): + """CausalConvTranspose1d module with customized initialization.""" + + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): + """Initialize CausalConvTranspose1d module.""" + super(CausalConvTranspose1d, self).__init__() + self.deconv = torch.nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride, bias=bias + ) + self.stride = stride + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_in). + + Returns: + Tensor: Output tensor (B, out_channels, T_out). + + """ + return self.deconv(x)[:, :, : -self.stride] diff --git a/vec2wav2/layers/pqmf.py b/vec2wav2/layers/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd46a3ca1e1ef272f7ac21c4bad22e0391f6555 --- /dev/null +++ b/vec2wav2/layers/pqmf.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Pseudo QMF modules.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from scipy.signal import kaiser + + +def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): + """Design prototype filter for PQMF. + + This method is based on `A Kaiser window approach for the design of prototype + filters of cosine modulated filterbanks`_. + + Args: + taps (int): The number of filter taps. + cutoff_ratio (float): Cut-off frequency ratio. + beta (float): Beta coefficient for kaiser window. + + Returns: + ndarray: Impluse response of prototype filter (taps + 1,). + + .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: + https://ieeexplore.ieee.org/abstract/document/681427 + + """ + # check the arguments are valid + assert taps % 2 == 0, "The number of taps mush be even number." + assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." + + # make initial filter + omega_c = np.pi * cutoff_ratio + with np.errstate(invalid="ignore"): + h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( + np.pi * (np.arange(taps + 1) - 0.5 * taps) + ) + h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form + + # apply kaiser window + w = kaiser(taps + 1, beta) + h = h_i * w + + return h + + +class PQMF(torch.nn.Module): + """PQMF module. + + This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. + + .. _`Near-perfect-reconstruction pseudo-QMF banks`: + https://ieeexplore.ieee.org/document/258122 + + """ + + def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): + """Initilize PQMF module. + + The cutoff_ratio and beta parameters are optimized for #subbands = 4. + See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. + + Args: + subbands (int): The number of subbands. + taps (int): The number of filter taps. + cutoff_ratio (float): Cut-off frequency ratio. + beta (float): Beta coefficient for kaiser window. + + """ + super(PQMF, self).__init__() + + # build analysis & synthesis filter coefficients + h_proto = design_prototype_filter(taps, cutoff_ratio, beta) + h_analysis = np.zeros((subbands, len(h_proto))) + h_synthesis = np.zeros((subbands, len(h_proto))) + for k in range(subbands): + h_analysis[k] = ( + 2 + * h_proto + * np.cos( + (2 * k + 1) + * (np.pi / (2 * subbands)) + * (np.arange(taps + 1) - (taps / 2)) + + (-1) ** k * np.pi / 4 + ) + ) + h_synthesis[k] = ( + 2 + * h_proto + * np.cos( + (2 * k + 1) + * (np.pi / (2 * subbands)) + * (np.arange(taps + 1) - (taps / 2)) + - (-1) ** k * np.pi / 4 + ) + ) + + # convert to tensor + analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) + synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) + + # register coefficients as beffer + self.register_buffer("analysis_filter", analysis_filter) + self.register_buffer("synthesis_filter", synthesis_filter) + + # filter for downsampling & upsampling + updown_filter = torch.zeros((subbands, subbands, subbands)).float() + for k in range(subbands): + updown_filter[k, k, 0] = 1.0 + self.register_buffer("updown_filter", updown_filter) + self.subbands = subbands + + # keep padding info + self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) + + def analysis(self, x): + """Analysis with PQMF. + + Args: + x (Tensor): Input tensor (B, 1, T). + + Returns: + Tensor: Output tensor (B, subbands, T // subbands). + + """ + x = F.conv1d(self.pad_fn(x), self.analysis_filter) + return F.conv1d(x, self.updown_filter, stride=self.subbands) + + def synthesis(self, x): + """Synthesis with PQMF. + + Args: + x (Tensor): Input tensor (B, subbands, T // subbands). + + Returns: + Tensor: Output tensor (B, 1, T). + + """ + # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands. + # Not sure this is the correct way, it is better to check again. + # TODO(kan-bayashi): Understand the reconstruction procedure + x = F.conv_transpose1d( + x, self.updown_filter * self.subbands, stride=self.subbands + ) + return F.conv1d(self.pad_fn(x), self.synthesis_filter) diff --git a/vec2wav2/layers/residual_block.py b/vec2wav2/layers/residual_block.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e9d6d240213ec897d4872d4a7d2b5d7d1158af --- /dev/null +++ b/vec2wav2/layers/residual_block.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- + +"""Residual block modules. + +References: + - https://github.com/r9y9/wavenet_vocoder + - https://github.com/jik876/hifi-gan + +""" + +import math + +import torch +import torch.nn.functional as F + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super(Conv1d, self).__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels, out_channels, bias): + """Initialize 1x1 Conv1d module.""" + super(Conv1d1x1, self).__init__( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +class WaveNetResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size=3, + residual_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False, + ): + """Initialize WaveNetResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution. + + """ + super().__init__() + self.dropout = dropout + # no future time stamps available + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + # dilation conv + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) + + def forward(self, x, c): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv(x) + + # remove future time steps if use_causal_conv conv + x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + assert self.conv1x1_aux is not None + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # for skip connection + s = self.conv1x1_skip(x) + + # for residual connection + x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) + + return x, s + + +class HiFiGANResidualBlock(torch.nn.Module): + """Residual block module in HiFiGAN.""" + + def __init__( + self, + kernel_size=3, + channels=512, + dilations=(1, 3, 5), + bias=True, + use_additional_convs=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.1}, + ): + """Initialize HiFiGANResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels for convolution layer. + dilations (List[int]): List of dilation factors. + use_additional_convs (bool): Whether to use additional convolution layers. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + + """ + super().__init__() + self.use_additional_convs = use_additional_convs + self.convs1 = torch.nn.ModuleList() + if use_additional_convs: + self.convs2 = torch.nn.ModuleList() + assert kernel_size % 2 == 1, "Kernel size must be odd number." + for dilation in dilations: + self.convs1 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + bias=bias, + padding=(kernel_size - 1) // 2 * dilation, + ), + ) + ] + if use_additional_convs: + self.convs2 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + ) + ] + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + for idx in range(len(self.convs1)): + xt = self.convs1[idx](x) + if self.use_additional_convs: + xt = self.convs2[idx](xt) + x = xt + x + return x diff --git a/vec2wav2/layers/residual_stack.py b/vec2wav2/layers/residual_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..d57263069cd8c387315057c407bbb2e3cb0eeec2 --- /dev/null +++ b/vec2wav2/layers/residual_stack.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Residual stack module in MelGAN.""" + +import torch + +from vec2wav2.layers import CausalConv1d + + +class ResidualStack(torch.nn.Module): + """Residual stack module introduced in MelGAN.""" + + def __init__( + self, + kernel_size=3, + channels=32, + dilation=1, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_causal_conv=False, + ): + """Initialize ResidualStack module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels of convolution layers. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(ResidualStack, self).__init__() + + # defile residual stack part + if not use_causal_conv: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + self.stack = torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params), + torch.nn.Conv1d( + channels, channels, kernel_size, dilation=dilation, bias=bias + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + torch.nn.Conv1d(channels, channels, 1, bias=bias), + ) + else: + self.stack = torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + CausalConv1d( + channels, + channels, + kernel_size, + dilation=dilation, + bias=bias, + pad=pad, + pad_params=pad_params, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + torch.nn.Conv1d(channels, channels, 1, bias=bias), + ) + + # defile extra layer for skip connection + self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias) + + def forward(self, c): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, chennels, T). + + """ + return self.stack(c) + self.skip_layer(c) diff --git a/vec2wav2/layers/tade_res_block.py b/vec2wav2/layers/tade_res_block.py new file mode 100644 index 0000000000000000000000000000000000000000..bcad421c351d50a582559bf2395ce0559f667737 --- /dev/null +++ b/vec2wav2/layers/tade_res_block.py @@ -0,0 +1,160 @@ +# Copyright 2021 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""StyleMelGAN's TADEResBlock Modules.""" + +from functools import partial + +import torch + + +class TADELayer(torch.nn.Module): + """TADE Layer module.""" + + def __init__( + self, + in_channels=64, + aux_channels=80, + kernel_size=9, + bias=True, + upsample_factor=2, + upsample_mode="nearest", + ): + """Initilize TADE layer.""" + super().__init__() + self.norm = torch.nn.InstanceNorm1d(in_channels) + self.aux_conv = torch.nn.Sequential( + torch.nn.Conv1d( + aux_channels, + in_channels, + kernel_size, + 1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + # NOTE(kan-bayashi): Use non-linear activation? + ) + self.gated_conv = torch.nn.Sequential( + torch.nn.Conv1d( + in_channels, + in_channels * 2, + kernel_size, + 1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + # NOTE(kan-bayashi): Use non-linear activation? + ) + self.upsample = torch.nn.Upsample( + scale_factor=upsample_factor, mode=upsample_mode + ) + + def forward(self, x, c): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + c (Tensor): Auxiliary input tensor (B, aux_channels, T'). + + Returns: + Tensor: Output tensor (B, in_channels, T * in_upsample_factor). + Tensor: Upsampled aux tensor (B, in_channels, T * aux_upsample_factor). + + """ + x = self.norm(x) + c = self.upsample(c) + c = self.aux_conv(c) + cg = self.gated_conv(c) + cg1, cg2 = cg.split(cg.size(1) // 2, dim=1) + # NOTE(kan-bayashi): Use upsample for noise input here? + y = cg1 * self.upsample(x) + cg2 + # NOTE(kan-bayashi): Return upsampled aux here? + return y, c + + +class TADEResBlock(torch.nn.Module): + """TADEResBlock module.""" + + def __init__( + self, + in_channels=64, + aux_channels=80, + kernel_size=9, + dilation=2, + bias=True, + upsample_factor=2, + upsample_mode="nearest", + gated_function="softmax", + ): + """Initialize TADEResBlock module.""" + super().__init__() + self.tade1 = TADELayer( + in_channels=in_channels, + aux_channels=aux_channels, + kernel_size=kernel_size, + bias=bias, + # NOTE(kan-bayashi): Use upsample in the first TADE layer? + upsample_factor=1, + upsample_mode=upsample_mode, + ) + self.gated_conv1 = torch.nn.Conv1d( + in_channels, + in_channels * 2, + kernel_size, + 1, + bias=bias, + padding=(kernel_size - 1) // 2, + ) + self.tade2 = TADELayer( + in_channels=in_channels, + aux_channels=in_channels, + kernel_size=kernel_size, + bias=bias, + upsample_factor=upsample_factor, + upsample_mode=upsample_mode, + ) + self.gated_conv2 = torch.nn.Conv1d( + in_channels, + in_channels * 2, + kernel_size, + 1, + bias=bias, + dilation=dilation, + padding=(kernel_size - 1) // 2 * dilation, + ) + self.upsample = torch.nn.Upsample( + scale_factor=upsample_factor, mode=upsample_mode + ) + if gated_function == "softmax": + self.gated_function = partial(torch.softmax, dim=1) + elif gated_function == "sigmoid": + self.gated_function = torch.sigmoid + else: + raise ValueError(f"{gated_function} is not supported.") + + def forward(self, x, c): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + c (Tensor): Auxiliary input tensor (B, aux_channels, T'). + + Returns: + Tensor: Output tensor (B, in_channels, T * in_upsample_factor). + Tensor: Upsampled auxirialy tensor (B, in_channels, T * in_upsample_factor). + + """ + residual = x + + x, c = self.tade1(x, c) + x = self.gated_conv1(x) + xa, xb = x.split(x.size(1) // 2, dim=1) + x = self.gated_function(xa) * torch.tanh(xb) + + x, c = self.tade2(x, c) + x = self.gated_conv2(x) + xa, xb = x.split(x.size(1) // 2, dim=1) + x = self.gated_function(xa) * torch.tanh(xb) + + # NOTE(kan-bayashi): Return upsampled aux here? + return self.upsample(residual) + x, c diff --git a/vec2wav2/layers/upsample.py b/vec2wav2/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..adf32a3a051a1335e0cd6f7b3ee263dfaf8eca59 --- /dev/null +++ b/vec2wav2/layers/upsample.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- + +"""Upsampling module. + +This code is modified from https://github.com/r9y9/wavenet_vocoder. + +""" + +import numpy as np +import torch +import torch.nn.functional as F + +from vec2wav2.layers import Conv1d + + +class Stretch2d(torch.nn.Module): + """Stretch2d module.""" + + def __init__(self, x_scale, y_scale, mode="nearest"): + """Initialize Stretch2d module. + + Args: + x_scale (int): X scaling factor (Time axis in spectrogram). + y_scale (int): Y scaling factor (Frequency axis in spectrogram). + mode (str): Interpolation mode. + + """ + super(Stretch2d, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, C, F, T). + + Returns: + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + + """ + return F.interpolate( + x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode + ) + + +class Conv2d(torch.nn.Conv2d): + """Conv2d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv2d module.""" + super(Conv2d, self).__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + self.weight.data.fill_(1.0 / np.prod(self.kernel_size)) + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class UpsampleNetwork(torch.nn.Module): + """Upsampling network module.""" + + def __init__( + self, + upsample_scales, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + """Initialize upsampling network module. + + Args: + upsample_scales (list): List of upsampling scales. + nonlinear_activation (str): Activation function name. + nonlinear_activation_params (dict): Arguments for specified activation function. + interpolate_mode (str): Interpolation mode. + freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. + + """ + super(UpsampleNetwork, self).__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = torch.nn.ModuleList() + for scale in upsample_scales: + # interpolation layer + stretch = Stretch2d(scale, 1, interpolate_mode) + self.up_layers += [stretch] + + # conv layer + assert ( + freq_axis_kernel_size - 1 + ) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.up_layers += [conv] + + # nonlinear + if nonlinear_activation is not None: + nonlinear = getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ) + self.up_layers += [nonlinear] + + def forward(self, c): + """Calculate forward propagation. + + Args: + c : Input tensor (B, C, T). + + Returns: + Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). + + """ + c = c.unsqueeze(1) # (B, 1, C, T) + for f in self.up_layers: + if self.use_causal_conv and isinstance(f, Conv2d): + c = f(c)[..., : c.size(-1)] + else: + c = f(c) + return c.squeeze(1) # (B, C, T') + + +class ConvInUpsampleNetwork(torch.nn.Module): + """Convolution + upsampling network module.""" + + def __init__( + self, + upsample_scales, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False, + ): + """Initialize convolution + upsampling network module. + + Args: + upsample_scales (list): List of upsampling scales. + nonlinear_activation (str): Activation function name. + nonlinear_activation_params (dict): Arguments for specified activation function. + mode (str): Interpolation mode. + freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. + aux_channels (int): Number of channels of pre-convolutional layer. + aux_context_window (int): Context window size of the pre-convolutional layer. + use_causal_conv (bool): Whether to use causal structure. + + """ + super(ConvInUpsampleNetwork, self).__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + # To capture wide-context information in conditional features + kernel_size = ( + aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + ) + # NOTE(kan-bayashi): Here do not use padding because the input is already padded + self.conv_in = Conv1d( + aux_channels, aux_channels, kernel_size=kernel_size, bias=False + ) + self.upsample = UpsampleNetwork( + upsample_scales=upsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + ) + + def forward(self, c): + """Calculate forward propagation. + + Args: + c : Input tensor (B, C, T'). + + Returns: + Tensor: Upsampled tensor (B, C, T), + where T = (T' - aux_context_window * 2) * prod(upsample_scales). + + Note: + The length of inputs considers the context window size. + + """ + c_ = self.conv_in(c) + c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) diff --git a/vec2wav2/losses/__init__.py b/vec2wav2/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adb36e634a0f4f769663e31b86d205a90dc141bc --- /dev/null +++ b/vec2wav2/losses/__init__.py @@ -0,0 +1,4 @@ +from .adversarial_loss import * # NOQA +from .feat_match_loss import * # NOQA +from .mel_loss import * # NOQA +from .stft_loss import * # NOQA diff --git a/vec2wav2/losses/__pycache__/__init__.cpython-310.pyc b/vec2wav2/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a1fd9a289a5292358c895f40d53d73a1d87a803 Binary files /dev/null and b/vec2wav2/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/losses/__pycache__/__init__.cpython-39.pyc b/vec2wav2/losses/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f91c4be0f616cd26a214fc88014ed01d87073e5 Binary files /dev/null and b/vec2wav2/losses/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/losses/__pycache__/adversarial_loss.cpython-310.pyc b/vec2wav2/losses/__pycache__/adversarial_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23466a803de31f4433759fd77fc0764c398cbbd3 Binary files /dev/null and b/vec2wav2/losses/__pycache__/adversarial_loss.cpython-310.pyc differ diff --git a/vec2wav2/losses/__pycache__/adversarial_loss.cpython-39.pyc b/vec2wav2/losses/__pycache__/adversarial_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccaa0093f7cb43346efb93e9081d630d492664eb Binary files /dev/null and b/vec2wav2/losses/__pycache__/adversarial_loss.cpython-39.pyc differ diff --git a/vec2wav2/losses/__pycache__/feat_match_loss.cpython-310.pyc b/vec2wav2/losses/__pycache__/feat_match_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d37d6116370a70f6fd4e9fb902d0c951e4a3f37 Binary files /dev/null and b/vec2wav2/losses/__pycache__/feat_match_loss.cpython-310.pyc differ diff --git a/vec2wav2/losses/__pycache__/feat_match_loss.cpython-39.pyc b/vec2wav2/losses/__pycache__/feat_match_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fd33d0e0072410d0c42322501fe8d5b71726ef9 Binary files /dev/null and b/vec2wav2/losses/__pycache__/feat_match_loss.cpython-39.pyc differ diff --git a/vec2wav2/losses/__pycache__/mel_loss.cpython-310.pyc b/vec2wav2/losses/__pycache__/mel_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..845c6e72e5a51a2fdc75cd2075d7661804182d0f Binary files /dev/null and b/vec2wav2/losses/__pycache__/mel_loss.cpython-310.pyc differ diff --git a/vec2wav2/losses/__pycache__/mel_loss.cpython-39.pyc b/vec2wav2/losses/__pycache__/mel_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8936a14dff40acfb23c8ff53a5eaac808fa6932 Binary files /dev/null and b/vec2wav2/losses/__pycache__/mel_loss.cpython-39.pyc differ diff --git a/vec2wav2/losses/__pycache__/stft_loss.cpython-310.pyc b/vec2wav2/losses/__pycache__/stft_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..242eb6ba457de61c8d58a1e6ab743dfd1ea3cbe9 Binary files /dev/null and b/vec2wav2/losses/__pycache__/stft_loss.cpython-310.pyc differ diff --git a/vec2wav2/losses/__pycache__/stft_loss.cpython-39.pyc b/vec2wav2/losses/__pycache__/stft_loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e651f0d8cd901e3d2bdd2019f3faca3a568bb7f0 Binary files /dev/null and b/vec2wav2/losses/__pycache__/stft_loss.cpython-39.pyc differ diff --git a/vec2wav2/losses/adversarial_loss.py b/vec2wav2/losses/adversarial_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c7624fa95e61261e9ded6ff3e6e39828fa878e0e --- /dev/null +++ b/vec2wav2/losses/adversarial_loss.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Adversarial loss modules.""" + +import torch +import torch.nn.functional as F + + +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", + ): + """Initialize GeneratorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward(self, outputs): + """Calcualate generator adversarial loss. + + Args: + outputs (Tensor or list): Discriminator outputs or list of + discriminator outputs. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators=True, + loss_type="mse", + ): + """Initialize DiscriminatorAversarialLoss module.""" + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward(self, outputs_hat, outputs): + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Tensor or list): Discriminator outputs or list of + discriminator outputs calculated from generator outputs. + outputs (Tensor or list): Discriminator outputs or list of + discriminator outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x): + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x): + return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) + + def _hinge_fake_loss(self, x): + return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) diff --git a/vec2wav2/losses/feat_match_loss.py b/vec2wav2/losses/feat_match_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9cee14db09b89b631d2a315ec6dd01f6d2f5a65c --- /dev/null +++ b/vec2wav2/losses/feat_match_loss.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +# Copyright 2021 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Feature matching loss modules.""" + +import torch +import torch.nn.functional as F + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers=True, + average_by_discriminators=True, + include_final_outputs=False, + ): + """Initialize FeatureMatchLoss module.""" + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward(self, feats_hat, feats): + """Calcualate feature matching loss. + + Args: + feats_hat (list): List of list of discriminator outputs + calcuated from generater outputs. + feats (list): List of list of discriminator outputs + calcuated from groundtruth. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss diff --git a/vec2wav2/losses/mel_loss.py b/vec2wav2/losses/mel_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..58b12bb76a4e9755d749ae83ba520ca2a3dbea2b --- /dev/null +++ b/vec2wav2/losses/mel_loss.py @@ -0,0 +1,166 @@ +# Copyright 2021 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Mel-spectrogram loss modules.""" + +from distutils.version import LooseVersion + +import librosa +import torch +import torch.nn.functional as F + + +is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7") + + +class MelSpectrogram(torch.nn.Module): + """Calculate Mel-spectrogram.""" + + def __init__( + self, + fs=22050, + fft_size=1024, + hop_size=256, + win_length=None, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + center=True, + normalized=False, + onesided=True, + eps=1e-10, + log_base=10.0, + ): + """Initialize MelSpectrogram module.""" + super().__init__() + self.fft_size = fft_size + if win_length is None: + self.win_length = fft_size + else: + self.win_length = win_length + self.hop_size = hop_size + self.center = center + self.normalized = normalized + self.onesided = onesided + if window is not None and not hasattr(torch, f"{window}_window"): + raise ValueError(f"{window} window is not implemented") + self.window = window + self.eps = eps + + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + melmat = librosa.filters.mel( + sr=fs, + n_fft=fft_size, + n_mels=num_mels, + fmin=fmin, + fmax=fmax, + ) + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + self.stft_params = { + "n_fft": self.fft_size, + "win_length": self.win_length, + "hop_length": self.hop_size, + "center": self.center, + "normalized": self.normalized, + "onesided": self.onesided, + } + if is_pytorch_17plus: + self.stft_params["return_complex"] = False + + self.log_base = log_base + if self.log_base is None: + self.log = torch.log + elif self.log_base == 2.0: + self.log = torch.log2 + elif self.log_base == 10.0: + self.log = torch.log10 + else: + raise ValueError(f"log_base: {log_base} is not supported.") + + def forward(self, x): + """Calculate Mel-spectrogram. + + Args: + x (Tensor): Input waveform tensor (B, T) or (B, 1, T). + + Returns: + Tensor: Mel-spectrogram (B, #mels, #frames). + + """ + if x.dim() == 3: + # (B, C, T) -> (B*C, T) + x = x.reshape(-1, x.size(2)) + + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + window = window_func(self.win_length, dtype=x.dtype, device=x.device) + else: + window = None + + x_stft = torch.stft(x, window=window, **self.stft_params) + # (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2) + x_stft = x_stft.transpose(1, 2) + x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2 + x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps)) + + x_mel = torch.matmul(x_amp, self.melmat) + x_mel = torch.clamp(x_mel, min=self.eps) + + return self.log(x_mel).transpose(1, 2) + + +class MelSpectrogramLoss(torch.nn.Module): + """Mel-spectrogram loss.""" + + def __init__( + self, + fs=22050, + fft_size=1024, + hop_size=256, + win_length=None, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + center=True, + normalized=False, + onesided=True, + eps=1e-10, + log_base=10.0, + ): + """Initialize Mel-spectrogram loss.""" + super().__init__() + self.mel_spectrogram = MelSpectrogram( + fs=fs, + fft_size=fft_size, + hop_size=hop_size, + win_length=win_length, + window=window, + num_mels=num_mels, + fmin=fmin, + fmax=fmax, + center=center, + normalized=normalized, + onesided=onesided, + eps=eps, + log_base=log_base, + ) + + def forward(self, y_hat, y): + """Calculate Mel-spectrogram loss. + + Args: + y_hat (Tensor): Generated single tensor (B, 1, T). + y (Tensor): Groundtruth single tensor (B, 1, T). + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_hat = self.mel_spectrogram(y_hat) + mel = self.mel_spectrogram(y) + mel_loss = F.l1_loss(mel_hat, mel) + + return mel_loss diff --git a/vec2wav2/losses/stft_loss.py b/vec2wav2/losses/stft_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b5923559d6cae5c335b6febc8b8e2124ce0c4487 --- /dev/null +++ b/vec2wav2/losses/stft_loss.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""STFT-based Loss modules.""" + +import torch +import torch.nn.functional as F + +from distutils.version import LooseVersion + +is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7") + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + + """ + if is_pytorch_17plus: + x_stft = torch.stft( + x, fft_size, hop_size, win_length, window, return_complex=False + ) + else: + x_stft = torch.stft(x, fft_size, hop_size, win_length, window) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) + + +class SpectralConvergenceLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self): + """Initilize spectral convergence loss module.""" + super(SpectralConvergenceLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + + Returns: + Tensor: Spectral convergence loss value. + + """ + return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + + +class LogSTFTMagnitudeLoss(torch.nn.Module): + """Log STFT magnitude loss module.""" + + def __init__(self): + """Initilize los STFT magnitude loss module.""" + super(LogSTFTMagnitudeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + + Returns: + Tensor: Log STFT magnitude loss value. + + """ + return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__( + self, fft_size=1024, shift_size=120, win_length=600, window="hann_window" + ): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.spectral_convergence_loss = SpectralConvergenceLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + # NOTE(kan-bayashi): Use register_buffer to fix #223 + self.register_buffer("window", getattr(torch, window)(win_length)) + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) + sc_loss = self.spectral_convergence_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__( + self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window", + ): + """Initialize Multi resolution STFT loss module. + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + + """ + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window)] + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T) or (B, #subband, T). + y (Tensor): Groundtruth signal (B, T) or (B, #subband, T). + + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + + """ + if len(x.shape) == 3: + x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T) + y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T) + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss diff --git a/vec2wav2/models/__init__.py b/vec2wav2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da11d57538a263c318102773190441cc1e29ebf0 --- /dev/null +++ b/vec2wav2/models/__init__.py @@ -0,0 +1,3 @@ +from .hifigan import * # NOQA +from .melgan import * # NOQA +from .v2w2 import * diff --git a/vec2wav2/models/__pycache__/__init__.cpython-310.pyc b/vec2wav2/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed67cb17f474988585a325ffc6739cd1cb6c3407 Binary files /dev/null and b/vec2wav2/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/models/__pycache__/__init__.cpython-39.pyc b/vec2wav2/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a40ce93530ec9eaa782398006fb4a5a7635ac5e Binary files /dev/null and b/vec2wav2/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/models/__pycache__/bigvgan.cpython-310.pyc b/vec2wav2/models/__pycache__/bigvgan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb450cc71faf480ebfb0da49526ae7438b77736e Binary files /dev/null and b/vec2wav2/models/__pycache__/bigvgan.cpython-310.pyc differ diff --git a/vec2wav2/models/__pycache__/ctx_v2w.cpython-310.pyc b/vec2wav2/models/__pycache__/ctx_v2w.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9bd4010dae74e3dbace0ee374ca90696bdd8d1f Binary files /dev/null and b/vec2wav2/models/__pycache__/ctx_v2w.cpython-310.pyc differ diff --git a/vec2wav2/models/__pycache__/ctx_v2w.cpython-39.pyc b/vec2wav2/models/__pycache__/ctx_v2w.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..393fe68a134683abec896e4f342cd360b9f7ab9b Binary files /dev/null and b/vec2wav2/models/__pycache__/ctx_v2w.cpython-39.pyc differ diff --git a/vec2wav2/models/__pycache__/hifigan.cpython-310.pyc b/vec2wav2/models/__pycache__/hifigan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..152dec09f80e9967608fbe6dae3aa24a9a0e8a39 Binary files /dev/null and b/vec2wav2/models/__pycache__/hifigan.cpython-310.pyc differ diff --git a/vec2wav2/models/__pycache__/hifigan.cpython-39.pyc b/vec2wav2/models/__pycache__/hifigan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9014a1950a19fe6025f86fe66b81ddbb1b87ec98 Binary files /dev/null and b/vec2wav2/models/__pycache__/hifigan.cpython-39.pyc differ diff --git a/vec2wav2/models/__pycache__/melgan.cpython-310.pyc b/vec2wav2/models/__pycache__/melgan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a44836b9de1acb3c3385db84c028f8e8ad66e54 Binary files /dev/null and b/vec2wav2/models/__pycache__/melgan.cpython-310.pyc differ diff --git a/vec2wav2/models/__pycache__/melgan.cpython-39.pyc b/vec2wav2/models/__pycache__/melgan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d06ffbe4a017867ef8efb0f4fb6521f94b1036a8 Binary files /dev/null and b/vec2wav2/models/__pycache__/melgan.cpython-39.pyc differ diff --git a/vec2wav2/models/__pycache__/prompt_prenet.cpython-310.pyc b/vec2wav2/models/__pycache__/prompt_prenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44210924ac131867a4cd79acfaea7849c5563879 Binary files /dev/null and b/vec2wav2/models/__pycache__/prompt_prenet.cpython-310.pyc differ diff --git a/vec2wav2/models/__pycache__/v2w2.cpython-310.pyc b/vec2wav2/models/__pycache__/v2w2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17bcb1dcf8880ef594ee2e1661b1e5223c012d97 Binary files /dev/null and b/vec2wav2/models/__pycache__/v2w2.cpython-310.pyc differ diff --git a/vec2wav2/models/bigvgan.py b/vec2wav2/models/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..4dde6e2a0fbb81245831f5d4c32b8b3ce5baf486 --- /dev/null +++ b/vec2wav2/models/bigvgan.py @@ -0,0 +1,414 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +# Modified by Yiwei Guo, 2024 +# including upsample ConvTranspose padding, output_padding +# and conditioned snakebeta activation + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +import logging + +import vec2wav2.layers.activations as activations +from alias_free_torch import * +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +class AMPBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None, condition_dim=1024): + super(AMPBlock1, self).__init__() + self.h = h + + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h['snake_logscale'])) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h['snake_logscale'])) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta-condition': # periodic nonlinearity with snakebeta function and anti-aliasing, conditioned by spk + self.activations = nn.ModuleList([ + Activation1dWithCondition( + activation=activations.SnakeBetaWithCondition(channels, condition_dim, alpha_logscale=h['snake_logscale']) + ) for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x, cond): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x, cond=cond) + xt = c1(xt) + xt = a2(xt, cond=cond) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None, condition_dim=1024): + super(AMPBlock2, self).__init__() + self.h = h + + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h['snake_logscale'])) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h['snake_logscale'])) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta-condition': # periodic nonlinearity with snakebeta function and anti-aliasing, conditioned by spk + self.activations = nn.ModuleList([ + Activation1dWithCondition( + activation=activations.SnakeBetaWithCondition(channels, condition_dim, alpha_logscale=h['snake_logscale']) + ) for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + def forward(self, x, cond): + for c, a in zip(self.convs, self.activations): + xt = a(x, cond=cond) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, **kwargs): + super(BigVGAN, self).__init__() + self.h = kwargs + + self.num_kernels = len(kwargs['resblock_kernel_sizes']) + self.num_upsamples = len(kwargs['upsample_scales']) + + # pre conv + self.conv_pre = weight_norm(Conv1d(kwargs['in_channels'], kwargs['channels'], 7, 1, padding=3)) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if kwargs['resblock'] == '1' else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(kwargs['upsample_scales'], kwargs['upsample_kernel_sizes'])): + self.ups.append(nn.ModuleList([ + weight_norm(ConvTranspose1d(kwargs['channels'] // (2 ** i), + kwargs['channels'] // (2 ** (i + 1)), + k, u, padding=u//2 + u % 2, output_padding=u%2)) + ])) # NOTE: this is different from official BigVGAN, to avoid length mismatch. + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = kwargs['channels'] // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(kwargs['resblock_kernel_sizes'], kwargs['resblock_dilations'])): + self.resblocks.append(resblock(kwargs, ch, k, d, activation=kwargs['nonlinear_activation'], condition_dim=kwargs['condition_dim'])) + + # post conv + if kwargs['nonlinear_activation'] == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = activations.Snake(ch, alpha_logscale=kwargs['snake_logscale']) + self.activation_post = Activation1d(activation=activation_post) + elif kwargs['nonlinear_activation'] == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = activations.SnakeBeta(ch, alpha_logscale=kwargs['snake_logscale']) + self.activation_post = Activation1d(activation=activation_post) + elif kwargs['nonlinear_activation'] == 'snakebeta-condition': + activation_post = activations.SnakeBetaWithCondition(ch, kwargs['condition_dim'], alpha_logscale=kwargs['snake_logscale']) + self.activation_post = Activation1dWithCondition(activation=activation_post) + else: + raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x, cond): + # pre conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x, cond) + else: + xs += self.resblocks[i * self.num_kernels + j](x, cond) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x, cond=cond) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + logging.info('Removing weight norm...') + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.d_mult = h.discriminator_channel_mult + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, int(32*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(32*self.d_mult), int(128*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(128*self.d_mult), int(512*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(512*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(int(1024*self.d_mult), int(1024*self.d_mult), (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(int(1024*self.d_mult), 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, h): + super(MultiPeriodDiscriminator, self).__init__() + self.mpd_reshapes = h.mpd_reshapes + print("mpd_reshapes: {}".format(self.mpd_reshapes)) + discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes] + self.discriminators = nn.ModuleList(discriminators) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg, resolution): + super().__init__() + + self.resolution = resolution + assert len(self.resolution) == 3, \ + "MRD layer requires list with len=3, got {}".format(self.resolution) + self.lrelu_slope = LRELU_SLOPE + + norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm + if hasattr(cfg, "mrd_use_spectral_norm"): + print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm)) + norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm + self.d_mult = cfg.discriminator_channel_mult + if hasattr(cfg, "mrd_channel_mult"): + print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult)) + self.d_mult = cfg.mrd_channel_mult + + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, int(32*self.d_mult), (3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(int(32*self.d_mult), int(32*self.d_mult), (3, 3), padding=(1, 1))), + ]) + self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) + + def forward(self, x): + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.resolution + x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') + x = x.squeeze(1) + x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.resolutions + assert len(self.resolutions) == 3,\ + "MRD requires list of list with len=3, each element having a list with len=3. got {}".\ + format(self.resolutions) + self.discriminators = nn.ModuleList( + [DiscriminatorR(cfg, resolution) for resolution in self.resolutions] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss*2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + diff --git a/vec2wav2/models/conformer/__init__.py b/vec2wav2/models/conformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/vec2wav2/models/conformer/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/vec2wav2/models/conformer/__pycache__/__init__.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d17e92fef6f9fba20dbe1d13ee6172b8b37c741 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/__init__.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c86748d235763b207d6f5345f4c97ffe114605e Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/__init__.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/__init__.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db4384558fa90ca36a95a64dc71724ff253fba25 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/argument.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/argument.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5719214d2901197e90d8f7e7b3c12e43f26accec Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/argument.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/attention.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa65d09c4b68fa6ac5e6c1188d1a4cf63f4a702f Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/attention.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/attention.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c729b4d4b5cdd15b71b5f5a31fd8249b945f621 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/attention.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/convolution.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f60b3a51d131a000540f87607c26bc889f75f76d Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/convolution.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/convolution.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/convolution.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3de09987c47a004b94c56eb579949a32c76a304d Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/convolution.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/convolution.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/convolution.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e95cd54c40d4844df75a5c86efff98c3de100be Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/convolution.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/decoder.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a17d190321dbb5f4eea44d1e90489ed29d57ab0 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/decoder.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/decoder.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/decoder.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eb54d4f136c7aa8761f97066fcacbeb8a678259 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/decoder.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/decoder.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e17e36f8d90137998608a6ad9c1a217ad1b221a Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/decoder.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5ab7b4e51074a5f4147f5180ba990c786820e5e Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fae90fd5baf90501d4d81b4dd7c1a7a57e924d3 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d69e562c091acdda635f6f68a93985a884220aa Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/decoder_layer.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/embedding.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16234d1dc933cb5f38692e829ce8934fc25b5767 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/embedding.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/embedding.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/embedding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c81cef76c97ac99018464c52cc580dffd002a47 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/embedding.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/encoder.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/encoder.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05028a550f7a8db193a67fd689e54fad9530539e Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/encoder.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/encoder_layer.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/encoder_layer.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b13bfd70d3bb93f162ebafb891b6baabbdc61dc8 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/encoder_layer.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/layer_norm.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/layer_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab066813d9140bc19d573073cac361ae46ca933c Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/layer_norm.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/layer_norm.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/layer_norm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44194bab5fbe901d2aff7077bfe2adb6f177ccf2 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/layer_norm.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/multi_layer_conv.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/multi_layer_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93a4a449051e999c927ff056fb90da167807b754 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/multi_layer_conv.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/multi_layer_conv.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/multi_layer_conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95d7dcf2aea2e80e440a633f0bf0e0db443cb441 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/multi_layer_conv.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/nets_utils.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/nets_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eeed770768499dd20f84f167847752a67a3cb27 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/nets_utils.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/nets_utils.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/nets_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03fb8849bd2a6f7cdd9e4d63840a913733c3606e Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/nets_utils.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/positionwise_feed_forward.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/positionwise_feed_forward.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32469ab42a59c1ed323bc40df47feb3b9a67cfa3 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/positionwise_feed_forward.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/positionwise_feed_forward.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/positionwise_feed_forward.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a31a7944320ae984c6cfa2eae5719e43e17f6bd Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/positionwise_feed_forward.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/repeat.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/repeat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f191998a2f6d55b1bfdfb63b3abd36e598a097e6 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/repeat.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/repeat.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/repeat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b05cf5440b158e0834a2b83e79e6b7d695697f19 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/repeat.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/subsampling.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/subsampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99427f5556ca952f1cd816efa9fee8cca51601ff Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/subsampling.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/subsampling.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/subsampling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7018397db3f29a4cff0f3814ec523a6cfd72eda Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/subsampling.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/swish.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/swish.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2d147bd18e31f6d4d378d07aafc1e64c6744fe7 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/swish.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/swish.cpython-36.pyc b/vec2wav2/models/conformer/__pycache__/swish.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a62eb62fe87bff409424320f86db4adc77a2befc Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/swish.cpython-36.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/swish.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/swish.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35f6901859d7cdd0617793f4e16e9ae48f28c8b8 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/swish.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/vgg2l.cpython-310.pyc b/vec2wav2/models/conformer/__pycache__/vgg2l.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d13d52fffd8ce3ef124acba981d493202a17eb6 Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/vgg2l.cpython-310.pyc differ diff --git a/vec2wav2/models/conformer/__pycache__/vgg2l.cpython-39.pyc b/vec2wav2/models/conformer/__pycache__/vgg2l.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..306677fdd08f55fd717e24f5c5e92f823ee8880a Binary files /dev/null and b/vec2wav2/models/conformer/__pycache__/vgg2l.cpython-39.pyc differ diff --git a/vec2wav2/models/conformer/argument.py b/vec2wav2/models/conformer/argument.py new file mode 100644 index 0000000000000000000000000000000000000000..d5681565256125941daaeff61e050141fcafbeb1 --- /dev/null +++ b/vec2wav2/models/conformer/argument.py @@ -0,0 +1,87 @@ +# Copyright 2020 Hirofumi Inaguma +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Conformer common arguments.""" + + +from distutils.util import strtobool +import logging + + +def add_arguments_conformer_common(group): + """Add Transformer common arguments.""" + group.add_argument( + "--transformer-encoder-pos-enc-layer-type", + type=str, + default="abs_pos", + choices=["abs_pos", "scaled_abs_pos", "rel_pos"], + help="Transformer encoder positional encoding layer type", + ) + group.add_argument( + "--transformer-encoder-activation-type", + type=str, + default="swish", + choices=["relu", "hardtanh", "selu", "swish"], + help="Transformer encoder activation function type", + ) + group.add_argument( + "--macaron-style", + default=False, + type=strtobool, + help="Whether to use macaron style for positionwise layer", + ) + # Attention + group.add_argument( + "--zero-triu", + default=False, + type=strtobool, + help="If true, zero the uppper triangular part of attention matrix.", + ) + # Relative positional encoding + group.add_argument( + "--rel-pos-type", + type=str, + default="legacy", + choices=["legacy", "latest"], + help="Whether to use the latest relative positional encoding or the legacy one." + "The legacy relative positional encoding will be deprecated in the future." + "More Details can be found in https://github.com/espnet/espnet/pull/2816.", + ) + # CNN module + group.add_argument( + "--use-cnn-module", + default=False, + type=strtobool, + help="Use convolution module or not", + ) + group.add_argument( + "--cnn-module-kernel", + default=31, + type=int, + help="Kernel size of convolution module.", + ) + return group + + +def verify_rel_pos_type(args): + """Verify the relative positional encoding type for compatibility. + + Args: + args (Namespace): original arguments + Returns: + args (Namespace): modified arguments + """ + rel_pos_type = getattr(args, "rel_pos_type", None) + if rel_pos_type is None or rel_pos_type == "legacy": + if args.transformer_encoder_pos_enc_layer_type == "rel_pos": + args.transformer_encoder_pos_enc_layer_type = "legacy_rel_pos" + logging.warning( + "Using legacy_rel_pos and it will be deprecated in the future." + ) + if args.transformer_encoder_selfattn_layer_type == "rel_selfattn": + args.transformer_encoder_selfattn_layer_type = "legacy_rel_selfattn" + logging.warning( + "Using legacy_rel_selfattn and it will be deprecated in the future." + ) + + return args diff --git a/vec2wav2/models/conformer/attention.py b/vec2wav2/models/conformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8b68089ec7629b22346f538dab359ff7560acd --- /dev/null +++ b/vec2wav2/models/conformer/attention.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention layer definition.""" + +import math + +import numpy +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float( + numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min + ) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, time2). + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, 2*time1-1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) diff --git a/vec2wav2/models/conformer/contextual_block_encoder_layer.py b/vec2wav2/models/conformer/contextual_block_encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..91cf0b068ab7182402d57c2d1123836c6fc074d3 --- /dev/null +++ b/vec2wav2/models/conformer/contextual_block_encoder_layer.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Aug 21 16:57:31 2021. + +@author: Keqi Deng (UCAS) +""" + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +import torch +from torch import nn + + +class ContextualBlockEncoderLayer(nn.Module): + """Contexutal Block Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + total_layer_num (int): Total number of layers + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + feed_forward_macaron, + conv_module, + dropout_rate, + total_layer_num, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(ContextualBlockEncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + self.total_layer_num = total_layer_num + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward( + self, + x, + mask, + infer_mode=False, + past_ctx=None, + next_ctx=None, + is_short_segment=False, + layer_idx=0, + cache=None, + ): + """Calculate forward propagation.""" + if self.training or not infer_mode: + return self.forward_train(x, mask, past_ctx, next_ctx, layer_idx, cache) + else: + return self.forward_infer( + x, mask, past_ctx, next_ctx, is_short_segment, layer_idx, cache + ) + + def forward_train( + self, x, mask, past_ctx=None, next_ctx=None, layer_idx=0, cache=None + ): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + past_ctx (torch.Tensor): Previous contexutal vector + next_ctx (torch.Tensor): Next contexutal vector + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + cur_ctx (torch.Tensor): Current contexutal vector + next_ctx (torch.Tensor): Next contexutal vector + layer_idx (int): layer index number + + """ + nbatch = x.size(0) + nblock = x.size(1) + + if past_ctx is not None: + if next_ctx is None: + # store all context vectors in one tensor + next_ctx = past_ctx.new_zeros( + nbatch, nblock, self.total_layer_num, x.size(-1) + ) + else: + x[:, :, 0] = past_ctx[:, :, layer_idx] + + # reshape ( nbatch, nblock, block_size + 2, dim ) + # -> ( nbatch * nblock, block_size + 2, dim ) + x = x.view(-1, x.size(-2), x.size(-1)) + if mask is not None: + mask = mask.view(-1, mask.size(-2), mask.size(-1)) + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) + if not self.normalize_before: + x = self.norm1(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + layer_idx += 1 + # reshape ( nbatch * nblock, block_size + 2, dim ) + # -> ( nbatch, nblock, block_size + 2, dim ) + x = x.view(nbatch, -1, x.size(-2), x.size(-1)).squeeze(1) + if mask is not None: + mask = mask.view(nbatch, -1, mask.size(-2), mask.size(-1)).squeeze(1) + + if next_ctx is not None and layer_idx < self.total_layer_num: + next_ctx[:, 0, layer_idx, :] = x[:, 0, -1, :] + next_ctx[:, 1:, layer_idx, :] = x[:, 0:-1, -1, :] + + return x, mask, False, next_ctx, next_ctx, layer_idx + + def forward_infer( + self, + x, + mask, + past_ctx=None, + next_ctx=None, + is_short_segment=False, + layer_idx=0, + cache=None, + ): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + past_ctx (torch.Tensor): Previous contexutal vector + next_ctx (torch.Tensor): Next contexutal vector + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + cur_ctx (torch.Tensor): Current contexutal vector + next_ctx (torch.Tensor): Next contexutal vector + layer_idx (int): layer index number + + """ + nbatch = x.size(0) + nblock = x.size(1) + # if layer_idx == 0, next_ctx has to be None + if layer_idx == 0: + assert next_ctx is None + next_ctx = x.new_zeros(nbatch, self.total_layer_num, x.size(-1)) + + # reshape ( nbatch, nblock, block_size + 2, dim ) + # -> ( nbatch * nblock, block_size + 2, dim ) + x = x.view(-1, x.size(-2), x.size(-1)) + if mask is not None: + mask = mask.view(-1, mask.size(-2), mask.size(-1)) + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) + if not self.normalize_before: + x = self.norm1(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + # reshape ( nbatch * nblock, block_size + 2, dim ) + # -> ( nbatch, nblock, block_size + 2, dim ) + x = x.view(nbatch, nblock, x.size(-2), x.size(-1)) + if mask is not None: + mask = mask.view(nbatch, nblock, mask.size(-2), mask.size(-1)) + + # Propagete context information (the last frame of each block) + # to the first frame + # of the next block + + if not is_short_segment: + if past_ctx is None: + # First block of an utterance + x[:, 0, 0, :] = x[:, 0, -1, :] + else: + x[:, 0, 0, :] = past_ctx[:, layer_idx, :] + if nblock > 1: + x[:, 1:, 0, :] = x[:, 0:-1, -1, :] + next_ctx[:, layer_idx, :] = x[:, -1, -1, :] + else: + next_ctx = None + + return x, mask, True, past_ctx, next_ctx, is_short_segment, layer_idx + 1 diff --git a/vec2wav2/models/conformer/convolution.py b/vec2wav2/models/conformer/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5d2c30c313e73fa2097bc28721be00aeb6910f --- /dev/null +++ b/vec2wav2/models/conformer/convolution.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""ConvolutionModule definition.""" + +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + + """ + + def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward(self, x): + """Compute convolution module. + + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) diff --git a/vec2wav2/models/conformer/decoder.py b/vec2wav2/models/conformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d6aaeb4803d560f2a67f135c00eacafd61684219 --- /dev/null +++ b/vec2wav2/models/conformer/decoder.py @@ -0,0 +1,247 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" + +import logging +import torch + +from vec2wav2.models.conformer.convolution import ConvolutionModule +from vec2wav2.models.conformer.decoder_layer import DecoderLayer +from vec2wav2.models.conformer.nets_utils import get_activation +from vec2wav2.models.conformer.vgg2l import VGG2L +from vec2wav2.models.conformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 + LegacyRelPositionMultiHeadedAttention, # noqa: H301 +) +from vec2wav2.models.conformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 + LegacyRelPositionalEncoding, # noqa: H301 +) +from vec2wav2.models.conformer.layer_norm import LayerNorm +from vec2wav2.models.conformer.multi_layer_conv import Conv1dLinear, MultiLayeredConv1d +from vec2wav2.models.conformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from vec2wav2.models.conformer.repeat import repeat +from vec2wav2.models.conformer.subsampling import Conv2dSubsampling + + +class Decoder(torch.nn.Module): + """Conformer encoder module. + + Args: + idim (int): Input dimension. + attention_dim (int): Dimention of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + macaron_style (bool): Whether to use macaron style for positionwise layer. + pos_enc_layer_type (str): Encoder positional encoding layer type. + selfattention_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + zero_triu=False, + cnn_module_kernel=31, + padding_idx=-1, + ): + """Construct an Encoder object.""" + super(Decoder, self).__init__() + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + pos_enc_class = LegacyRelPositionalEncoding + assert selfattention_layer_type == "legacy_rel_selfattn" + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + self.conv_subsampling_factor = 1 + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.conv_subsampling_factor = 4 + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + self.conv_subsampling_factor = 4 + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + + # self-attention module definition + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "rel_selfattn": + logging.info("encoder self-attention layer type = relative self-attention") + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + zero_triu, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + # feed-forward module definition + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + attention_dim, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + MultiHeadedAttention( + attention_heads, attention_dim, attention_dropout_rate + ), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, xs, masks, memory, memory_mask): + """Encode input sequence. + + Args: + xs (torch.Tensor): Input tensor (#batch, time, idim). + masks (torch.Tensor): Mask tensor (#batch, time). + + Returns: + torch.Tensor: Output tensor (#batch, time, attention_dim). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + xs, masks, memory, memory_mask = self.decoders(xs, masks, memory, memory_mask) + if isinstance(xs, tuple): + xs = xs[0] + + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks diff --git a/vec2wav2/models/conformer/decoder_layer.py b/vec2wav2/models/conformer/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..773401e5fcf0a5da8a13ca05bf1e58438a94056e --- /dev/null +++ b/vec2wav2/models/conformer/decoder_layer.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder self-attention layer definition.""" + +import torch + +from torch import nn + +from vec2wav2.models.conformer.layer_norm import LayerNorm + + +class DecoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + src_attn, + feed_forward, + feed_forward_macaron, + conv_module, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(DecoderLayer, self).__init__() + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + self.norm2 = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, x_input, mask, memory, memory_mask, cache=None): + """Compute encoded features. + + Args: + x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. + - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. + - w/o pos emb: Tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # cross attention + residual = x + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) + if not self.normalize_before: + x = self.norm2(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask, memory, memory_mask + + return x, mask, memory, memory_mask diff --git a/vec2wav2/models/conformer/embedding.py b/vec2wav2/models/conformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..3a92e0f8cc9091744d3b9434fb9dfe8f8231c9df --- /dev/null +++ b/vec2wav2/models/conformer/embedding.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positional Encoding Module.""" + +import math + +import torch + + +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + reverse (bool): Whether to reverse the input position. Only for + the class LegacyRelPositionalEncoding. We remove it in the current + class RelPositionalEncoding. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class ScaledPositionalEncoding(PositionalEncoding): + """Scaled positional encoding module. + + See Sec. 3.2 https://arxiv.org/abs/1809.08895 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) + self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + + def reset_parameters(self): + """Reset parameters.""" + self.alpha.data = torch.tensor(1.0) + + def forward(self, x): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(x) + + +class LegacyRelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__( + d_model=d_model, + dropout_rate=dropout_rate, + max_len=max_len, + reverse=True, + ) + + def forward(self, x): + """Compute positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, : x.size(1)] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + macaron_style (bool): Whether to use macaron style for positionwise layer. + pos_enc_layer_type (str): Encoder positional encoding layer type. + selfattention_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + zero_triu=False, + cnn_module_kernel=31, + padding_idx=-1, + ): + """Construct an Encoder object.""" + super(Encoder, self).__init__() + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + pos_enc_class = LegacyRelPositionalEncoding + assert selfattention_layer_type == "legacy_rel_selfattn" + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + self.conv_subsampling_factor = 1 + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.conv_subsampling_factor = 4 + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + self.conv_subsampling_factor = 4 + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + + # self-attention module definition + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "rel_selfattn": + logging.info("encoder self-attention layer type = relative self-attention") + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + zero_triu, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + # feed-forward module definition + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + attention_dim, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, xs, masks): + """Encode input sequence. + + Args: + xs (torch.Tensor): Input tensor (#batch, time, idim). + masks (torch.Tensor): Mask tensor (#batch, time). + + Returns: + torch.Tensor: Output tensor (#batch, time, attention_dim). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + xs, masks = self.encoders(xs, masks) + if isinstance(xs, tuple): + xs = xs[0] + + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks diff --git a/vec2wav2/models/conformer/encoder_layer.py b/vec2wav2/models/conformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e8571e01eee2e126fcd0ce64524ae60f433ade2a --- /dev/null +++ b/vec2wav2/models/conformer/encoder_layer.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder self-attention layer definition.""" + +import torch + +from torch import nn + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + feed_forward_macaron, + conv_module, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x_input, mask, cache=None): + """Compute encoded features. + + Args: + x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. + - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. + - w/o pos emb: Tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask diff --git a/vec2wav2/models/conformer/layer_norm.py b/vec2wav2/models/conformer/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..6e934e644bf27aead8f299123519ad536758cdd7 --- /dev/null +++ b/vec2wav2/models/conformer/layer_norm.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer normalization module.""" + +import torch + + +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + + Args: + nout (int): Output dim size. + dim (int): Dimension to be normalized. + + """ + + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor. + + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return ( + super(LayerNorm, self) + .forward(x.transpose(self.dim, -1)) + .transpose(self.dim, -1) + ) diff --git a/vec2wav2/models/conformer/multi_layer_conv.py b/vec2wav2/models/conformer/multi_layer_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb0717b060d5815d44c83b711f8fc4659987f3a --- /dev/null +++ b/vec2wav2/models/conformer/multi_layer_conv.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" + +import torch + + +class MultiLayeredConv1d(torch.nn.Module): + """Multi-layered conv1d for Transformer block. + + This is a module of multi-leyered conv1d designed + to replace positionwise feed-forward network + in Transforner block, which is introduced in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize MultiLayeredConv1d module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Conv1d( + hidden_chans, + in_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + + +class Conv1dLinear(torch.nn.Module): + """Conv1D + Linear for Transformer block. + + A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize Conv1dLinear module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(Conv1dLinear, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Linear(hidden_chans, in_chans) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x)) diff --git a/vec2wav2/models/conformer/nets_utils.py b/vec2wav2/models/conformer/nets_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a13ea9302eebb4dc366a6152106153053d7b0e0d --- /dev/null +++ b/vec2wav2/models/conformer/nets_utils.py @@ -0,0 +1,498 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" + +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + if isinstance(m, torch.nn.Module): + device = next(m.parameters()).device + elif isinstance(m, torch.Tensor): + device = m.device + else: + raise TypeError( + "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}" + ) + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, : xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) + ).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == "c": + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if "real" not in x or "imag" not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x["real"], x["imag"]) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ( + "x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x)) + ) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == "transformer": + return np.array([1]) + + elif mode == "mt" and arch == "rnn": + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + logging.warning("Subsampling is not performed for machine translation.") + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif ( + (mode == "asr" and arch in ("rnn", "rnn-t")) + or (mode == "mt" and arch == "rnn") + or (mode == "st" and arch == "rnn") + ): + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mix": + subsample = np.ones( + train_args.elayers_sd + train_args.elayers + 1, dtype=np.int + ) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range( + min(train_args.elayers_sd + train_args.elayers + 1, len(ss)) + ): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mulenc": + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int) + if train_args.etype[idx].endswith("p") and not train_args.etype[ + idx + ].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Encoder %d: Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN.", + idx + 1, + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch)) + + +def rename_state_dict( + old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor] +): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f"Rename: {old_prefix} -> {new_prefix}") + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + from vec2wav2.models.conformer.swish import Swish + + activation_funcs = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": Swish, + } + + return activation_funcs[act]() diff --git a/vec2wav2/models/conformer/positionwise_feed_forward.py b/vec2wav2/models/conformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..5a66445e9557c9ea5f4ad382a7532f9d4204ff54 --- /dev/null +++ b/vec2wav2/models/conformer/positionwise_feed_forward.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): + """Construct an PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.dropout = torch.nn.Dropout(dropout_rate) + self.activation = activation + + def forward(self, x): + """Forward funciton.""" + return self.w_2(self.dropout(self.activation(self.w_1(x)))) diff --git a/vec2wav2/models/conformer/repeat.py b/vec2wav2/models/conformer/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d2676a8020bbb4cb44e84a199baece2c9e763b --- /dev/null +++ b/vec2wav2/models/conformer/repeat.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Repeat the same layer definition.""" + +import torch + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential.""" + + def forward(self, *args): + """Repeat.""" + for m in self: + args = m(*args) + return args + + +def repeat(N, fn): + """Repeat module N times. + + Args: + N (int): Number of repeat time. + fn (Callable): Function to generate module. + + Returns: + MultiSequential: Repeated model instance. + + """ + return MultiSequential(*[fn(n) for n in range(N)]) diff --git a/vec2wav2/models/conformer/subsampling.py b/vec2wav2/models/conformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..eb024025a6866de7eb3bcd94212362015ecb3826 --- /dev/null +++ b/vec2wav2/models/conformer/subsampling.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Subsampling layer definition.""" + +import torch + +from vec2wav2.models.conformer.embedding import PositionalEncoding + + +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + + Args: + message (str): Message for error catch + actual_size (int): the short size that cannot pass the subsampling + limit (int): the limit size for subsampling + + """ + + def __init__(self, message, actual_size, limit): + """Construct a TooShortUttError for error handler.""" + super().__init__(message) + self.actual_size = actual_size + self.limit = limit + + +def check_short_utt(ins, size): + """Check if the utterance is too short for subsampling.""" + if isinstance(ins, Conv2dSubsampling2) and size < 3: + return True, 3 + if isinstance(ins, Conv2dSubsampling) and size < 7: + return True, 7 + if isinstance(ins, Conv2dSubsampling6) and size < 11: + return True, 11 + if isinstance(ins, Conv2dSubsampling8) and size < 15: + return True, 15 + return False, -1 + + +class Conv2dSubsampling(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-2:2] + + def __getitem__(self, key): + """Get item. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dSubsampling2(torch.nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling2 object.""" + super(Conv2dSubsampling2, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 2. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 2. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-2:1] + + def __getitem__(self, key): + """Get item. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dSubsampling6(torch.nn.Module): + """Convolutional 2D subsampling (to 1/6 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling6 object.""" + super(Conv2dSubsampling6, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-4:3] + + +class Conv2dSubsampling8(torch.nn.Module): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling8 object.""" + super(Conv2dSubsampling8, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] diff --git a/vec2wav2/models/conformer/swish.py b/vec2wav2/models/conformer/swish.py new file mode 100644 index 0000000000000000000000000000000000000000..c53a7a98bfc6d983c3a308c4b40f81e315aa7875 --- /dev/null +++ b/vec2wav2/models/conformer/swish.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Swish() activation function for Conformer.""" + +import torch + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x): + """Return Swich activation function.""" + return x * torch.sigmoid(x) diff --git a/vec2wav2/models/conformer/vgg2l.py b/vec2wav2/models/conformer/vgg2l.py new file mode 100644 index 0000000000000000000000000000000000000000..18aeafb0f32c1feea7f38c28645ecac2d461b0e5 --- /dev/null +++ b/vec2wav2/models/conformer/vgg2l.py @@ -0,0 +1,89 @@ +"""VGG2L module definition for transformer encoder.""" + +from typing import Tuple +from typing import Union + +import torch + + +class VGG2L(torch.nn.Module): + """VGG2L module for custom encoder. + + Args: + idim: Dimension of inputs + odim: Dimension of outputs + pos_enc: Positional encoding class + + """ + + def __init__(self, idim: int, odim: int, pos_enc: torch.nn.Module = None): + """Construct a VGG2L object.""" + super().__init__() + + self.vgg2l = torch.nn.Sequential( + torch.nn.Conv2d(1, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(64, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((3, 2)), + torch.nn.Conv2d(64, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(128, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((2, 2)), + ) + + if pos_enc is not None: + self.output = torch.nn.Sequential( + torch.nn.Linear(128 * ((idim // 2) // 2), odim), pos_enc + ) + else: + self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor + ) -> Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], + ]: + """VGG2L forward for x. + + Args: + x: Input tensor (B, T, idim) + x_mask: Input mask (B, 1, T) + + Returns: + x: Output tensor (B, sub(T), odim) + or ((B, sub(T), odim), (B, sub(T), att_dim)) + x_mask: Output mask (B, 1, sub(T)) + + """ + x = x.unsqueeze(1) + x = self.vgg2l(x) + + b, c, t, f = x.size() + + x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if x_mask is not None: + x_mask = self.create_new_mask(x_mask) + + return x, x_mask + + def create_new_mask(self, x_mask: torch.Tensor) -> torch.Tensor: + """Create a subsampled version of x_mask. + + Args: + x_mask: Input mask (B, 1, T) + + Returns: + x_mask: Output mask (B, 1, sub(T)) + + """ + x_t1 = x_mask.size(2) - (x_mask.size(2) % 3) + x_mask = x_mask[:, :, :x_t1][:, :, ::3] + + x_t2 = x_mask.size(2) - (x_mask.size(2) % 2) + x_mask = x_mask[:, :, :x_t2][:, :, ::2] + + return x_mask diff --git a/vec2wav2/models/fairseq_modules/__init__.py b/vec2wav2/models/fairseq_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vec2wav2/models/fairseq_modules/__pycache__/__init__.cpython-310.pyc b/vec2wav2/models/fairseq_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..594da672cfb8765edc5a3ab564cb91b2c4102064 Binary files /dev/null and b/vec2wav2/models/fairseq_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/models/fairseq_modules/__pycache__/fp32_group_norm.cpython-310.pyc b/vec2wav2/models/fairseq_modules/__pycache__/fp32_group_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..018a98a015a36ecf65c995af508170148ee94125 Binary files /dev/null and b/vec2wav2/models/fairseq_modules/__pycache__/fp32_group_norm.cpython-310.pyc differ diff --git a/vec2wav2/models/fairseq_modules/__pycache__/layer_norm.cpython-310.pyc b/vec2wav2/models/fairseq_modules/__pycache__/layer_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c6276ebb273b329eead34fd7a93e30e98f1c2da Binary files /dev/null and b/vec2wav2/models/fairseq_modules/__pycache__/layer_norm.cpython-310.pyc differ diff --git a/vec2wav2/models/fairseq_modules/__pycache__/transpose_last.cpython-310.pyc b/vec2wav2/models/fairseq_modules/__pycache__/transpose_last.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00738ba3bffe42cc3ec73d1f78afd6efebdd28fe Binary files /dev/null and b/vec2wav2/models/fairseq_modules/__pycache__/transpose_last.cpython-310.pyc differ diff --git a/vec2wav2/models/fairseq_modules/fp32_group_norm.py b/vec2wav2/models/fairseq_modules/fp32_group_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d03aac022e30c8c14a600062d1d86429504ba003 --- /dev/null +++ b/vec2wav2/models/fairseq_modules/fp32_group_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Layer norm done in fp32 (for fp16 training) +""" + +import torch.nn as nn +import torch.nn.functional as F + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/vec2wav2/models/fairseq_modules/gumbel_vector_quantizer.py b/vec2wav2/models/fairseq_modules/gumbel_vector_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..867b019f676d72a51db8f8ea54e08fab2b535bfc --- /dev/null +++ b/vec2wav2/models/fairseq_modules/gumbel_vector_quantizer.py @@ -0,0 +1,212 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GumbelVectorQuantizer(nn.Module): + def __init__( + self, + dim, + num_vars, + temp, + groups, + combine_groups, + vq_dim, + time_first, + activation=nn.GELU(), + weight_proj_depth=1, + weight_proj_factor=1, + hard=True, + std=0, + ): + """Vector quantization using gumbel softmax + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + self.hard = hard + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) + if std == 0: + nn.init.uniform_(self.vars) + else: + nn.init.normal_(self.vars, mean=0, std=std) + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[ + block(self.input_dim if i == 0 else inner_dim, inner_dim) + for i in range(weight_proj_depth - 1) + ], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) + + if isinstance(temp, str): + import ast + + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max( + self.max_temp * self.temp_decay**num_updates, self.min_temp + ) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = torch.tensor( + inds, dtype=torch.long, device=self.vars.device + ).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.view( + self.num_vars**self.groups, -1 + ) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def codebook(self): + indices = self.get_codebook_indices() + return ( + self.vars.squeeze(0) + .index_select(0, indices) + .view(self.num_vars**self.groups, -1) + ) + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = indices.new_full(indices.shape[:-1], 0) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars**exponent) + return res + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars * self.groups} + + if not self.time_first: + x = x.transpose(1, 2) + + bsz, tsz, fsz = x.shape + x = x.reshape(-1, fsz) + x = self.weight_proj(x) + x = x.view(bsz * tsz * self.groups, -1) + + with torch.no_grad(): + _, k = x.max(-1) + hard_x = ( + x.new_zeros(*x.shape) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + avg_probs = torch.softmax( + x.view(bsz * tsz, self.groups, -1).float(), dim=-1 + ).mean(dim=0) + result["prob_perplexity"] = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ).sum() + + result["temp"] = self.curr_temp + + if self.training: + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=self.hard).type_as( + x + ) + else: + x = hard_x + + x = x.view(bsz * tsz, -1) + + vars = self.vars + if self.combine_groups: + vars = vars.repeat(1, self.groups, 1) + + if produce_targets: + result["targets"] = ( + x.view(bsz * tsz * self.groups, -1) + .argmax(dim=-1) + .view(bsz, tsz, self.groups) + .detach() + ) + + x = x.unsqueeze(-1) * vars + x = x.view(bsz * tsz, self.groups, self.num_vars, -1) + x = x.sum(-2) + x = x.view(bsz, tsz, -1) + + if not self.time_first: + x = x.transpose(1, 2) # BTC -> BCT + + result["x"] = x + + return result diff --git a/vec2wav2/models/fairseq_modules/kmeans_vector_quantizer.py b/vec2wav2/models/fairseq_modules/kmeans_vector_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a63a189c80c82bb27154fbcfe29abfb1cf91edec --- /dev/null +++ b/vec2wav2/models/fairseq_modules/kmeans_vector_quantizer.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from vec2wav2.models.fairseq_modules.fp32_group_norm import Fp32GroupNorm + + +class KmeansVectorQuantizer(nn.Module): + def __init__( + self, dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25 + ): + """Vector quantization using straight pass-through estimator (i.e. kmeans) + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + gamma: commitment loss coefficient + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.vq_dim = vq_dim + self.time_first = time_first + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + self.var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.embedding = nn.Parameter( + 0.01 * torch.randn(num_vars, num_groups, self.var_dim) + ) + self.projection = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size=1, groups=groups, bias=False), + Fp32GroupNorm(groups, dim), + ) + self.gamma = gamma + self.mse_mean = nn.MSELoss(reduction="mean") + + def _pass_grad(self, x, y): + """Manually set gradient for backward pass. + for y = f(x), ensure that during the backward pass, + dL/dy = dL/dx regardless of f(x). + Returns: + y, with the gradient forced to be dL/dy = dL/dx. + """ + + return y.detach() + (x - x.detach()) + + @property + def expand_embedding(self): + if self.combine_groups: + return self.embedding.expand(self.num_vars, self.groups, self.var_dim) + return self.embedding + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward_idx_limited(self, x, valid_label2vqidx_mat): + # mask_mat = convert_valid_label2vqidx_to_mask_mat(valid_label2vqidx) + res = self.forward_group2(x, mask_mat=valid_label2vqidx_mat, produce_targets=True) + return res['x'], res['targets'] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars} + + if self.time_first: + x = x.transpose(1, 2) + + bsz, fsz, tsz = x.shape + + ze = self.projection(x) + ze_ = ze.view(bsz, self.groups, self.var_dim, tsz).permute(0, 3, 1, 2) + d = ( + (ze_.unsqueeze(0) - self.expand_embedding.unsqueeze(1).unsqueeze(1)) + .view(self.num_vars, bsz, tsz, self.groups, -1) + .norm(dim=-1, p=2) + ) + idx = d.argmin(dim=0) + zq = ( + torch.stack( + [ + self.expand_embedding[idx[..., group], group] + for group in range(self.groups) + ], + dim=-2, + ) + .view(bsz, tsz, self.groups * self.var_dim) + .permute(0, 2, 1) + ) + assert ze.shape == zq.shape, (ze.shape, zq.shape) + x = self._pass_grad(ze, zq) + + with torch.no_grad(): + hard_x = ( + idx.new_zeros(bsz * tsz * self.groups, self.num_vars) + .scatter_(-1, idx.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + if produce_targets: + result["targets"] = idx + + if self.time_first: + x = x.transpose(1, 2) # BCT -> BTC + result["x"] = x + + ze = ze.float() + zq = zq.float() + latent_loss = self.mse_mean(zq, ze.detach()) + commitment_loss = self.mse_mean(ze, zq.detach()) + + result["kmeans_loss"] = latent_loss + self.gamma * commitment_loss + + return result + + def forward_group2(self, x, mask_mat=None, produce_targets=False, inf=999999): + assert mask_mat is not None + + result = {"num_vars": self.num_vars} + + if self.time_first: + x = x.transpose(1, 2) + + bsz, fsz, tsz = x.shape + + ze = self.projection(x) + + ze_ = ze.view(bsz, self.groups, self.var_dim, tsz).permute(0, 3, 1, 2) + ze_0 = ze_[:, :, 0, None, :] + ze_1 = ze_[:, :, 1, None, :] # 4 * 100 * 320 * 128 + cb0_expand = self.expand_embedding[:, 0, :] + cb1_expand = self.expand_embedding[:, 1, :] # 320 * 128 + dist_0 = ((ze_0 - cb0_expand) ** 2).sum(dim=-1)[:, :, :, None] + dist_1 = ((ze_1 - cb1_expand) ** 2).sum(dim=-1)[:, :, None, :] + res_0, res_1 = torch.broadcast_tensors(dist_0, dist_1) + mask_mat = (1 - mask_mat[None, None, :, :].to(res_0.device) * torch.ones_like(res_0)) * inf + # mask_mat = mask_mat.to(x.device) + d_flt = (res_0 + res_1 + mask_mat).view(bsz, tsz, -1) + idx_flt = torch.argmin(d_flt, dim=-1) + idx = torch.stack((idx_flt // self.num_vars, idx_flt % self.num_vars), dim=-1) + + zq = ( + torch.stack( + [ + self.expand_embedding[idx[..., group], group] + for group in range(self.groups) + ], + dim=-2, + ) + .view(bsz, tsz, self.groups * self.var_dim) + .permute(0, 2, 1) + ) + assert ze.shape == zq.shape, (ze.shape, zq.shape) + x = self._pass_grad(ze, zq) + + with torch.no_grad(): + hard_x = ( + idx.new_zeros(bsz * tsz * self.groups, self.num_vars) + .scatter_(-1, idx.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + if produce_targets: + result["targets"] = idx + + if self.time_first: + x = x.transpose(1, 2) # BCT -> BTC + result["x"] = x + + ze = ze.float() + zq = zq.float() + latent_loss = self.mse_mean(zq, ze.detach()) + commitment_loss = self.mse_mean(ze, zq.detach()) + + result["kmeans_loss"] = latent_loss + self.gamma * commitment_loss + + return result + +if __name__ == "__main__": + quantizer = KmeansVectorQuantizer(dim=256, num_vars=320, groups=2, combine_groups=False, vq_dim=256, time_first=True) + x = torch.ones(4, 100, 256) + result = quantizer.forward_group2(x, mask_mat=torch.randint(0, 2, (320, 320))) + print(result) diff --git a/vec2wav2/models/fairseq_modules/layer_norm.py b/vec2wav2/models/fairseq_modules/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..0b276ce02fc6bcb9619c9e8a0f7ec10cd28bc420 --- /dev/null +++ b/vec2wav2/models/fairseq_modules/layer_norm.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/vec2wav2/models/fairseq_modules/transpose_last.py b/vec2wav2/models/fairseq_modules/transpose_last.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cca9a4bbdb3f455217380f96a2f2d77eae8630 --- /dev/null +++ b/vec2wav2/models/fairseq_modules/transpose_last.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +transpose last 2 dimensions of the input +""" + +import torch.nn as nn + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): + super().__init__() + self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(self.tranpose_dim, -1) diff --git a/vec2wav2/models/hifigan.py b/vec2wav2/models/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..33984b0033a2cc5aed14a10a82b6d8ac38a8671b --- /dev/null +++ b/vec2wav2/models/hifigan.py @@ -0,0 +1,732 @@ +# -*- coding: utf-8 -*- + +"""HiFi-GAN Modules. + +This code is based on https://github.com/jik876/hifi-gan. + +""" + +import copy +import logging + +import numpy as np +import torch +import torch.nn.functional as F + +from vec2wav2.layers import HiFiGANResidualBlock as ResidualBlock +from vec2wav2.utils import read_hdf5 + + +class HiFiGANGenerator(torch.nn.Module): + """HiFiGAN generator module.""" + + def __init__( + self, + in_channels=80, + out_channels=1, + channels=512, + kernel_size=7, + upsample_scales=(8, 8, 2, 2), + upsample_kernel_sizes=(16, 16, 4, 4), + resblock_kernel_sizes=(3, 7, 11), + resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], + use_additional_convs=True, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.1}, + use_weight_norm=True, + ): + """Initialize HiFiGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + channels (int): Number of hidden representation channels. + kernel_size (int): Kernel size of initial and final conv layer. + upsample_scales (list): List of upsampling scales. + upsample_kernel_sizes (list): List of kernel sizes for upsampling layers. + resblock_kernel_sizes (list): List of kernel sizes for residual blocks. + resblock_dilations (list): List of dilation list for residual blocks. + use_additional_convs (bool): Whether to use additional conv layers in residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all the conv layers. + + """ + super().__init__() + + # check hyperparameters are valid + assert kernel_size % 2 == 1, "Kernel size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + + # define modules + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ + ResidualBlock( + kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + ) + ] + self.output_conv = torch.nn.Sequential( + # NOTE(kan-bayashi): follow official implementation but why + # using different slope parameter here? (0.1 vs. 0.01) + torch.nn.LeakyReLU(), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.Tanh(), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, c): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + c = self.input_conv(c) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs += self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + + return c + + def reset_parameters(self): + """Reset parameters. + + This initialization follows the official implementation manner. + https://github.com/jik876/hifi-gan/blob/master/models.py + + """ + + def _reset_parameters(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + m.weight.data.normal_(0.0, 0.01) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def remove_weight_norm(self): + """Remove weight normalization module from all the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def register_stats(self, stats): + """Register stats for de-normalization as buffer. + + Args: + stats (str): Path of statistics file (".npy" or ".h5"). + + """ + assert stats.endswith(".h5") or stats.endswith(".npy") + if stats.endswith(".h5"): + mean = read_hdf5(stats, "mean").reshape(-1) + scale = read_hdf5(stats, "scale").reshape(-1) + else: + mean = np.load(stats)[0].reshape(-1) + scale = np.load(stats)[1].reshape(-1) + self.register_buffer("mean", torch.from_numpy(mean).float()) + self.register_buffer("scale", torch.from_numpy(scale).float()) + logging.info("Successfully registered stats as buffer.") + + def inference(self, c, normalize_before=False): + """Perform inference. + + Args: + c (Union[Tensor, ndarray]): Input tensor (T, in_channels). + normalize_before (bool): Whether to perform normalization. + + Returns: + Tensor: Output tensor (T ** prod(upsample_scales), out_channels). + + """ + # if not isinstance(c, torch.Tensor): + # c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) + if normalize_before: + c = (c - self.mean) / self.scale + c = self.forward(c.transpose(1, 2)) + return c.squeeze(0).transpose(1, 0) + + +class HiFiGANPeriodDiscriminator(torch.nn.Module): + """HiFiGAN period discriminator module.""" + + def __init__( + self, + in_channels=1, + out_channels=1, + period=3, + kernel_sizes=[5, 3], + channels=32, + downsample_scales=[3, 3, 3, 3, 1], + max_downsample_channels=1024, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.1}, + use_weight_norm=True, + use_spectral_norm=False, + ): + """Initialize HiFiGANPeriodDiscriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + period (int): Period. + kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer. + channels (int): Number of initial channels. + downsample_scales (list): List of downsampling scales. + max_downsample_channels (int): Number of maximum downsampling channels. + use_additional_convs (bool): Whether to use additional conv layers in residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. + If set to true, it will be applied to all the conv layers. + + """ + super().__init__() + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." + assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." + + self.period = period + self.convs = torch.nn.ModuleList() + in_chs = in_channels + out_chs = channels + for downsample_scale in downsample_scales: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv2d( + in_chs, + out_chs, + (kernel_sizes[0], 1), + (downsample_scale, 1), + padding=((kernel_sizes[0] - 1) // 2, 0), + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Use downsample_scale + 1? + out_chs = min(out_chs * 4, max_downsample_channels) + self.output_conv = torch.nn.Conv2d( + out_chs, + out_channels, + (kernel_sizes[1] - 1, 1), + 1, + padding=((kernel_sizes[1] - 1) // 2, 0), + ) + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + if use_spectral_norm: + self.apply_spectral_norm() + + def forward(self, x): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + + Returns: + list: List of each layer's tensors. + + """ + # transform 1d to 2d -> (B, C, T/P, P) + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t += n_pad + x = x.view(b, c, t // self.period, self.period) + + # forward conv + outs = [] + for layer in self.convs: + x = layer(x) + outs += [x] + x = self.output_conv(x) + x = torch.flatten(x, 1, -1) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all the layers.""" + + def _apply_spectral_norm(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + +class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN multi-period discriminator module.""" + + def __init__( + self, + periods=[2, 3, 5, 7, 11], + discriminator_params={ + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initialize HiFiGANMultiPeriodDiscriminator module. + + Args: + periods (list): List of periods. + discriminator_params (dict): Parameters for hifi-gan period discriminator module. + The period parameter will be overwritten. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + for period in periods: + params = copy.deepcopy(discriminator_params) + params["period"] = period + self.discriminators += [HiFiGANPeriodDiscriminator(**params)] + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + + return outs + + +class HiFiGANScaleDiscriminator(torch.nn.Module): + """HiFi-GAN scale discriminator module.""" + + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_sizes=[15, 41, 5, 3], + channels=128, + max_downsample_channels=1024, + max_groups=16, + bias=True, + downsample_scales=[2, 2, 4, 4, 1], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.1}, + use_weight_norm=True, + use_spectral_norm=False, + ): + """Initilize HiFiGAN scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (list): List of four kernel sizes. The first will be used for the first conv layer, + and the second is for downsampling part, and the remaining two are for output layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (list): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. + If set to true, it will be applied to all the conv layers. + + """ + super().__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 4 + for ks in kernel_sizes: + assert ks % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_channels, + channels, + # NOTE(kan-bayashi): Use always the same kernel size + kernel_sizes[0], + bias=bias, + padding=(kernel_sizes[0] - 1) // 2, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + out_chs = channels + # NOTE(kan-bayashi): Remove hard coding? + groups = 4 + for downsample_scale in downsample_scales: + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[1], + stride=downsample_scale, + padding=(kernel_sizes[1] - 1) // 2, + groups=groups, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Remove hard coding? + out_chs = min(in_chs * 2, max_downsample_channels) + # NOTE(kan-bayashi): Remove hard coding? + groups = min(groups * 4, max_groups) + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, + out_channels, + kernel_size=kernel_sizes[3], + stride=1, + padding=(kernel_sizes[3] - 1) // 2, + bias=bias, + ), + ] + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + if use_spectral_norm: + self.apply_spectral_norm() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all the layers.""" + + def _apply_spectral_norm(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + +class HiFiGANMultiScaleDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale discriminator module.""" + + def __init__( + self, + scales=3, + downsample_pooling="AvgPool1d", + # follow the official implementation setting + downsample_pooling_params={ + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + discriminator_params={ + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm=False, + ): + """Initilize HiFiGAN multi-scale discriminator module. + + Args: + scales (int): Number of multi-scales. + downsample_pooling (str): Pooling module name for downsampling of the inputs. + downsample_pooling_params (dict): Parameters for the above pooling module. + discriminator_params (dict): Parameters for hifi-gan scale discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the official + implementaion. The first discriminator uses spectral norm and the other + discriminators use weight norm. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for i in range(scales): + params = copy.deepcopy(discriminator_params) + if follow_official_norm: + if i == 0: + params["use_weight_norm"] = False + params["use_spectral_norm"] = True + else: + params["use_weight_norm"] = True + params["use_spectral_norm"] = False + self.discriminators += [HiFiGANScaleDiscriminator(**params)] + self.pooling = getattr(torch.nn, downsample_pooling)( + **downsample_pooling_params + ) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + x = self.pooling(x) + + return outs + + +class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale + multi-period discriminator module.""" + + def __init__( + self, + # Multi-scale discriminator related + scales=3, + scale_downsample_pooling="AvgPool1d", + scale_downsample_pooling_params={ + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + scale_discriminator_params={ + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm=True, + # Multi-period discriminator related + periods=[2, 3, 5, 7, 11], + period_discriminator_params={ + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initilize HiFiGAN multiscale + multi-period discriminator module. + + Args: + scales (int): Number of multi-scales. + scale_downsample_pooling (str): Pooling module name for downsampling of the inputs. + scale_downsample_pooling_params (dict): Parameters for the above pooling module. + scale_discriminator_params (dict): Parameters for hifi-gan scale discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the official + implementaion. The first discriminator uses spectral norm and the other + discriminators use weight norm. + periods (list): List of periods. + period_discriminator_params (dict): Parameters for hifi-gan period discriminator module. + The period parameter will be overwritten. + + """ + super().__init__() + self.msd = HiFiGANMultiScaleDiscriminator( + scales=scales, + downsample_pooling=scale_downsample_pooling, + downsample_pooling_params=scale_downsample_pooling_params, + discriminator_params=scale_discriminator_params, + follow_official_norm=follow_official_norm, + ) + self.mpd = HiFiGANMultiPeriodDiscriminator( + periods=periods, + discriminator_params=period_discriminator_params, + ) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, + which consists of each layer output tensors. + Multiscale and multi period ones are concatenated. + + """ + msd_outs = self.msd(x) + mpd_outs = self.mpd(x) + return msd_outs + mpd_outs diff --git a/vec2wav2/models/melgan.py b/vec2wav2/models/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..d3129e17885efe7f9855ec1365dfd468e358ef48 --- /dev/null +++ b/vec2wav2/models/melgan.py @@ -0,0 +1,516 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""MelGAN Modules.""" + +import logging + +import numpy as np +import torch + +from vec2wav2.layers import CausalConv1d +from vec2wav2.layers import CausalConvTranspose1d +from vec2wav2.layers import ResidualStack +from vec2wav2.utils import read_hdf5 + + +class MelGANGenerator(torch.nn.Module): + """MelGAN generator module.""" + + def __init__( + self, + in_channels=80, + out_channels=1, + kernel_size=7, + channels=512, + bias=True, + upsample_scales=[8, 8, 2, 2], + stack_kernel_size=3, + stacks=3, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_final_nonlinear_activation=True, + use_weight_norm=True, + use_causal_conv=False, + ): + """Initialize MelGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of initial and final conv layer. + channels (int): Initial number of channels for conv layer. + bias (bool): Whether to add bias parameter in convolution layers. + upsample_scales (list): List of upsampling scales. + stack_kernel_size (int): Kernel size of dilated conv layers in residual stack. + stacks (int): Number of stacks in a single residual stack. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(MelGANGenerator, self).__init__() + + # check hyper parameters is valid + assert channels >= np.prod(upsample_scales) + assert channels % (2 ** len(upsample_scales)) == 0 + if not use_causal_conv: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + # add initial layer + layers = [] + if not use_causal_conv: + layers += [ + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias), + ] + else: + layers += [ + CausalConv1d( + in_channels, + channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, + ), + ] + + for i, upsample_scale in enumerate(upsample_scales): + # add upsampling layer + layers += [ + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + bias=bias, + ) + ] + else: + layers += [ + CausalConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_scale * 2, + stride=upsample_scale, + bias=bias, + ) + ] + + # add residual stack + for j in range(stacks): + layers += [ + ResidualStack( + kernel_size=stack_kernel_size, + channels=channels // (2 ** (i + 1)), + dilation=stack_kernel_size ** j, + bias=bias, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + use_causal_conv=use_causal_conv, + ) + ] + + # add final layer + layers += [ + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + ] + if not use_causal_conv: + layers += [ + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias + ), + ] + else: + layers += [ + CausalConv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, + ), + ] + if use_final_nonlinear_activation: + layers += [torch.nn.Tanh()] + + # define the model as a single function + self.melgan = torch.nn.Sequential(*layers) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + # initialize pqmf for inference + self.pqmf = None + + def forward(self, c): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, 1, T ** prod(upsample_scales)). + + """ + return self.melgan(c) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + """Reset parameters. + + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + + """ + + def _reset_parameters(m): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + m.weight.data.normal_(0.0, 0.02) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def register_stats(self, stats): + """Register stats for de-normalization as buffer. + + Args: + stats (str): Path of statistics file (".npy" or ".h5"). + + """ + assert stats.endswith(".h5") or stats.endswith(".npy") + if stats.endswith(".h5"): + mean = read_hdf5(stats, "mean").reshape(-1) + scale = read_hdf5(stats, "scale").reshape(-1) + else: + mean = np.load(stats)[0].reshape(-1) + scale = np.load(stats)[1].reshape(-1) + self.register_buffer("mean", torch.from_numpy(mean).float()) + self.register_buffer("scale", torch.from_numpy(scale).float()) + logging.info("Successfully registered stats as buffer.") + + def inference(self, c, normalize_before=False): + """Perform inference. + + Args: + c (Union[Tensor, ndarray]): Input tensor (T, in_channels). + normalize_before (bool): Whether to perform normalization. + + Returns: + Tensor: Output tensor (T ** prod(upsample_scales), out_channels). + + """ + # if not isinstance(c, torch.Tensor): + # c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) + if normalize_before: + c = (c - self.mean) / self.scale + c = self.melgan(c.transpose(1, 2)) + if self.pqmf is not None: + c = self.pqmf.synthesis(c) + return c.squeeze(0).transpose(1, 0) + + +class MelGANDiscriminator(torch.nn.Module): + """MelGAN discriminator module.""" + + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_sizes=[5, 3], + channels=16, + max_downsample_channels=1024, + bias=True, + downsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + ): + """Initilize MelGAN discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, + the last two layers' kernel size will be 5 and 3, respectively. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (list): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + + """ + super(MelGANDiscriminator, self).__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1 + assert kernel_sizes[1] % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), + torch.nn.Conv1d( + in_channels, channels, np.prod(kernel_sizes), bias=bias + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + for downsample_scale in downsample_scales: + out_chs = min(in_chs * downsample_scale, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=downsample_scale * 10 + 1, + stride=downsample_scale, + padding=downsample_scale * 5, + groups=in_chs // 4, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_sizes[0], + padding=(kernel_sizes[0] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, + out_channels, + kernel_sizes[1], + padding=(kernel_sizes[1] - 1) // 2, + bias=bias, + ), + ] + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + +class MelGANMultiScaleDiscriminator(torch.nn.Module): + """MelGAN multi-scale discriminator module.""" + + def __init__( + self, + in_channels=1, + out_channels=1, + scales=3, + downsample_pooling="AvgPool1d", + # follow the official implementation setting + downsample_pooling_params={ + "kernel_size": 4, + "stride": 2, + "padding": 1, + "count_include_pad": False, + }, + kernel_sizes=[5, 3], + channels=16, + max_downsample_channels=1024, + bias=True, + downsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_weight_norm=True, + ): + """Initilize MelGAN multi-scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scales (int): Number of multi-scales. + downsample_pooling (str): Pooling module name for downsampling of the inputs. + downsample_pooling_params (dict): Parameters for the above pooling module. + kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (list): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(MelGANMultiScaleDiscriminator, self).__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for _ in range(scales): + self.discriminators += [ + MelGANDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + channels=channels, + max_downsample_channels=max_downsample_channels, + bias=bias, + downsample_scales=downsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + ) + ] + self.pooling = getattr(torch.nn, downsample_pooling)( + **downsample_pooling_params + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + x = self.pooling(x) + + return outs + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + """Reset parameters. + + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + + """ + + def _reset_parameters(m): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + m.weight.data.normal_(0.0, 0.02) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) diff --git a/vec2wav2/models/prompt_prenet.py b/vec2wav2/models/prompt_prenet.py new file mode 100644 index 0000000000000000000000000000000000000000..24b0a32a892444e52e5e62a6284d809bc53751fa --- /dev/null +++ b/vec2wav2/models/prompt_prenet.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Yiwei Guo +# Derived mostly from fairseq (https://github.com/facebookresearch/fairseq) + +"""Prompt Pre-net Modules.""" + +import math + +import torch.nn as nn +from vec2wav2.models.fairseq_modules.fp32_group_norm import Fp32GroupNorm +from vec2wav2.models.fairseq_modules.layer_norm import Fp32LayerNorm +from vec2wav2.models.fairseq_modules.transpose_last import TransposeLast +import torch + + +def norm_block(is_layer_norm, dim, affine=True): + if is_layer_norm: + mod = nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=affine), + TransposeLast(), + ) + else: + mod = Fp32GroupNorm(1, dim, affine=affine) + + return mod + + +class ZeroPad1d(nn.Module): + def __init__(self, pad_left, pad_right): + super().__init__() + self.pad_left = pad_left + self.pad_right = pad_right + + def forward(self, x): + return nn.functional.pad(x, (self.pad_left, self.pad_right)) + + +class ConvPromptPrenet(nn.Module): + def __init__( + self, + conv_layers, + embed, + dropout, + skip_connections, + residual_scale, + non_affine_group_norm, + conv_bias, + activation, + ): + super().__init__() + + def block(n_in, n_out, k, stride, pad): + return nn.Sequential( + nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias, padding=pad), + nn.Dropout(p=dropout), + norm_block(False, n_out, affine=not non_affine_group_norm), + activation, + ) + + in_d = embed + self.conv_layers = nn.ModuleList() + self.residual_proj = nn.ModuleList() + for dim, k, stride, pad in conv_layers: + if in_d != dim and skip_connections: + self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) + else: + self.residual_proj.append(None) + + self.conv_layers.append(block(in_d, dim, k, stride, pad)) + in_d = dim + self.conv_layers = nn.Sequential(*self.conv_layers) + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + for rproj, conv in zip(self.residual_proj, self.conv_layers): + residual = x + x = conv(x) + if self.skip_connections: + if rproj is not None: + residual = rproj(residual) + x = (x + residual) * self.residual_scale + return x diff --git a/vec2wav2/models/quantization/__init__.py b/vec2wav2/models/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7061f130235610db9195f1833cecf819460dd0b9 --- /dev/null +++ b/vec2wav2/models/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from vec2wav2.models.quantization.vq import QuantizedResult, ResidualVectorQuantizer diff --git a/vec2wav2/models/quantization/ac.py b/vec2wav2/models/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..6b80be289989d1a4a5b7fd0e238ac2a4a451c7aa --- /dev/null +++ b/vec2wav2/models/quantization/ac.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, + roundoff: float = 1e-8, min_range: int = 2, + check: bool = True) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2 ** total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take fewer bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2 ** self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + + def flush(self): + """Flush the remaining information to the stream. + """ + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the same common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= (b1 << self.max_bit) + self.high -= (b1 << self.max_bit) + self.current -= (b1 << self.max_bit) + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exactly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2 ** self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return mid, low, high, self.current + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/vec2wav2/models/quantization/core_vq.py b/vec2wav2/models/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..a232226fc1843309e585a13d828f64ad1a60cd82 --- /dev/null +++ b/vec2wav2/models/quantization/core_vq.py @@ -0,0 +1,423 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Modified by Yiwei Guo, 2024 +# including Group VQ + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F +import logging +import numpy as np +import vec2wav2.distributed.distrib as distrib + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +def preprocess(x): + x = rearrange(x, "... d -> (...) d") + return x + + +def postprocess_emb(embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + init_codebook=None, + perform_expire: bool = False, + perform_ema_update: bool = False, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if (not kmeans_init) and (init_codebook is not None) else torch.zeros + embed = init_fn(codebook_size, dim) + if init_codebook is not None: + # So that we don't need to care about mask issues in forward + embed[:] = torch.from_numpy(init_codebook) # the saved codebook is [V, D] + logging.info(f"Initiated EuclideanCodebook from {init_codebook}") + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([(not kmeans_init) or (init_codebook is not None)])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + self.training = True + self.perform_expire = perform_expire + self.perform_ema_update = perform_ema_update + + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + distrib.broadcast_tensors(self.buffers()) + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + if self.perform_expire: + self.expire_codes_(x) + if self.perform_ema_update: + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently, supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1., + init_codebook=None, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + init_codebook=init_codebook) + self.codebook_size = codebook_size + self.training = True + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + # NOTE: we bypass the rearranging of the tensor, and directly input [something, D] tensor. + device = x.device + # x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() # NOTE: pass grad. + + loss = torch.tensor(0.0, device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + # quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + +class GroupVectorQuantization(nn.Module): + def __init__(self, *, num_quantizers, init_codebook=None, **kwargs): + super().__init__() + if (init_codebook is not None) and (len(init_codebook) > 0): + init_codebook = np.load(init_codebook) # [G, V, D] + else: + init_codebook = None + self.groups = nn.ModuleList( + [VectorQuantization(init_codebook=init_codebook[group_id] if init_codebook is not None else None, + **kwargs) for group_id in range(num_quantizers)] + ) + self.n_group = num_quantizers + + def forward(self, x): + # x is assumed to have [something, D] shape. + num, dim_in = x.shape + dim_per_q = dim_in // self.n_group + quantized_result = torch.zeros_like(x) + quantized_indices = torch.zeros(num, self.n_group).long().to(x.device) + loss = torch.tensor(0.0).to(x.device) + for group_i in range(self.n_group): + dim_range = slice(group_i * dim_per_q, (group_i+1) * dim_per_q) + quantized_i, indices_i, loss_i = self.groups[group_i](x[:, dim_range]) + quantized_result[:, dim_range] = quantized_i + quantized_indices[:, group_i] = indices_i + loss = loss + loss_i + return quantized_result, quantized_indices, loss + + def encode(self, x): + return self.forward(x)[1] # only return indices + + def decode(self, q_indices): + # [B, G, L] + quantized_out = [] # tensors of shape [B, D, L] + for group_i in range(q_indices.shape[1]): + quantized_out.append(self.groups[group_i].decode(q_indices[:, group_i, :])) + return torch.concat(quantized_out, dim=1) diff --git a/vec2wav2/models/quantization/vq.py b/vec2wav2/models/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..d45725d261374ec8f37e2d8915a77102bf941ea4 --- /dev/null +++ b/vec2wav2/models/quantization/vq.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +from vec2wav2.models.quantization.core_vq import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.training = True + self.model = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + sample_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + quantized, codes, commit_loss = self.model(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: + """Return n_q based on specified target bandwidth. + """ + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.: + n_q = int(max(1, math.floor(bandwidth / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, sample_rate: int): + """Return bandwidth per quantizer for a given input sample rate. + """ + return math.log2(self.bins) * sample_rate / 1000 + + def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + codes = self.model.encode(x, n_q=n_q) + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + """ + quantized = self.model.decode(codes) + return quantized diff --git a/vec2wav2/models/v2w2.py b/vec2wav2/models/v2w2.py new file mode 100644 index 0000000000000000000000000000000000000000..e44fcd4d1e1472ca8dc22561396adc6bb2a24505 --- /dev/null +++ b/vec2wav2/models/v2w2.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Yiwei Guo +# Licensed under the Apache 2.0 license. + +"""vec2wav2.0 main architectures""" + +import torch +from vec2wav2.models.conformer.decoder import Decoder as ConformerDecoder +from vec2wav2.utils import crop_seq +from vec2wav2.models.bigvgan import BigVGAN +from vec2wav2.models.prompt_prenet import ConvPromptPrenet +import logging + + +class CTXVEC2WAVFrontend(torch.nn.Module): + + def __init__(self, + prompt_net_type, + num_mels, + vqvec_channels, + prompt_channels, + conformer_params): + + super(CTXVEC2WAVFrontend, self).__init__() + + if prompt_net_type == "ConvPromptPrenet": + self.prompt_prenet = ConvPromptPrenet( + embed=prompt_channels, + conv_layers=[(128, 3, 1, 1), (256, 5, 1, 2), (512, 5, 1, 2), (conformer_params["attention_dim"], 3, 1, 1)], + dropout=0.1, + skip_connections=True, + residual_scale=0.25, + non_affine_group_norm=False, + conv_bias=True, + activation=torch.nn.ReLU() + ) + elif prompt_net_type == "Conv1d": + self.prompt_prenet = torch.nn.Conv1d(prompt_channels, conformer_params["attention_dim"], kernel_size=5, padding=2) + else: + raise NotImplementedError + + self.encoder1 = ConformerDecoder(vqvec_channels, input_layer='linear', **conformer_params) + + self.hidden_proj = torch.nn.Linear(conformer_params["attention_dim"], conformer_params["attention_dim"]) + + self.encoder2 = ConformerDecoder(0, input_layer=None, **conformer_params) + self.mel_proj = torch.nn.Linear(conformer_params["attention_dim"], num_mels) + + def forward(self, vqvec, prompt, mask=None, prompt_mask=None): + """ + params: + vqvec: sequence of VQ-vectors. + prompt: sequence of mel-spectrogram prompt (acoustic context) + mask: mask of the vqvec. True or 1 stands for valid values. + prompt_mask: mask of the prompt. + vqvec and prompt are of shape [B, D, T]. All masks are of shape [B, T]. + returns: + enc_out: the input to the vec2wav2 Generator (BigVGAN); + mel: the frontend predicted mel spectrogram (for faster convergence); + """ + prompt = self.prompt_prenet(prompt.transpose(1, 2)).transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(-2) + if prompt_mask is not None: + prompt_mask = prompt_mask.unsqueeze(-2) + enc_out, _ = self.encoder1(vqvec, mask, prompt, prompt_mask) + + h = self.hidden_proj(enc_out) + + enc_out, _ = self.encoder2(h, mask, prompt, prompt_mask) + mel = self.mel_proj(enc_out) # (B, L, 80) + + return enc_out, mel, None + + +class VEC2WAV2Generator(torch.nn.Module): + + def __init__(self, frontend: CTXVEC2WAVFrontend, backend: BigVGAN): + + super(VEC2WAV2Generator, self).__init__() + self.frontend = frontend + self.backend = backend + + def forward(self, vqvec, prompt, mask=None, prompt_mask=None, crop_len=0, crop_offsets=None): + """ + :param vqvec: (torch.Tensor) The shape is (B, L, D). Sequence of VQ-vectors. + :param prompt: (torch.Tensor) The shape is (B, L', 80). Sequence of mel-spectrogram prompt (acoustic context) + :param mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L). True or 1 stands for valid values in `vqvec`. + :param prompt_mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L'). True or 1 stands for valid values in `prompt`. + :return: frontend predicted mel spectrogram; reconstructed waveform. + """ + h, mel, _ = self.frontend(vqvec, prompt, mask=mask, prompt_mask=prompt_mask) # (B, L, adim), (B, L, 80) + if mask is not None: + h = h.masked_fill(~mask.unsqueeze(-1), 0) + h = h.transpose(1, 2) + if crop_len > 0: + h = crop_seq(h, crop_offsets, crop_len) + if prompt_mask is not None: + prompt_avg = prompt.masked_fill(~prompt_mask.unsqueeze(-1), 0).sum(1) / prompt_mask.sum(1).unsqueeze(-1) + else: + prompt_avg = prompt.mean(1) + wav = self.backend(h, prompt_avg) # (B, C, T) + return mel, None, wav + + def inference(self, vqvec, prompt): + h, mel, _ = self.frontend(vqvec, prompt) + wav = self.backend(h.transpose(1,2), prompt.mean(1)) + + return mel, None, wav + diff --git a/vec2wav2/optimizers/__init__.py b/vec2wav2/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db777e82841eb9e5cbcb28ba46634b6807c986a4 --- /dev/null +++ b/vec2wav2/optimizers/__init__.py @@ -0,0 +1,3 @@ +from torch.optim import * # NOQA + +from .radam import * # NOQA diff --git a/vec2wav2/optimizers/__pycache__/__init__.cpython-310.pyc b/vec2wav2/optimizers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64c671e28108a11e16c8a1efb446b7862210d4ad Binary files /dev/null and b/vec2wav2/optimizers/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/optimizers/__pycache__/__init__.cpython-39.pyc b/vec2wav2/optimizers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194b5e09228a4fe2b5e87cffe8209cdd7aa0b29b Binary files /dev/null and b/vec2wav2/optimizers/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/optimizers/__pycache__/radam.cpython-310.pyc b/vec2wav2/optimizers/__pycache__/radam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fff30ed81731139ffbc13571e9675fe886c88e53 Binary files /dev/null and b/vec2wav2/optimizers/__pycache__/radam.cpython-310.pyc differ diff --git a/vec2wav2/optimizers/__pycache__/radam.cpython-39.pyc b/vec2wav2/optimizers/__pycache__/radam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a9255b53725754f47c79554a352df45e2f77d84 Binary files /dev/null and b/vec2wav2/optimizers/__pycache__/radam.cpython-39.pyc differ diff --git a/vec2wav2/optimizers/radam.py b/vec2wav2/optimizers/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..36ae6e5cec78cf0a2db40b4cf8089289820ccbbe --- /dev/null +++ b/vec2wav2/optimizers/radam.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +"""RAdam optimizer. + +This code is derived from https://github.com/LiyuanLucasLiu/RAdam. +""" + +import math +import torch + +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + """Rectified Adam optimizer.""" + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + """Initilize RAdam optimizer.""" + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + """Set state.""" + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + """Run one step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state["step"] += 1 + buffered = self.buffer[int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / ( + 1 - beta1 ** state["step"] + ) # NOQA + else: + step_size = 1.0 / (1 - beta1 ** state["step"]) + buffered[2] = step_size + + if group["weight_decay"] != 0: + p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group["lr"], exp_avg) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/vec2wav2/ssl_models/WavLM.py b/vec2wav2/ssl_models/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..1179957953355bc8dc1080fba3f389a09717b96e --- /dev/null +++ b/vec2wav2/ssl_models/WavLM.py @@ -0,0 +1,752 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from vec2wav2.ssl_models.wavlm_modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + # ============= GYW ADD ============== + if padding_mask.size(1) < features.size(1): + extra = features.size(1) - padding_mask.size(1) + padding_mask = torch.concat([padding_mask, + torch.ones(len(padding_mask), extra).bool().to(padding_mask.device)], + dim=-1) + # ==================================== + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + # print(padding_mask) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + print(features.shape, padding_mask.shape) + padding_mask = self.forward_padding_mask(features, padding_mask) + # print(padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias diff --git a/vec2wav2/ssl_models/__init__.py b/vec2wav2/ssl_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vec2wav2/ssl_models/__pycache__/WavLM.cpython-310.pyc b/vec2wav2/ssl_models/__pycache__/WavLM.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b1da8d61aef01de4e8ef602b838cea18918bfe7 Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/WavLM.cpython-310.pyc differ diff --git a/vec2wav2/ssl_models/__pycache__/__init__.cpython-310.pyc b/vec2wav2/ssl_models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b00a95cb6e285e323eaf80eb50b312ea371cabb Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/ssl_models/__pycache__/__init__.cpython-311.pyc b/vec2wav2/ssl_models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de365b5dd05a09266707cafd6f0dc36abccbfa95 Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/__init__.cpython-311.pyc differ diff --git a/vec2wav2/ssl_models/__pycache__/vqw2v_extractor.cpython-310.pyc b/vec2wav2/ssl_models/__pycache__/vqw2v_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b3cdb5ef9622ef44b6f4311af7cd43b21730045 Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/vqw2v_extractor.cpython-310.pyc differ diff --git a/vec2wav2/ssl_models/__pycache__/vqw2v_extractor.cpython-311.pyc b/vec2wav2/ssl_models/__pycache__/vqw2v_extractor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddb3269c81ad40b2defa6c2330f85f315d3595d9 Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/vqw2v_extractor.cpython-311.pyc differ diff --git a/vec2wav2/ssl_models/__pycache__/w2v2_extractor.cpython-310.pyc b/vec2wav2/ssl_models/__pycache__/w2v2_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11738224b9b73be22df7d6224b1a2f66e5fc77ea Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/w2v2_extractor.cpython-310.pyc differ diff --git a/vec2wav2/ssl_models/__pycache__/wavlm_extractor.cpython-310.pyc b/vec2wav2/ssl_models/__pycache__/wavlm_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc51f1c09045dbfdd32017453b04eb07d928f3ad Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/wavlm_extractor.cpython-310.pyc differ diff --git a/vec2wav2/ssl_models/__pycache__/wavlm_modules.cpython-310.pyc b/vec2wav2/ssl_models/__pycache__/wavlm_modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..364c4c7c8d3e16385a0c9cddc275d4344ee3cb61 Binary files /dev/null and b/vec2wav2/ssl_models/__pycache__/wavlm_modules.cpython-310.pyc differ diff --git a/vec2wav2/ssl_models/vqw2v_extractor.py b/vec2wav2/ssl_models/vqw2v_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b4150dea6a771d98c74b772d28ef8f9c440909ca --- /dev/null +++ b/vec2wav2/ssl_models/vqw2v_extractor.py @@ -0,0 +1,67 @@ +# Copyright 2024 Yiwei Guo +# Licensed under Apache 2.0 + +"""Extract VQ indexes using vq-wav2vec model (from fairseq)""" + +import torch +import logging +from kaldiio import WriteHelper +import os +import fairseq +import argparse +import numpy as np +from pathlib import Path +import soundfile as sf +from tqdm import tqdm +from vec2wav2.utils.utils import read_wav_16k + +logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') + +class Extractor: + def __init__(self, checkpoint="pretrained/vq-wav2vec_kmeans.pt", device="cuda"): + self.device = device + self.model, self.cfg, self.task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint]) + self.model = self.model[0].to(device) + self.model.eval() + for p in self.model.parameters(): + p.requires_grad_(False) + + def extract(self, wav: np.ndarray) -> torch.Tensor: + with torch.no_grad(): + audio = torch.from_numpy(wav).float().unsqueeze(0).to(self.device) + + z = self.model.feature_extractor(audio) + _, idxs = self.model.vector_quantizer.forward_idx(z) + return idxs[0].cpu() # [L, Groups] + + def get_codebook(self) -> np.ndarray: + quantizer = self.model.vector_quantizer + if self.cfg.model.vq_type == "kmeans": + codebook = quantizer.expand_embedding.data.transpose(0,1).contiguous() + elif self.cfg.model.vq_type == "gumbel": + codebook = quantizer.vars.data + if quantizer.combine_groups: + codebook = codebook.repeat(1, quantizer.groups, 1) + codebook = codebook.view(quantizer.groups, quantizer.num_vars, -1) + + codebook = codebook.cpu().numpy() + return codebook + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--wav-scp', type=str) + parser.add_argument("--out-dir", type=str) + parser.add_argument('--model', default="pretrained/vq-wav2vec_kmeans.pt", type=str) + args = parser.parse_args() + + extractor = Extractor(checkpoint=args.model, device="cuda" if torch.cuda.is_available() else "cpu") + + out_dir=Path(args.out_dir).absolute() + with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer: + for line in tqdm(f.readlines()): + uttid, wav_path = line.strip().split(maxsplit=1) + logging.info("Extracting " + uttid) + audio = read_wav_16k(wav_path) + idxs = extractor.extract(audio).cpu().numpy() + idxs = idxs.astype(float) + writer(uttid, idxs) diff --git a/vec2wav2/ssl_models/w2v2_extractor.py b/vec2wav2/ssl_models/w2v2_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7f89dfe4dc72e8ac5ad4a40684090e3f953ac1 --- /dev/null +++ b/vec2wav2/ssl_models/w2v2_extractor.py @@ -0,0 +1,72 @@ +# Copyright 2024 Yiwei Guo +# Licensed under Apache 2.0 + +"""Extract VQ indexes using wav2vec2.0 model (from fairseq)""" + +import torch +import logging +from kaldiio import WriteHelper +import os +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining +import argparse +import numpy as np +from pathlib import Path +import soundfile as sf +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') + +class Extractor: + def __init__(self, checkpoint="pretrained/wav2vec2-large-lv60/", device="cuda"): + self.device = device + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint) + model = Wav2Vec2ForPreTraining.from_pretrained(checkpoint) + model.to(self.device) + model.half() + model.eval() + self.model = model + self.feature_extractor = feature_extractor + logging.info(self.model) + for p in self.model.parameters(): + p.requires_grad_(False) + + def extract(self, wav: np.ndarray, sample_rate: int) -> torch.Tensor: + with torch.no_grad(): + wav = torch.from_numpy(wav).float() + + input_values = self.feature_extractor(wav, return_tensors="pt", sampling_rate=sample_rate).input_values + input_values = input_values.half().to(self.device) + outputs = self.model.wav2vec2(input_values) + extract_features = self.model.dropout_features(outputs[1]) + hidden_states = extract_features + batch_size, sequence_length, hidden_size = hidden_states.shape + hidden_states = self.model.quantizer.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.model.quantizer.num_groups, -1) + codevector_idx = hidden_states.argmax(dim=-1) + idxs = codevector_idx.view(batch_size, sequence_length, self.model.quantizer.num_groups) + return idxs[0].cpu() # [L, Groups] + + def get_codebook(self) -> np.ndarray: + quantizer = self.model.quantizer + codebook = quantizer.codevectors # (1, 640, 384) + codebook = codebook.view(quantizer.num_groups, quantizer.num_vars, -1) # (2, 320, 384) + return codebook.cpu().numpy() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--wav-scp', type=str) + parser.add_argument("--out-dir", type=str) + parser.add_argument('--model', default="pretrained/wav2vec2-large-lv60/", type=str) + args = parser.parse_args() + + extractor = Extractor(checkpoint=args.model, device="cuda" if torch.cuda.is_available() else "cpu") + + out_dir=Path(args.out_dir).absolute() + with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer: + for line in tqdm(f.readlines()): + uttid, wav_path = line.strip().split(maxsplit=1) + logging.info("Extracting " + uttid) + audio, sample_rate = sf.read(wav_path) + idxs = extractor.extract(audio, sample_rate=sample_rate) + idxs = idxs.astype(float) + writer(uttid, idxs) diff --git a/vec2wav2/ssl_models/wavlm_extractor.py b/vec2wav2/ssl_models/wavlm_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..bb139cc75143b7eb769c7d75b19d0baed67012b5 --- /dev/null +++ b/vec2wav2/ssl_models/wavlm_extractor.py @@ -0,0 +1,82 @@ +# Copyright 2024 Yiwei Guo +# Licensed under Apache 2.0 + +"""Extract VQ indexes using WavLM model (from microsoft UniLM)""" + +import torch +from vec2wav2.ssl_models.WavLM import WavLM, WavLMConfig +import soundfile as sf +from vec2wav2.utils.espnet_utils import pad_list, make_pad_mask +import time +from pathlib import Path +import argparse +from kaldiio import WriteHelper +from tqdm import tqdm +import logging +from vec2wav2.utils.utils import read_wav_16k + +class Extractor: + def __init__(self, checkpoint="pretrained/WavLM-Large.pt", device="cuda", output_layer=6): + self.device = device + checkpoint = torch.load(checkpoint) + self.cfg = WavLMConfig(checkpoint['cfg']) + self.model = WavLM(self.cfg) + self.model.load_state_dict(checkpoint['model']) + self.model.to(device) + self.model.eval() + for p in self.model.parameters(): + p.requires_grad_(False) + self.output_layer = output_layer + + def extract(self, wav): + with torch.no_grad(): + wav_input_16khz = torch.from_numpy(wav).unsqueeze(0).float().to(self.device) + if self.cfg.normalize: + wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz, wav_input_16khz.shape) + rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer)[0] + return rep.squeeze(0).clone().detach() # torch.tensor [T, D] + + def extract_batch(self, wav_list, frame_lens): + # suppose wav is already a tensor padded with 0 + # should be careful with LayerNorm since it may cause difference between batch vs single modes. + pad_mask = make_pad_mask(frame_lens).to(self.device) + with torch.no_grad(): + wav_input_16khz = [torch.from_numpy(wav).float().to(self.device) for wav in wav_list] + if self.cfg.normalize: + wav_input_16khz = [torch.nn.functional.layer_norm(wav, wav.shape) for wav in wav_input_16khz] + wav_input_16khz = pad_list(wav_input_16khz, 0) + s = time.time() + rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer, padding_mask=pad_mask)[0] + t = time.time() + print(f'in batch mode, pure extracting costs {t-s} s') + return rep.clone().detach() # [B, T, D] + + +def calc_out_len(in_len, k, s): + return int((in_len-(k-1)-1)/s + 1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--wav-scp', type=str) + parser.add_argument("--out-dir", type=str) + parser.add_argument('--model', default="pretrained/WavLM-Large.pt", type=str) + parser.add_argument('--output-layer', default=6, type=int) + args = parser.parse_args() + + extractor = Extractor(checkpoint=args.model, + device="cuda" if torch.cuda.is_available() else "cpu", + output_layer=args.output_layer) + + out_dir=Path(args.out_dir).absolute() + out_dir.mkdir(parents=True, exist_ok=True) + + with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer: + for line in tqdm(f.readlines()): + uttid, wav_path = line.strip().split(maxsplit=1) + logging.info("Extracting " + uttid) + audio = read_wav_16k(wav_path) + rep = extractor.extract(audio) + rep = rep.cpu().numpy() + writer(uttid, rep) + \ No newline at end of file diff --git a/vec2wav2/ssl_models/wavlm_modules.py b/vec2wav2/ssl_models/wavlm_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1dcfc6f061cc189ca51fc90107116f38e2e48daf --- /dev/null +++ b/vec2wav2/ssl_models/wavlm_modules.py @@ -0,0 +1,827 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/vec2wav2/utils/__init__.py b/vec2wav2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fa95a020706b5412c3959fbf6e5980019c0d5f --- /dev/null +++ b/vec2wav2/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * # NOQA diff --git a/vec2wav2/utils/__pycache__/__init__.cpython-310.pyc b/vec2wav2/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c5d457be8414a05c9cf56da5b96b05ee39eaf80 Binary files /dev/null and b/vec2wav2/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/vec2wav2/utils/__pycache__/__init__.cpython-39.pyc b/vec2wav2/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90c95599acffcad59079678145d453a61bdcc6c7 Binary files /dev/null and b/vec2wav2/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/vec2wav2/utils/__pycache__/espnet_utils.cpython-310.pyc b/vec2wav2/utils/__pycache__/espnet_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1219fcb36314d01d7fe19d9051e6c835ee55e472 Binary files /dev/null and b/vec2wav2/utils/__pycache__/espnet_utils.cpython-310.pyc differ diff --git a/vec2wav2/utils/__pycache__/espnet_utils.cpython-39.pyc b/vec2wav2/utils/__pycache__/espnet_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc4fc2ac82e612669dee5c9da8c9cfc693192298 Binary files /dev/null and b/vec2wav2/utils/__pycache__/espnet_utils.cpython-39.pyc differ diff --git a/vec2wav2/utils/__pycache__/utils.cpython-310.pyc b/vec2wav2/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..121e3e5af88a487eff30d60046c2cd0e3c9517a1 Binary files /dev/null and b/vec2wav2/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/vec2wav2/utils/__pycache__/utils.cpython-39.pyc b/vec2wav2/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..775badf60bebe93a955e9d24d1c6c4816d058309 Binary files /dev/null and b/vec2wav2/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/vec2wav2/utils/espnet_utils.py b/vec2wav2/utils/espnet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..45ceb556ab4cfcdfc0a84e3ad0cfbce9a2edfc17 --- /dev/null +++ b/vec2wav2/utils/espnet_utils.py @@ -0,0 +1,503 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" +# Retrieved from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/nets_utils.py +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + if isinstance(m, torch.nn.Module): + device = next(m.parameters()).device + elif isinstance(m, torch.Tensor): + device = m.device + else: + raise TypeError( + "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}" + ) + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, : xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.long().tolist() + + bs = int(len(lengths)) + if maxlen is None: + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + else: + assert xs is None + assert maxlen >= int(max(lengths)) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) + ).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == "c": + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if "real" not in x or "imag" not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x["real"], x["imag"]) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ( + "x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x)) + ) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == "transformer": + return np.array([1]) + + elif mode == "mt" and arch == "rnn": + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int64) + logging.warning("Subsampling is not performed for machine translation.") + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif ( + (mode == "asr" and arch in ("rnn", "rnn-t")) + or (mode == "mt" and arch == "rnn") + or (mode == "st" and arch == "rnn") + ): + subsample = np.ones(train_args.elayers + 1, dtype=np.int64) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mix": + subsample = np.ones( + train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64 + ) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range( + min(train_args.elayers_sd + train_args.elayers + 1, len(ss)) + ): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mulenc": + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64) + if train_args.etype[idx].endswith("p") and not train_args.etype[ + idx + ].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Encoder %d: Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN.", + idx + 1, + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch)) + + +def rename_state_dict( + old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor] +): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f"Rename: {old_prefix} -> {new_prefix}") + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + from espnet.nets.pytorch_backend.conformer.swish import Swish + + activation_funcs = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": Swish, + } + + return activation_funcs[act]() diff --git a/vec2wav2/utils/utils.py b/vec2wav2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f5aa98fb6539c05988c2a8bfa6443229b793a6a2 --- /dev/null +++ b/vec2wav2/utils/utils.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Utility functions.""" + +import fnmatch +import logging +import os +import sys +import tarfile + +from distutils.version import LooseVersion +from filelock import FileLock + +import h5py +import numpy as np +import torch +import yaml +import soundfile as sf +import torchaudio.transforms as transforms + +def read_wav_16k(audio_path): + """Process audio file to 16kHz sample rate""" + if isinstance(audio_path, tuple): # Gradio audio input returns (sample_rate, audio_data) + wav = audio_path[1] + sr = audio_path[0] + else: # Regular file path + assert os.path.exists(audio_path), f"File not found: {audio_path}" + wav, sr = sf.read(audio_path) + + if sr != 16000: + audio_tensor = torch.tensor(wav, dtype=torch.float32) + resampler = transforms.Resample(orig_freq=sr, new_freq=16000) + wav = resampler(audio_tensor) + wav = wav.numpy() + return wav + + + +def find_files(root_dir, query="*.wav", include_root_dir=True): + """Find files recursively. + + Args: + root_dir (str): Root root_dir to find. + query (str): Query to find. + include_root_dir (bool): If False, root_dir name is not included. + + Returns: + list: List of found filenames. + + """ + files = [] + for root, dirnames, filenames in os.walk(root_dir, followlinks=True): + for filename in fnmatch.filter(filenames, query): + files.append(os.path.join(root, filename)) + if not include_root_dir: + files = [file_.replace(root_dir + "/", "") for file_ in files] + + return files + + +def read_hdf5(hdf5_name, hdf5_path): + """Read hdf5 dataset. + + Args: + hdf5_name (str): Filename of hdf5 file. + hdf5_path (str): Dataset name in hdf5 file. + + Return: + any: Dataset values. + + """ + if not os.path.exists(hdf5_name): + logging.error(f"There is no such a hdf5 file ({hdf5_name}).") + sys.exit(1) + + hdf5_file = h5py.File(hdf5_name, "r") + + if hdf5_path not in hdf5_file: + logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") + sys.exit(1) + + hdf5_data = hdf5_file[hdf5_path][()] + hdf5_file.close() + + return hdf5_data + + +def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): + """Write dataset to hdf5. + + Args: + hdf5_name (str): Hdf5 dataset filename. + hdf5_path (str): Dataset path in hdf5. + write_data (ndarray): Data to write. + is_overwrite (bool): Whether to overwrite dataset. + + """ + # convert to numpy array + write_data = np.array(write_data) + + # check folder existence + folder_name, _ = os.path.split(hdf5_name) + if not os.path.exists(folder_name) and len(folder_name) != 0: + os.makedirs(folder_name) + + # check hdf5 existence + if os.path.exists(hdf5_name): + # if already exists, open with r+ mode + hdf5_file = h5py.File(hdf5_name, "r+") + # check dataset existence + if hdf5_path in hdf5_file: + if is_overwrite: + logging.warning( + "Dataset in hdf5 file already exists. " "recreate dataset in hdf5." + ) + hdf5_file.__delitem__(hdf5_path) + else: + logging.error( + "Dataset in hdf5 file already exists. " + "if you want to overwrite, please set is_overwrite = True." + ) + hdf5_file.close() + sys.exit(1) + else: + # if not exists, open with w mode + hdf5_file = h5py.File(hdf5_name, "w") + + # write data to hdf5 + hdf5_file.create_dataset(hdf5_path, data=write_data) + hdf5_file.flush() + hdf5_file.close() + + +class HDF5ScpLoader(object): + """Loader class for a fests.scp file of hdf5 file. + + Examples: + key1 /some/path/a.h5:feats + key2 /some/path/b.h5:feats + key3 /some/path/c.h5:feats + key4 /some/path/d.h5:feats + ... + >>> loader = HDF5ScpLoader("hdf5.scp") + >>> array = loader["key1"] + + key1 /some/path/a.h5 + key2 /some/path/b.h5 + key3 /some/path/c.h5 + key4 /some/path/d.h5 + ... + >>> loader = HDF5ScpLoader("hdf5.scp", "feats") + >>> array = loader["key1"] + + key1 /some/path/a.h5:feats_1,feats_2 + key2 /some/path/b.h5:feats_1,feats_2 + key3 /some/path/c.h5:feats_1,feats_2 + key4 /some/path/d.h5:feats_1,feats_2 + ... + >>> loader = HDF5ScpLoader("hdf5.scp") + # feats_1 and feats_2 will be concatenated + >>> array = loader["key1"] + + """ + + def __init__(self, feats_scp, default_hdf5_path="feats"): + """Initialize HDF5 scp loader. + + Args: + feats_scp (str): Kaldi-style feats.scp file with hdf5 format. + default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used. + + """ + self.default_hdf5_path = default_hdf5_path + with open(feats_scp) as f: + lines = [line.replace("\n", "") for line in f.readlines()] + self.data = {} + for line in lines: + key, value = line.split() + self.data[key] = value + + def get_path(self, key): + """Get hdf5 file path for a given key.""" + return self.data[key] + + def __getitem__(self, key): + """Get ndarray for a given key.""" + p = self.data[key] + if ":" in p: + if len(p.split(",")) == 1: + return read_hdf5(*p.split(":")) + else: + p1, p2 = p.split(":") + feats = [read_hdf5(p1, p) for p in p2.split(",")] + return np.concatenate( + [f if len(f.shape) != 1 else f.reshape(-1, 1) for f in feats], 1 + ) + else: + return read_hdf5(p, self.default_hdf5_path) + + def __len__(self): + """Return the length of the scp file.""" + return len(self.data) + + def __iter__(self): + """Return the iterator of the scp file.""" + return iter(self.data) + + def keys(self): + """Return the keys of the scp file.""" + return self.data.keys() + + def values(self): + """Return the values of the scp file.""" + for key in self.keys(): + yield self[key] + + +class NpyScpLoader(object): + """Loader class for a fests.scp file of npy file. + + Examples: + key1 /some/path/a.npy + key2 /some/path/b.npy + key3 /some/path/c.npy + key4 /some/path/d.npy + ... + >>> loader = NpyScpLoader("feats.scp") + >>> array = loader["key1"] + + """ + + def __init__(self, feats_scp): + """Initialize npy scp loader. + + Args: + feats_scp (str): Kaldi-style feats.scp file with npy format. + + """ + with open(feats_scp) as f: + lines = [line.replace("\n", "") for line in f.readlines()] + self.data = {} + for line in lines: + key, value = line.split() + self.data[key] = value + + def get_path(self, key): + """Get npy file path for a given key.""" + return self.data[key] + + def __getitem__(self, key): + """Get ndarray for a given key.""" + return np.load(self.data[key]) + + def __len__(self): + """Return the length of the scp file.""" + return len(self.data) + + def __iter__(self): + """Return the iterator of the scp file.""" + return iter(self.data) + + def keys(self): + """Return the keys of the scp file.""" + return self.data.keys() + + def values(self): + """Return the values of the scp file.""" + for key in self.keys(): + yield self[key] + + +def load_model(checkpoint, config=None): + """Load trained model. + + Args: + checkpoint (str): Checkpoint path. + config (dict): Configuration dict. + + Return: + torch.nn.Module: Model instance. + + """ + # load config if not provided + if config is None: + dirname = os.path.dirname(checkpoint) + config = os.path.join(dirname, "config.yml") + with open(config) as f: + config = yaml.load(f, Loader=yaml.Loader) + + # lazy load for circular error + import vec2wav2.models + + # get model and load parameters + model_class = getattr( + vec2wav2.models, + config.get("generator_type", "BigVGAN"), + ) + model = vec2wav2.models.VEC2WAV2Generator( + vec2wav2.models.CTXVEC2WAVFrontend(config["prompt_net_type"], config["num_mels"], **config["frontend_params"]), + model_class(**config["generator_params"]) + ) + model.load_state_dict( + torch.load(checkpoint, map_location="cpu")["model"]["generator"] + ) + + return model + +def load_feat_codebook(codebook: np.ndarray, device: str="cuda"): + """Given a codebook of shape [G, V, D], convert into a torch Module that can be called. + """ + feat_codebook = torch.tensor(codebook).to(device) # (2, 320, 384) + feat_codebook_numgroups = feat_codebook.shape[0] + feat_codebook = torch.nn.ModuleList([torch.nn.Embedding.from_pretrained(feat_codebook[i], freeze=True) for i in range(feat_codebook_numgroups)]).to(device) + return feat_codebook, feat_codebook_numgroups + +def idx2vec(codebook: torch.nn.Module, idx: torch.Tensor, num_groups: int): + """Given a codebook (converted, so can be called), and a idx tensor with shape [L, groups] or [B, L, groups] + Return the corresponding vectors + """ + return torch.cat([codebook[i](idx[..., i]) for i in range(num_groups)], dim=-1) # (L, D) + +def crop_seq(x, offsets, length): + """Crop padded tensor with specified length. + + :param x: (torch.Tensor) The shape is (B, C, D) + :param offsets: (list) + :param min_len: (int) + :return: + """ + B, C, T = x.shape + x_ = x.new_zeros(B, C, length) + for i in range(B): + x_[i, :] = x[i, :, offsets[i]: offsets[i] + length] + return x_ +