# -*- coding: utf-8 -*- # Copyright 2024 Yiwei Guo # Licensed under the Apache 2.0 license. """vec2wav2.0 main architectures""" import torch from vec2wav2.models.conformer.decoder import Decoder as ConformerDecoder from vec2wav2.utils import crop_seq from vec2wav2.models.bigvgan import BigVGAN from vec2wav2.models.prompt_prenet import ConvPromptPrenet import logging class CTXVEC2WAVFrontend(torch.nn.Module): def __init__(self, prompt_net_type, num_mels, vqvec_channels, prompt_channels, conformer_params): super(CTXVEC2WAVFrontend, self).__init__() if prompt_net_type == "ConvPromptPrenet": self.prompt_prenet = ConvPromptPrenet( embed=prompt_channels, conv_layers=[(128, 3, 1, 1), (256, 5, 1, 2), (512, 5, 1, 2), (conformer_params["attention_dim"], 3, 1, 1)], dropout=0.1, skip_connections=True, residual_scale=0.25, non_affine_group_norm=False, conv_bias=True, activation=torch.nn.ReLU() ) elif prompt_net_type == "Conv1d": self.prompt_prenet = torch.nn.Conv1d(prompt_channels, conformer_params["attention_dim"], kernel_size=5, padding=2) else: raise NotImplementedError self.encoder1 = ConformerDecoder(vqvec_channels, input_layer='linear', **conformer_params) self.hidden_proj = torch.nn.Linear(conformer_params["attention_dim"], conformer_params["attention_dim"]) self.encoder2 = ConformerDecoder(0, input_layer=None, **conformer_params) self.mel_proj = torch.nn.Linear(conformer_params["attention_dim"], num_mels) def forward(self, vqvec, prompt, mask=None, prompt_mask=None): """ params: vqvec: sequence of VQ-vectors. prompt: sequence of mel-spectrogram prompt (acoustic context) mask: mask of the vqvec. True or 1 stands for valid values. prompt_mask: mask of the prompt. vqvec and prompt are of shape [B, D, T]. All masks are of shape [B, T]. returns: enc_out: the input to the vec2wav2 Generator (BigVGAN); mel: the frontend predicted mel spectrogram (for faster convergence); """ prompt = self.prompt_prenet(prompt.transpose(1, 2)).transpose(1, 2) if mask is not None: mask = mask.unsqueeze(-2) if prompt_mask is not None: prompt_mask = prompt_mask.unsqueeze(-2) enc_out, _ = self.encoder1(vqvec, mask, prompt, prompt_mask) h = self.hidden_proj(enc_out) enc_out, _ = self.encoder2(h, mask, prompt, prompt_mask) mel = self.mel_proj(enc_out) # (B, L, 80) return enc_out, mel, None class VEC2WAV2Generator(torch.nn.Module): def __init__(self, frontend: CTXVEC2WAVFrontend, backend: BigVGAN): super(VEC2WAV2Generator, self).__init__() self.frontend = frontend self.backend = backend def forward(self, vqvec, prompt, mask=None, prompt_mask=None, crop_len=0, crop_offsets=None): """ :param vqvec: (torch.Tensor) The shape is (B, L, D). Sequence of VQ-vectors. :param prompt: (torch.Tensor) The shape is (B, L', 80). Sequence of mel-spectrogram prompt (acoustic context) :param mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L). True or 1 stands for valid values in `vqvec`. :param prompt_mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L'). True or 1 stands for valid values in `prompt`. :return: frontend predicted mel spectrogram; reconstructed waveform. """ h, mel, _ = self.frontend(vqvec, prompt, mask=mask, prompt_mask=prompt_mask) # (B, L, adim), (B, L, 80) if mask is not None: h = h.masked_fill(~mask.unsqueeze(-1), 0) h = h.transpose(1, 2) if crop_len > 0: h = crop_seq(h, crop_offsets, crop_len) if prompt_mask is not None: prompt_avg = prompt.masked_fill(~prompt_mask.unsqueeze(-1), 0).sum(1) / prompt_mask.sum(1).unsqueeze(-1) else: prompt_avg = prompt.mean(1) wav = self.backend(h, prompt_avg) # (B, C, T) return mel, None, wav def inference(self, vqvec, prompt): h, mel, _ = self.frontend(vqvec, prompt) wav = self.backend(h.transpose(1,2), prompt.mean(1)) return mel, None, wav