Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright 2024 Yiwei Guo | |
# Licensed under the Apache 2.0 license. | |
"""vec2wav2.0 main architectures""" | |
import torch | |
from vec2wav2.models.conformer.decoder import Decoder as ConformerDecoder | |
from vec2wav2.utils import crop_seq | |
from vec2wav2.models.bigvgan import BigVGAN | |
from vec2wav2.models.prompt_prenet import ConvPromptPrenet | |
import logging | |
class CTXVEC2WAVFrontend(torch.nn.Module): | |
def __init__(self, | |
prompt_net_type, | |
num_mels, | |
vqvec_channels, | |
prompt_channels, | |
conformer_params): | |
super(CTXVEC2WAVFrontend, self).__init__() | |
if prompt_net_type == "ConvPromptPrenet": | |
self.prompt_prenet = ConvPromptPrenet( | |
embed=prompt_channels, | |
conv_layers=[(128, 3, 1, 1), (256, 5, 1, 2), (512, 5, 1, 2), (conformer_params["attention_dim"], 3, 1, 1)], | |
dropout=0.1, | |
skip_connections=True, | |
residual_scale=0.25, | |
non_affine_group_norm=False, | |
conv_bias=True, | |
activation=torch.nn.ReLU() | |
) | |
elif prompt_net_type == "Conv1d": | |
self.prompt_prenet = torch.nn.Conv1d(prompt_channels, conformer_params["attention_dim"], kernel_size=5, padding=2) | |
else: | |
raise NotImplementedError | |
self.encoder1 = ConformerDecoder(vqvec_channels, input_layer='linear', **conformer_params) | |
self.hidden_proj = torch.nn.Linear(conformer_params["attention_dim"], conformer_params["attention_dim"]) | |
self.encoder2 = ConformerDecoder(0, input_layer=None, **conformer_params) | |
self.mel_proj = torch.nn.Linear(conformer_params["attention_dim"], num_mels) | |
def forward(self, vqvec, prompt, mask=None, prompt_mask=None): | |
""" | |
params: | |
vqvec: sequence of VQ-vectors. | |
prompt: sequence of mel-spectrogram prompt (acoustic context) | |
mask: mask of the vqvec. True or 1 stands for valid values. | |
prompt_mask: mask of the prompt. | |
vqvec and prompt are of shape [B, D, T]. All masks are of shape [B, T]. | |
returns: | |
enc_out: the input to the vec2wav2 Generator (BigVGAN); | |
mel: the frontend predicted mel spectrogram (for faster convergence); | |
""" | |
prompt = self.prompt_prenet(prompt.transpose(1, 2)).transpose(1, 2) | |
if mask is not None: | |
mask = mask.unsqueeze(-2) | |
if prompt_mask is not None: | |
prompt_mask = prompt_mask.unsqueeze(-2) | |
enc_out, _ = self.encoder1(vqvec, mask, prompt, prompt_mask) | |
h = self.hidden_proj(enc_out) | |
enc_out, _ = self.encoder2(h, mask, prompt, prompt_mask) | |
mel = self.mel_proj(enc_out) # (B, L, 80) | |
return enc_out, mel, None | |
class VEC2WAV2Generator(torch.nn.Module): | |
def __init__(self, frontend: CTXVEC2WAVFrontend, backend: BigVGAN): | |
super(VEC2WAV2Generator, self).__init__() | |
self.frontend = frontend | |
self.backend = backend | |
def forward(self, vqvec, prompt, mask=None, prompt_mask=None, crop_len=0, crop_offsets=None): | |
""" | |
:param vqvec: (torch.Tensor) The shape is (B, L, D). Sequence of VQ-vectors. | |
:param prompt: (torch.Tensor) The shape is (B, L', 80). Sequence of mel-spectrogram prompt (acoustic context) | |
:param mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L). True or 1 stands for valid values in `vqvec`. | |
:param prompt_mask: (torch.Tensor) The dtype is torch.bool. The shape is (B, L'). True or 1 stands for valid values in `prompt`. | |
:return: frontend predicted mel spectrogram; reconstructed waveform. | |
""" | |
h, mel, _ = self.frontend(vqvec, prompt, mask=mask, prompt_mask=prompt_mask) # (B, L, adim), (B, L, 80) | |
if mask is not None: | |
h = h.masked_fill(~mask.unsqueeze(-1), 0) | |
h = h.transpose(1, 2) | |
if crop_len > 0: | |
h = crop_seq(h, crop_offsets, crop_len) | |
if prompt_mask is not None: | |
prompt_avg = prompt.masked_fill(~prompt_mask.unsqueeze(-1), 0).sum(1) / prompt_mask.sum(1).unsqueeze(-1) | |
else: | |
prompt_avg = prompt.mean(1) | |
wav = self.backend(h, prompt_avg) # (B, C, T) | |
return mel, None, wav | |
def inference(self, vqvec, prompt): | |
h, mel, _ = self.frontend(vqvec, prompt) | |
wav = self.backend(h.transpose(1,2), prompt.mean(1)) | |
return mel, None, wav | |