nickovchinnikov's picture
Init
9d61c9b
from typing import List, Tuple
import torch
from torch import Tensor, nn
from torch.nn import Conv1d, ConvTranspose1d, Module
import torch.nn.functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from models.config import HifiGanConfig, HifiGanPretrainingConfig, PreprocessingConfig
from .utils import get_padding, init_weights
# Leaky ReLU slope
LRELU_SLOPE = HifiGanPretrainingConfig.lReLU_slope
class ResBlock1(Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: List[int] = [1, 3, 5],
):
r"""Initialize the ResBlock1 module.
Args:
channels (int): The number of channels for the ResBlock.
kernel_size (int, optional): The kernel size for the convolutional layers. Defaults to 3.
dilation (Tuple[int, int, int], optional): The dilation for the convolutional layers. Defaults to (1, 3, 5).
"""
super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
),
],
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
),
],
)
self.convs2.apply(init_weights)
def forward(self, x: Tensor) -> Tensor:
r"""Forward pass of the ResBlock1 module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
r"""Remove the weight normalization from the convolutional layers."""
for layer in self.convs1:
remove_weight_norm(layer)
for layer in self.convs2:
remove_weight_norm(layer)
class ResBlock2(Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: List[int] = [1, 3],
):
r"""Initialize the ResBlock2 module.
Args:
channels (int): The number of channels for the ResBlock.
kernel_size (int, optional): The kernel size for the convolutional layers. Defaults to 3.
dilation (Tuple[int, int], optional): The dilation for the convolutional layers. Defaults to (1, 3).
"""
super().__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
),
],
)
self.convs.apply(init_weights)
def forward(self, x: Tensor) -> Tensor:
r"""Forward pass of the ResBlock2 module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
for layer in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = layer(xt)
x = xt + x
return x
def remove_weight_norm(self):
r"""Remove the weight normalization from the convolutional layers."""
for layer in self.convs:
remove_weight_norm(layer)
class Generator(Module):
def __init__(self, h: HifiGanConfig, p: PreprocessingConfig):
r"""Initialize the Generator module.
Args:
h (HifiGanConfig): The configuration for the Generator.
p (PreprocessingConfig): The configuration for the preprocessing.
"""
super().__init__()
self.h = h
self.p = p
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(
Conv1d(
self.p.stft.n_mel_channels,
h.upsample_initial_channel,
7,
1,
padding=3,
),
)
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
),
),
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
resblock_list = nn.ModuleList()
ch = h.upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes),
):
resblock_list.append(resblock(ch, k, d))
self.resblocks.append(resblock_list)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x: Tensor) -> Tensor:
r"""Forward pass of the Generator module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
x = self.conv_pre(x)
for upsample_layer, resblock_group in zip(self.ups, self.resblocks):
x = F.leaky_relu(x, LRELU_SLOPE)
x = upsample_layer(x)
xs = torch.zeros(x.shape, dtype=x.dtype, device=x.device)
for resblock in resblock_group: # type: ignore
xs += resblock(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x