Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torch.nn import Module | |
from torch.nn.utils import parametrize | |
from models.config import PreprocessingConfig, VocoderModelConfig | |
from models.helpers import get_mask_from_lengths | |
from .lvc_block import LVCBlock | |
class Generator(Module): | |
"""UnivNet Generator""" | |
def __init__( | |
self, | |
model_config: VocoderModelConfig, | |
preprocess_config: PreprocessingConfig, | |
): | |
r"""UnivNet Generator. | |
Initializes the UnivNet module. | |
Args: | |
model_config (VocoderModelConfig): the model configuration. | |
preprocess_config (PreprocessingConfig): the preprocessing configuration. | |
""" | |
super().__init__() | |
self.mel_channel = preprocess_config.stft.n_mel_channels | |
self.noise_dim = model_config.gen.noise_dim | |
self.hop_length = preprocess_config.stft.hop_length | |
channel_size = model_config.gen.channel_size | |
kpnet_conv_size = model_config.gen.kpnet_conv_size | |
hop_length = 1 | |
self.res_stack = nn.ModuleList() | |
for stride in model_config.gen.strides: | |
hop_length = stride * hop_length | |
self.res_stack.append( | |
LVCBlock( | |
channel_size, | |
preprocess_config.stft.n_mel_channels, | |
stride=stride, | |
dilations=model_config.gen.dilations, | |
lReLU_slope=model_config.gen.lReLU_slope, | |
cond_hop_length=hop_length, | |
kpnet_conv_size=kpnet_conv_size, | |
), | |
) | |
self.conv_pre = nn.utils.parametrizations.weight_norm( | |
nn.Conv1d( | |
model_config.gen.noise_dim, | |
channel_size, | |
7, | |
padding=3, | |
padding_mode="reflect", | |
), | |
) | |
self.conv_post = nn.Sequential( | |
nn.LeakyReLU(model_config.gen.lReLU_slope), | |
nn.utils.parametrizations.weight_norm( | |
nn.Conv1d( | |
channel_size, | |
1, | |
7, | |
padding=3, | |
padding_mode="reflect", | |
), | |
), | |
nn.Tanh(), | |
) | |
# Output of STFT(zeros) | |
self.mel_mask_value = -11.5129 | |
def forward(self, c: torch.Tensor) -> torch.Tensor: | |
r"""Forward pass of the Generator module. | |
Args: | |
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) | |
Returns: | |
Tensor: the generated audio waveform (batch, 1, out_length) | |
""" | |
z = torch.randn( | |
c.shape[0], | |
self.noise_dim, | |
c.shape[2], | |
device=c.device, | |
dtype=self.conv_pre.weight.data.dtype, | |
) | |
z = self.conv_pre(z) # (B, c_g, L) | |
for res_block in self.res_stack: | |
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) | |
return self.conv_post(z) # (B, 1, L * 256) | |
def eval(self, inference: bool = False): | |
r"""Sets the module to evaluation mode. | |
Args: | |
inference (bool): whether to remove weight normalization or not. | |
""" | |
super().eval() | |
# don't remove weight norm while validation in training loop | |
if inference: | |
self.remove_weight_norm() | |
def remove_weight_norm(self) -> None: | |
r"""Removes weight normalization from the module.""" | |
print("Removing weight norm...") | |
parametrize.remove_parametrizations(self.conv_pre, "weight") | |
for layer in self.conv_post: | |
if len(layer.state_dict()) != 0: | |
parametrize.remove_parametrizations(layer, "weight") | |
for res_block in self.res_stack: | |
res_block.remove_weight_norm() | |
def infer(self, c: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: | |
r"""Infers the audio waveform from the mel-spectrogram conditioning sequence. | |
Args: | |
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) | |
mel_lens (Tensor): the lengths of the mel-spectrogram conditioning sequence. | |
Returns: | |
Tensor: the generated audio waveform (batch, 1, out_length) | |
""" | |
mel_mask = get_mask_from_lengths(mel_lens).unsqueeze(1).to(c.device) | |
c = c.masked_fill(mel_mask, self.mel_mask_value) | |
zero = torch.full( | |
(c.shape[0], self.mel_channel, 10), self.mel_mask_value, device=c.device, | |
) | |
mel = torch.cat((c, zero), dim=2) | |
audio = self(mel) | |
audio = audio[:, :, : -(self.hop_length * 10)] | |
audio_mask = get_mask_from_lengths(mel_lens * 256).unsqueeze(1) | |
return audio.masked_fill(audio_mask, 0.0) | |