import torch import numpy as np import torch.nn.functional as F from torch.nn.utils import remove_weight_norm from torch.utils.checkpoint import checkpoint from torch.nn.utils.parametrizations import weight_norm LRELU_SLOPE = 0.1 class MRFLayer(torch.nn.Module): def __init__(self, channels, kernel_size, dilation): super().__init__() self.conv1 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation)) self.conv2 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1)) def forward(self, x): return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE)) def remove_weight_norm(self): remove_weight_norm(self.conv1) remove_weight_norm(self.conv2) class MRFBlock(torch.nn.Module): def __init__(self, channels, kernel_size, dilations): super().__init__() self.layers = torch.nn.ModuleList() for dilation in dilations: self.layers.append(MRFLayer(channels, kernel_size, dilation)) def forward(self, x): for layer in self.layers: x = layer(x) return x def remove_weight_norm(self): for layer in self.layers: layer.remove_weight_norm() class SineGenerator(torch.nn.Module): def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0): super(SineGenerator, self).__init__() self.sine_amp = sine_amp self.noise_std = noise_std self.harmonic_num = harmonic_num self.dim = self.harmonic_num + 1 self.sampling_rate = samp_rate self.voiced_threshold = voiced_threshold def _f02uv(self, f0): return torch.ones_like(f0) * (f0 > self.voiced_threshold) def _f02sine(self, f0_values): rad_values = (f0_values / self.sampling_rate) % 1 rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) rand_ini[:, 0] = 0 rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini tmp_over_one = torch.cumsum(rad_values, 1) % 1 tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 cumsum_shift = torch.zeros_like(rad_values) cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) def forward(self, f0): with torch.no_grad(): f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) f0_buf[:, :, 0] = f0[:, :, 0] for idx in np.arange(self.harmonic_num): f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) sine_waves = self._f02sine(f0_buf) * self.sine_amp uv = self._f02uv(f0) sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves)) return sine_waves class SourceModuleHnNSF(torch.nn.Module): def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0): super(SourceModuleHnNSF, self).__init__() self.sine_amp = sine_amp self.noise_std = add_noise_std self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold) self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) self.l_tanh = torch.nn.Tanh() def forward(self, x): return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype))) class HiFiGANMRFGenerator(torch.nn.Module): def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing=False): super().__init__() self.num_kernels = len(resblock_kernel_sizes) self.upp = int(np.prod(upsample_rates)) self.f0_upsample = torch.nn.Upsample(scale_factor=self.upp) self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num) self.conv_pre = weight_norm(torch.nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3)) self.checkpointing = checkpointing self.upsamples = torch.nn.ModuleList() self.upsampler = torch.nn.ModuleList() self.noise_convs = torch.nn.ModuleList() stride_f0s = [upsample_rates[1] * upsample_rates[2] * upsample_rates[3], upsample_rates[2] * upsample_rates[3], upsample_rates[3], 1] for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): if self.upp == 441: self.upsampler.append(torch.nn.Upsample(scale_factor=u, mode="linear")) self.upsamples.append(weight_norm(torch.nn.Conv1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=1))) self.noise_convs.append(torch.nn.Conv1d(in_channels=1, out_channels=upsample_initial_channel // (2 ** (i + 1)), kernel_size = 1)) else: self.upsampler.append(torch.nn.Identity()) self.upsamples.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=(k - u) // 2))) self.noise_convs.append(torch.nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1, stride=stride_f0s[i], padding=stride_f0s[i] // 2)) self.mrfs = torch.nn.ModuleList() for i in range(len(self.upsamples)): channel = upsample_initial_channel // (2 ** (i + 1)) self.mrfs.append(torch.nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)])) self.conv_post = weight_norm(torch.nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3)) if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1) def forward(self, x, f0, g = None): har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2) x = self.conv_pre(x) if g is not None: x += self.cond(g) for ups, upr, mrf, noise_conv in zip(self.upsamples, self.upsampler, self.mrfs, self.noise_convs): x = F.leaky_relu(x, LRELU_SLOPE) if self.training and self.checkpointing: if self.upp == 441: x = upr(x) x = checkpoint(ups, x, use_reentrant=False) else: if self.upp == 441: x = upr(x) x = ups(x) h = noise_conv(har_source) if self.upp == 441: h = torch.nn.functional.interpolate(h, size=x.shape[-1], mode="linear") x += h def mrf_sum(x, layers): return sum(layer(x) for layer in layers) / self.num_kernels x = checkpoint(mrf_sum, x, mrf, use_reentrant=False) if self.training and self.checkpointing else mrf_sum(x, mrf) return torch.tanh(self.conv_post(F.leaky_relu(x))) def remove_weight_norm(self): remove_weight_norm(self.conv_pre) for up in self.upsamples: remove_weight_norm(up) for mrf in self.mrfs: mrf.remove_weight_norm() remove_weight_norm(self.conv_post)