nickovchinnikov's picture
Init
9d61c9b
import torch
from torch import nn
from torch.nn import Module
from torch.nn.utils import parametrize
from models.config import PreprocessingConfig, VocoderModelConfig
from models.helpers import get_mask_from_lengths
from .lvc_block import LVCBlock
class Generator(Module):
"""UnivNet Generator"""
def __init__(
self,
model_config: VocoderModelConfig,
preprocess_config: PreprocessingConfig,
):
r"""UnivNet Generator.
Initializes the UnivNet module.
Args:
model_config (VocoderModelConfig): the model configuration.
preprocess_config (PreprocessingConfig): the preprocessing configuration.
"""
super().__init__()
self.mel_channel = preprocess_config.stft.n_mel_channels
self.noise_dim = model_config.gen.noise_dim
self.hop_length = preprocess_config.stft.hop_length
channel_size = model_config.gen.channel_size
kpnet_conv_size = model_config.gen.kpnet_conv_size
hop_length = 1
self.res_stack = nn.ModuleList()
for stride in model_config.gen.strides:
hop_length = stride * hop_length
self.res_stack.append(
LVCBlock(
channel_size,
preprocess_config.stft.n_mel_channels,
stride=stride,
dilations=model_config.gen.dilations,
lReLU_slope=model_config.gen.lReLU_slope,
cond_hop_length=hop_length,
kpnet_conv_size=kpnet_conv_size,
),
)
self.conv_pre = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
model_config.gen.noise_dim,
channel_size,
7,
padding=3,
padding_mode="reflect",
),
)
self.conv_post = nn.Sequential(
nn.LeakyReLU(model_config.gen.lReLU_slope),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
channel_size,
1,
7,
padding=3,
padding_mode="reflect",
),
),
nn.Tanh(),
)
# Output of STFT(zeros)
self.mel_mask_value = -11.5129
def forward(self, c: torch.Tensor) -> torch.Tensor:
r"""Forward pass of the Generator module.
Args:
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
Returns:
Tensor: the generated audio waveform (batch, 1, out_length)
"""
z = torch.randn(
c.shape[0],
self.noise_dim,
c.shape[2],
device=c.device,
dtype=self.conv_pre.weight.data.dtype,
)
z = self.conv_pre(z) # (B, c_g, L)
for res_block in self.res_stack:
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
return self.conv_post(z) # (B, 1, L * 256)
def eval(self, inference: bool = False):
r"""Sets the module to evaluation mode.
Args:
inference (bool): whether to remove weight normalization or not.
"""
super().eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def remove_weight_norm(self) -> None:
r"""Removes weight normalization from the module."""
print("Removing weight norm...")
parametrize.remove_parametrizations(self.conv_pre, "weight")
for layer in self.conv_post:
if len(layer.state_dict()) != 0:
parametrize.remove_parametrizations(layer, "weight")
for res_block in self.res_stack:
res_block.remove_weight_norm()
def infer(self, c: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor:
r"""Infers the audio waveform from the mel-spectrogram conditioning sequence.
Args:
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
mel_lens (Tensor): the lengths of the mel-spectrogram conditioning sequence.
Returns:
Tensor: the generated audio waveform (batch, 1, out_length)
"""
mel_mask = get_mask_from_lengths(mel_lens).unsqueeze(1).to(c.device)
c = c.masked_fill(mel_mask, self.mel_mask_value)
zero = torch.full(
(c.shape[0], self.mel_channel, 10), self.mel_mask_value, device=c.device,
)
mel = torch.cat((c, zero), dim=2)
audio = self(mel)
audio = audio[:, :, : -(self.hop_length * 10)]
audio_mask = get_mask_from_lengths(mel_lens * 256).unsqueeze(1)
return audio.masked_fill(audio_mask, 0.0)