|
import torch
|
|
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
|
|
from torch.nn.utils import remove_weight_norm
|
|
from torch.utils.checkpoint import checkpoint
|
|
from torch.nn.utils.parametrizations import weight_norm
|
|
|
|
LRELU_SLOPE = 0.1
|
|
|
|
class MRFLayer(torch.nn.Module):
|
|
def __init__(self, channels, kernel_size, dilation):
|
|
super().__init__()
|
|
self.conv1 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
|
|
self.conv2 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
|
|
|
|
def forward(self, x):
|
|
return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
|
|
|
|
def remove_weight_norm(self):
|
|
remove_weight_norm(self.conv1)
|
|
remove_weight_norm(self.conv2)
|
|
|
|
class MRFBlock(torch.nn.Module):
|
|
def __init__(self, channels, kernel_size, dilations):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList()
|
|
|
|
for dilation in dilations:
|
|
self.layers.append(MRFLayer(channels, kernel_size, dilation))
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
|
|
return x
|
|
|
|
def remove_weight_norm(self):
|
|
for layer in self.layers:
|
|
layer.remove_weight_norm()
|
|
|
|
class SineGenerator(torch.nn.Module):
|
|
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
|
|
super(SineGenerator, self).__init__()
|
|
self.sine_amp = sine_amp
|
|
self.noise_std = noise_std
|
|
self.harmonic_num = harmonic_num
|
|
self.dim = self.harmonic_num + 1
|
|
self.sampling_rate = samp_rate
|
|
self.voiced_threshold = voiced_threshold
|
|
|
|
def _f02uv(self, f0):
|
|
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
|
|
|
def _f02sine(self, f0_values):
|
|
rad_values = (f0_values / self.sampling_rate) % 1
|
|
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
|
|
|
rand_ini[:, 0] = 0
|
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
|
|
|
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
|
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
|
|
|
cumsum_shift = torch.zeros_like(rad_values)
|
|
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
|
|
|
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
|
|
|
def forward(self, f0):
|
|
with torch.no_grad():
|
|
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
|
f0_buf[:, :, 0] = f0[:, :, 0]
|
|
|
|
for idx in np.arange(self.harmonic_num):
|
|
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
|
|
|
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
|
uv = self._f02uv(f0)
|
|
|
|
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
|
|
|
return sine_waves
|
|
|
|
class SourceModuleHnNSF(torch.nn.Module):
|
|
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0):
|
|
super(SourceModuleHnNSF, self).__init__()
|
|
self.sine_amp = sine_amp
|
|
self.noise_std = add_noise_std
|
|
|
|
self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
|
|
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
|
self.l_tanh = torch.nn.Tanh()
|
|
|
|
def forward(self, x):
|
|
return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
|
|
|
|
class HiFiGANMRFGenerator(torch.nn.Module):
|
|
def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing=False):
|
|
super().__init__()
|
|
self.num_kernels = len(resblock_kernel_sizes)
|
|
|
|
self.upp = int(np.prod(upsample_rates))
|
|
self.f0_upsample = torch.nn.Upsample(scale_factor=self.upp)
|
|
self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
|
|
|
|
self.conv_pre = weight_norm(torch.nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
|
|
self.checkpointing = checkpointing
|
|
|
|
self.upsamples = torch.nn.ModuleList()
|
|
self.upsampler = torch.nn.ModuleList()
|
|
self.noise_convs = torch.nn.ModuleList()
|
|
|
|
stride_f0s = [upsample_rates[1] * upsample_rates[2] * upsample_rates[3], upsample_rates[2] * upsample_rates[3], upsample_rates[3], 1]
|
|
|
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
if self.upp == 441:
|
|
self.upsampler.append(torch.nn.Upsample(scale_factor=u, mode="linear"))
|
|
self.upsamples.append(weight_norm(torch.nn.Conv1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=1)))
|
|
self.noise_convs.append(torch.nn.Conv1d(in_channels=1, out_channels=upsample_initial_channel // (2 ** (i + 1)), kernel_size = 1))
|
|
else:
|
|
self.upsampler.append(torch.nn.Identity())
|
|
self.upsamples.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=(k - u) // 2)))
|
|
self.noise_convs.append(torch.nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1, stride=stride_f0s[i], padding=stride_f0s[i] // 2))
|
|
|
|
self.mrfs = torch.nn.ModuleList()
|
|
for i in range(len(self.upsamples)):
|
|
channel = upsample_initial_channel // (2 ** (i + 1))
|
|
self.mrfs.append(torch.nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
|
|
|
|
self.conv_post = weight_norm(torch.nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
|
|
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
|
|
|
def forward(self, x, f0, g = None):
|
|
har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
|
|
x = self.conv_pre(x)
|
|
if g is not None: x += self.cond(g)
|
|
|
|
for ups, upr, mrf, noise_conv in zip(self.upsamples, self.upsampler, self.mrfs, self.noise_convs):
|
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
|
|
|
if self.training and self.checkpointing:
|
|
if self.upp == 441: x = upr(x)
|
|
x = checkpoint(ups, x, use_reentrant=False)
|
|
else:
|
|
if self.upp == 441: x = upr(x)
|
|
x = ups(x)
|
|
|
|
h = noise_conv(har_source)
|
|
if self.upp == 441: h = torch.nn.functional.interpolate(h, size=x.shape[-1], mode="linear")
|
|
x += h
|
|
|
|
def mrf_sum(x, layers):
|
|
return sum(layer(x) for layer in layers) / self.num_kernels
|
|
|
|
x = checkpoint(mrf_sum, x, mrf, use_reentrant=False) if self.training and self.checkpointing else mrf_sum(x, mrf)
|
|
|
|
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
|
|
|
def remove_weight_norm(self):
|
|
remove_weight_norm(self.conv_pre)
|
|
|
|
for up in self.upsamples:
|
|
remove_weight_norm(up)
|
|
|
|
for mrf in self.mrfs:
|
|
mrf.remove_weight_norm()
|
|
|
|
remove_weight_norm(self.conv_post) |