# Copyright 2025 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch.nn as nn import torch.nn.functional as F import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn from torch.nn.utils import weight_norm, remove_weight_norm from torch.nn import Conv1d import numpy as np def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) class Upsample(nn.Module): def __init__(self, mult, r): super(Upsample, self).__init__() self.r = r self.upsample = nn.Sequential(nn.Upsample(mode="nearest", scale_factor=r), nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mult, mult // 2, kernel_size=7, stride=1)) ) r_kernel = r if r >= 5 else 5 self.trans_upsample = nn.Sequential(nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult, mult // 2, kernel_size=r_kernel * 2, stride=r, padding=r_kernel - r // 2, output_padding=r % 2) )) def forward(self, x): x = torch.sin(x) + x out1 = self.upsample(x) out2 = self.trans_upsample(x) return out1 + out2 class Downsample(nn.Module): def __init__(self, mult, r): super(Downsample, self).__init__() self.r = r r_kernel = r if r >= 5 else 5 self.trans_downsample = nn.Sequential(nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.Conv1d(mult, mult * 2, kernel_size=r_kernel * 2, stride=r, padding=r_kernel - r // 2) )) def forward(self, x): out = self.trans_downsample(x) return out def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find("BatchNorm2d") != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def weights_zero_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.fill_(0.0) m.bias.data.fill_(0.0) def WNConv1d(*args, **kwargs): return weight_norm(nn.Conv1d(*args, **kwargs)) def WNConvTranspose1d(*args, **kwargs): return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) class Audio2Mel(nn.Module): def __init__( self, hop_length=300, sampling_rate=24000, n_mel_channels=80, mel_fmin=0., mel_fmax=None, frame_size=0.05, device='cpu' ): super().__init__() ############################################## # FFT Parameters # ############################################## self.n_fft = int(np.power(2., np.ceil(np.log(sampling_rate * frame_size) / np.log(2)))) window = torch.hann_window(int(sampling_rate * frame_size)).float() mel_basis = librosa_mel_fn( sampling_rate, self.n_fft, n_mel_channels, mel_fmin, mel_fmax ) # Mel filter (by librosa) mel_basis = torch.from_numpy(mel_basis).float() self.register_buffer("mel_basis", mel_basis) self.register_buffer("window", window) self.hop_length = hop_length self.win_length = int(sampling_rate * frame_size) self.sampling_rate = sampling_rate self.n_mel_channels = n_mel_channels def forward(self, audio): fft = torch.stft( audio.squeeze(1), n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, center=True, ) real_part, imag_part = fft.unbind(-1) magnitude = torch.sqrt(torch.clamp(real_part ** 2 + imag_part ** 2, min=1e-5)) mel_output = torch.matmul(self.mel_basis, magnitude) log_mel_spec = 20 * torch.log10(torch.clamp(mel_output, min=1e-5)) - 20 norm_mel = (log_mel_spec + 115.) / 115. mel_comp = torch.clamp(norm_mel * 8. - 4., -4., 4.) return mel_comp class ResnetBlock(nn.Module): def __init__(self, dim, dilation=1, dim_in=None): super().__init__() if dim_in is None: dim_in = dim self.block = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(dilation), WNConv1d(dim_in, dim, kernel_size=3, dilation=dilation), nn.LeakyReLU(0.2), WNConv1d(dim, dim, kernel_size=1), ) self.shortcut = WNConv1d(dim_in, dim, kernel_size=1) def forward(self, x): return self.shortcut(x) + self.block(x) ''' 参照hifigan(https://arxiv.org/pdf/2010.05646.pdf)v2结构 多尺度主要是kernel_size不同,3组并行卷积模块,每个卷积模块内部采用不同的串行dilation size,且中间交叉正常无dilation卷积层 ''' class ResBlockMRFV2(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlockMRFV2, self).__init__() self.convs1 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) def forward(self, x): for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, 0.2) xt = c1(xt) xt = F.leaky_relu(xt, 0.2) xt = c2(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) class ResBlockMRFV2Inter(torch.nn.Module): def __init__(self, channels, kernel_size=3): super(ResBlockMRFV2Inter, self).__init__() self.block1 = ResBlockMRFV2(channels) self.block2 = ResBlockMRFV2(channels, 7) self.block3 = ResBlockMRFV2(channels, 11) def forward(self, x): xs = self.block1(x) xs += self.block2(x) xs += self.block3(x) x = xs / 3 return x class Generator(nn.Module): def __init__(self, input_size_, ngf, n_residual_layers, num_band, args, ratios=[5, 5, 4, 3], onnx_export=False, device='cpu'): super().__init__() self.hop_length = args.frame_shift self.args = args self.onnx_export = onnx_export # ------------- Define upsample layers ---------------- mult = int(2 ** len(ratios)) model_up = [] input_size = input_size_ model_up += [ nn.ReflectionPad1d(3), WNConv1d(input_size, mult * ngf, kernel_size=7, padding=0), ] # Upsample to raw audio scale for i, r in enumerate(ratios): model_up += [Upsample(mult * ngf, r)] model_up += [ResBlockMRFV2Inter(mult * ngf // 2)] mult //= 2 model_up += [ nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), WNConv1d(ngf, num_band, kernel_size=7, padding=0), nn.Tanh(), ] if not args.use_tanh: model_up[-1] = nn.Conv1d(num_band, num_band, 1) model_up[-2].apply(weights_zero_init) self.model_up = nn.Sequential(*model_up) self.apply(weights_init) def forward(self, mel, step=None): # mel input: (batch_size, seq_num, 80) if self.onnx_export: mel = mel.transpose(1, 2) # on onnx, for engineering, mel input: (batch_size, 80, seq_num) # Between Down and up x = mel # Upsample pipline cnt_after_upsample = 0 for i, m in enumerate(self.model_up): x = m(x) if type(m) == Upsample: cnt_after_upsample += 1 return x