Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
import torch | |
from torch import Tensor, nn | |
from torch.nn import Conv1d, ConvTranspose1d, Module | |
import torch.nn.functional as F | |
from torch.nn.utils import remove_weight_norm, weight_norm | |
from models.config import HifiGanConfig, HifiGanPretrainingConfig, PreprocessingConfig | |
from .utils import get_padding, init_weights | |
# Leaky ReLU slope | |
LRELU_SLOPE = HifiGanPretrainingConfig.lReLU_slope | |
class ResBlock1(Module): | |
def __init__( | |
self, | |
channels: int, | |
kernel_size: int = 3, | |
dilation: List[int] = [1, 3, 5], | |
): | |
r"""Initialize the ResBlock1 module. | |
Args: | |
channels (int): The number of channels for the ResBlock. | |
kernel_size (int, optional): The kernel size for the convolutional layers. Defaults to 3. | |
dilation (Tuple[int, int, int], optional): The dilation for the convolutional layers. Defaults to (1, 3, 5). | |
""" | |
super().__init__() | |
self.convs1 = nn.ModuleList( | |
[ | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[0], | |
padding=get_padding(kernel_size, dilation[0]), | |
), | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[1], | |
padding=get_padding(kernel_size, dilation[1]), | |
), | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[2], | |
padding=get_padding(kernel_size, dilation[2]), | |
), | |
), | |
], | |
) | |
self.convs1.apply(init_weights) | |
self.convs2 = nn.ModuleList( | |
[ | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
), | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
), | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=1, | |
padding=get_padding(kernel_size, 1), | |
), | |
), | |
], | |
) | |
self.convs2.apply(init_weights) | |
def forward(self, x: Tensor) -> Tensor: | |
r"""Forward pass of the ResBlock1 module. | |
Args: | |
x (Tensor): The input tensor. | |
Returns: | |
Tensor: The output tensor. | |
""" | |
for c1, c2 in zip(self.convs1, self.convs2): | |
xt = F.leaky_relu(x, LRELU_SLOPE) | |
xt = c1(xt) | |
xt = F.leaky_relu(xt, LRELU_SLOPE) | |
xt = c2(xt) | |
x = xt + x | |
return x | |
def remove_weight_norm(self): | |
r"""Remove the weight normalization from the convolutional layers.""" | |
for layer in self.convs1: | |
remove_weight_norm(layer) | |
for layer in self.convs2: | |
remove_weight_norm(layer) | |
class ResBlock2(Module): | |
def __init__( | |
self, | |
channels: int, | |
kernel_size: int = 3, | |
dilation: List[int] = [1, 3], | |
): | |
r"""Initialize the ResBlock2 module. | |
Args: | |
channels (int): The number of channels for the ResBlock. | |
kernel_size (int, optional): The kernel size for the convolutional layers. Defaults to 3. | |
dilation (Tuple[int, int], optional): The dilation for the convolutional layers. Defaults to (1, 3). | |
""" | |
super().__init__() | |
self.convs = nn.ModuleList( | |
[ | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[0], | |
padding=get_padding(kernel_size, dilation[0]), | |
), | |
), | |
weight_norm( | |
Conv1d( | |
channels, | |
channels, | |
kernel_size, | |
1, | |
dilation=dilation[1], | |
padding=get_padding(kernel_size, dilation[1]), | |
), | |
), | |
], | |
) | |
self.convs.apply(init_weights) | |
def forward(self, x: Tensor) -> Tensor: | |
r"""Forward pass of the ResBlock2 module. | |
Args: | |
x (Tensor): The input tensor. | |
Returns: | |
Tensor: The output tensor. | |
""" | |
for layer in self.convs: | |
xt = F.leaky_relu(x, LRELU_SLOPE) | |
xt = layer(xt) | |
x = xt + x | |
return x | |
def remove_weight_norm(self): | |
r"""Remove the weight normalization from the convolutional layers.""" | |
for layer in self.convs: | |
remove_weight_norm(layer) | |
class Generator(Module): | |
def __init__(self, h: HifiGanConfig, p: PreprocessingConfig): | |
r"""Initialize the Generator module. | |
Args: | |
h (HifiGanConfig): The configuration for the Generator. | |
p (PreprocessingConfig): The configuration for the preprocessing. | |
""" | |
super().__init__() | |
self.h = h | |
self.p = p | |
self.num_kernels = len(h.resblock_kernel_sizes) | |
self.num_upsamples = len(h.upsample_rates) | |
self.conv_pre = weight_norm( | |
Conv1d( | |
self.p.stft.n_mel_channels, | |
h.upsample_initial_channel, | |
7, | |
1, | |
padding=3, | |
), | |
) | |
resblock = ResBlock1 if h.resblock == "1" else ResBlock2 | |
self.ups = nn.ModuleList() | |
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): | |
self.ups.append( | |
weight_norm( | |
ConvTranspose1d( | |
h.upsample_initial_channel // (2**i), | |
h.upsample_initial_channel // (2 ** (i + 1)), | |
k, | |
u, | |
padding=(k - u) // 2, | |
), | |
), | |
) | |
self.resblocks = nn.ModuleList() | |
for i in range(len(self.ups)): | |
resblock_list = nn.ModuleList() | |
ch = h.upsample_initial_channel // (2 ** (i + 1)) | |
for _, (k, d) in enumerate( | |
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes), | |
): | |
resblock_list.append(resblock(ch, k, d)) | |
self.resblocks.append(resblock_list) | |
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) | |
self.ups.apply(init_weights) | |
self.conv_post.apply(init_weights) | |
def forward(self, x: Tensor) -> Tensor: | |
r"""Forward pass of the Generator module. | |
Args: | |
x (Tensor): The input tensor. | |
Returns: | |
Tensor: The output tensor. | |
""" | |
x = self.conv_pre(x) | |
for upsample_layer, resblock_group in zip(self.ups, self.resblocks): | |
x = F.leaky_relu(x, LRELU_SLOPE) | |
x = upsample_layer(x) | |
xs = torch.zeros(x.shape, dtype=x.dtype, device=x.device) | |
for resblock in resblock_group: # type: ignore | |
xs += resblock(x) | |
x = xs / self.num_kernels | |
x = F.leaky_relu(x) | |
x = self.conv_post(x) | |
x = torch.tanh(x) | |
return x | |