|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from Modules.GeneralLayers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock |
|
|
|
|
|
class HiFiGAN(torch.nn.Module): |
|
|
|
def __init__(self, |
|
in_channels=128, |
|
out_channels=1, |
|
channels=768, |
|
kernel_size=7, |
|
upsample_scales=(8, 6, 2, 2, 2), |
|
upsample_kernel_sizes=(16, 12, 4, 4, 4), |
|
resblock_kernel_sizes=(3, 7, 11), |
|
resblock_dilations=((1, 3, 5), (1, 3, 5), (1, 3, 5)), |
|
use_additional_convs=True, |
|
bias=True, |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.1}, |
|
weights=None): |
|
""" |
|
Initialize HiFiGANGenerator module. |
|
|
|
Args: |
|
in_channels (int): Number of input channels. |
|
out_channels (int): Number of output channels. |
|
channels (int): Number of hidden representation channels. |
|
kernel_size (int): Kernel size of initial and final conv layer. |
|
upsample_scales (list): List of upsampling scales. |
|
upsample_kernel_sizes (list): List of kernel sizes for upsampling layers. |
|
resblock_kernel_sizes (list): List of kernel sizes for residual blocks. |
|
resblock_dilations (list): List of dilation list for residual blocks. |
|
use_additional_convs (bool): Whether to use additional conv layers in residual blocks. |
|
bias (bool): Whether to add bias parameter in convolution layers. |
|
nonlinear_activation (str): Activation function module name. |
|
nonlinear_activation_params (dict): Hyperparameters for activation function. |
|
use_weight_norm (bool): Whether to use weight norm. |
|
If set to true, it will be applied to all of the conv layers. |
|
""" |
|
super().__init__() |
|
|
|
|
|
assert kernel_size % 2 == 1, "Kernel size must be odd number." |
|
assert len(upsample_scales) == len(upsample_kernel_sizes) |
|
assert len(resblock_dilations) == len(resblock_kernel_sizes) |
|
|
|
|
|
self.num_upsamples = len(upsample_kernel_sizes) |
|
self.num_blocks = len(resblock_kernel_sizes) |
|
self.input_conv = torch.nn.Conv1d(in_channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
padding=(kernel_size - 1) // 2, ) |
|
self.upsamples = torch.nn.ModuleList() |
|
self.blocks = torch.nn.ModuleList() |
|
for i in range(len(upsample_kernel_sizes)): |
|
self.upsamples += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), |
|
torch.nn.ConvTranspose1d(channels // (2 ** i), |
|
channels // (2 ** (i + 1)), |
|
upsample_kernel_sizes[i], |
|
upsample_scales[i], |
|
padding=(upsample_kernel_sizes[i] - upsample_scales[i]) // 2, ), )] |
|
for j in range(len(resblock_kernel_sizes)): |
|
self.blocks += [ResidualBlock(kernel_size=resblock_kernel_sizes[j], |
|
channels=channels // (2 ** (i + 1)), |
|
dilations=resblock_dilations[j], |
|
bias=bias, |
|
use_additional_convs=use_additional_convs, |
|
nonlinear_activation=nonlinear_activation, |
|
nonlinear_activation_params=nonlinear_activation_params, )] |
|
self.output_conv = torch.nn.Sequential( |
|
|
|
|
|
torch.nn.LeakyReLU(), |
|
torch.nn.Conv1d(channels // (2 ** (i + 1)), |
|
out_channels, |
|
kernel_size, |
|
1, |
|
padding=(kernel_size - 1) // 2, ), torch.nn.Tanh(), ) |
|
|
|
|
|
self.apply_weight_norm() |
|
|
|
|
|
self.reset_parameters() |
|
|
|
if weights is not None: |
|
self.load_state_dict(weights) |
|
|
|
def forward(self, c): |
|
""" |
|
Calculate forward propagation. |
|
|
|
Args: |
|
c (Tensor): Input tensor (B, in_channels, T). |
|
|
|
Returns: |
|
Tensor: Output tensor (B, out_channels, T). |
|
Tensor: intermediate result |
|
Tensor: another intermediate result |
|
""" |
|
c = self.input_conv(c) |
|
for i in range(self.num_upsamples): |
|
c = self.upsamples[i](c) |
|
cs = 0.0 |
|
for j in range(self.num_blocks): |
|
cs += self.blocks[i * self.num_blocks + j](c) |
|
c = cs / self.num_blocks |
|
c = self.output_conv(c) |
|
|
|
return c |
|
|
|
def reset_parameters(self): |
|
""" |
|
Reset parameters. |
|
|
|
This initialization follows the official implementation manner. |
|
https://github.com/jik876/hifi-gan/blob/master/models.py |
|
""" |
|
|
|
def _reset_parameters(m): |
|
if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): |
|
m.weight.data.normal_(0.0, 0.01) |
|
|
|
self.apply(_reset_parameters) |
|
|
|
def remove_weight_norm(self): |
|
""" |
|
Remove weight normalization module from all of the layers. |
|
""" |
|
|
|
def _remove_weight_norm(m): |
|
try: |
|
torch.nn.utils.remove_weight_norm(m) |
|
except ValueError: |
|
return |
|
|
|
self.apply(_remove_weight_norm) |
|
|
|
def apply_weight_norm(self): |
|
""" |
|
Apply weight normalization module from all of the layers. |
|
""" |
|
|
|
def _apply_weight_norm(m): |
|
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): |
|
torch.nn.utils.weight_norm(m) |
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
def inference(self, c, normalize_before=False): |
|
""" |
|
Perform inference. |
|
|
|
Args: |
|
c (Union[Tensor, ndarray]): Input tensor (T, in_channels). |
|
normalize_before (bool): Whether to perform normalization. |
|
|
|
Returns: |
|
Tensor: Output tensor (T ** prod(upsample_scales), out_channels). |
|
""" |
|
if not isinstance(c, torch.Tensor): |
|
c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device) |
|
if normalize_before: |
|
c = (c - self.mean) / self.scale |
|
c = self.forward(c.transpose(1, 0).unsqueeze(0)) |
|
return c.squeeze(0).transpose(1, 0) |
|
|
|
|
|
if __name__ == "__main__": |
|
hifi = HiFiGAN() |
|
print(f"HiFiGAN parameter count: {sum(p.numel() for p in hifi.parameters() if p.requires_grad)}") |
|
print(hifi(torch.randn([1, 128, 100]))[0].shape) |
|
|