Spaces:
Running
Running
File size: 9,241 Bytes
2f5f13b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import math
from typing import Optional
import torch
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.utils.checkpoint import checkpoint
from rvc.lib.algorithm.commons import init_weights
from rvc.lib.algorithm.generators.hifigan import SineGenerator
from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock
class SourceModuleHnNSF(torch.nn.Module):
"""
Source Module for generating harmonic and noise components for audio synthesis.
This module generates a harmonic source signal using sine waves and adds
optional noise. It's often used in neural vocoders as a source of excitation.
Args:
sample_rate (int): Sampling rate of the audio in Hz.
harmonic_num (int, optional): Number of harmonic overtones to generate above the fundamental frequency (F0). Defaults to 0.
sine_amp (float, optional): Amplitude of the sine wave components. Defaults to 0.1.
add_noise_std (float, optional): Standard deviation of the additive white Gaussian noise. Defaults to 0.003.
voiced_threshod (float, optional): Threshold for the fundamental frequency (F0) to determine if a frame is voiced. If F0 is below this threshold, it's considered unvoiced. Defaults to 0.
"""
def __init__(
self,
sample_rate: int,
harmonic_num: int = 0,
sine_amp: float = 0.1,
add_noise_std: float = 0.003,
voiced_threshod: float = 0,
):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
self.l_sin_gen = SineGenerator(
sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
)
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x: torch.Tensor, upsample_factor: int = 1):
sine_wavs, uv, _ = self.l_sin_gen(x, upsample_factor)
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
return sine_merge, None, None
class HiFiGANNSFGenerator(torch.nn.Module):
"""
Generator module based on the Neural Source Filter (NSF) architecture.
This generator synthesizes audio by first generating a source excitation signal
(harmonic and noise) and then filtering it through a series of upsampling and
residual blocks. Global conditioning can be applied to influence the generation.
Args:
initial_channel (int): Number of input channels to the initial convolutional layer.
resblock_kernel_sizes (list): List of kernel sizes for the residual blocks.
resblock_dilation_sizes (list): List of lists of dilation rates for the residual blocks, corresponding to each kernel size.
upsample_rates (list): List of upsampling factors for each upsampling layer.
upsample_initial_channel (int): Number of output channels from the initial convolutional layer, which is also the input to the first upsampling layer.
upsample_kernel_sizes (list): List of kernel sizes for the transposed convolutional layers used for upsampling.
gin_channels (int): Number of input channels for the global conditioning. If 0, no global conditioning is used.
sr (int): Sampling rate of the audio.
checkpointing (bool, optional): Whether to use gradient checkpointing to save memory during training. Defaults to False.
"""
def __init__(
self,
initial_channel: int,
resblock_kernel_sizes: list,
resblock_dilation_sizes: list,
upsample_rates: list,
upsample_initial_channel: int,
upsample_kernel_sizes: list,
gin_channels: int,
sr: int,
checkpointing: bool = False,
):
super(HiFiGANNSFGenerator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.checkpointing = checkpointing
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
self.conv_pre = torch.nn.Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
self.ups = torch.nn.ModuleList()
self.noise_convs = torch.nn.ModuleList()
channels = [
upsample_initial_channel // (2 ** (i + 1))
for i in range(len(upsample_rates))
]
stride_f0s = [
math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1
for i in range(len(upsample_rates))
]
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
# handling odd upsampling rates
if u % 2 == 0:
# old method
padding = (k - u) // 2
else:
padding = u // 2 + u % 2
self.ups.append(
weight_norm(
torch.nn.ConvTranspose1d(
upsample_initial_channel // (2**i),
channels[i],
k,
u,
padding=padding,
output_padding=u % 2,
)
)
)
""" handling odd upsampling rates
# s k p
# 40 80 20
# 32 64 16
# 4 8 2
# 2 3 1
# 63 125 31
# 9 17 4
# 3 5 1
# 1 1 0
"""
stride = stride_f0s[i]
kernel = 1 if stride == 1 else stride * 2 - stride % 2
padding = 0 if stride == 1 else (kernel - stride) // 2
self.noise_convs.append(
torch.nn.Conv1d(
1,
channels[i],
kernel_size=kernel,
stride=stride,
padding=padding,
)
)
self.resblocks = torch.nn.ModuleList(
[
ResBlock(channels[i], k, d)
for i in range(len(self.ups))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)
]
)
self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
self.upp = math.prod(upsample_rates)
self.lrelu_slope = LRELU_SLOPE
def forward(
self, x: torch.Tensor, f0: torch.Tensor, g: Optional[torch.Tensor] = None
):
har_source, _, _ = self.m_source(f0, self.upp)
har_source = har_source.transpose(1, 2)
# new tensor
x = self.conv_pre(x)
if g is not None:
# in-place call
x += self.cond(g)
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
# in-place call
x = torch.nn.functional.leaky_relu_(x, self.lrelu_slope)
# Apply upsampling layer
if self.training and self.checkpointing:
x = checkpoint(ups, x, use_reentrant=False)
else:
x = ups(x)
# Add noise excitation
x += noise_convs(har_source)
# Apply residual blocks
def resblock_forward(x, blocks):
return sum(block(x) for block in blocks) / len(blocks)
blocks = self.resblocks[i * self.num_kernels : (i + 1) * self.num_kernels]
# Checkpoint or regular computation for ResBlocks
if self.training and self.checkpointing:
x = checkpoint(resblock_forward, x, blocks, use_reentrant=False)
else:
x = resblock_forward(x, blocks)
# in-place call
x = torch.nn.functional.leaky_relu_(x)
# in-place call
x = torch.tanh_(self.conv_post(x))
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
def __prepare_scriptable__(self):
for l in self.ups:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
remove_weight_norm(l)
for l in self.resblocks:
for hook in l._forward_pre_hooks.values():
if (
hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
and hook.__class__.__name__ == "WeightNorm"
):
remove_weight_norm(l)
return self
|