import torch import torch.nn as nn import torch.nn.functional as F from ..utils.utils import weights_init from .res_stack import ResStack # from res_stack import ResStack MAX_WAV_VALUE = 32768.0 class VOCGenerator(nn.Module): def __init__(self, mel_channel, n_residual_layers, ratios=[4, 4, 2, 2, 2, 2], mult=256, out_band=1): super(VOCGenerator, self).__init__() self.mel_channel = mel_channel self.start = nn.Sequential( nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mel_channel, mult * 2, kernel_size=7, stride=1)) ) r = ratios[0] self.upsample_1 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.res_stack_1 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) r = ratios[1] mult = mult // 2 self.upsample_2 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.res_stack_2 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) self.sub_out_1 = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), nn.Tanh(), ) r = ratios[2] mult = mult // 2 self.upsample_3 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_1 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=64, stride=32, padding=16, output_padding=0) ) self.res_stack_3 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) self.sub_out_2 = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), nn.Tanh(), ) r = ratios[3] mult = mult // 2 self.upsample_4 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_2 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=128, stride=64, padding=32, output_padding=0) ) self.res_stack_4 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) self.sub_out_3 = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), nn.Tanh(), ) r = ratios[4] mult = mult // 2 self.upsample_5 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_3 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=256, stride=128, padding=64, output_padding=0) ) self.res_stack_5 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) self.sub_out_4 = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), nn.Tanh(), ) r = ratios[5] mult = mult // 2 self.upsample_6 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_4 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=512, stride=256, padding=128, output_padding=0) ) self.res_stack_6 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) self.out = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), nn.Tanh(), ) self.apply(weights_init) def forward(self, mel): mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram # Mel Shape [B, num_mels, T] -> torch.Size([3, 80, 10]) x = self.start(mel) # [B, dim*2, T] -> torch.Size([3, 512, 10]) x = self.upsample_1(x) x = self.res_stack_1(x) # [B, dim, T*4] -> torch.Size([3, 256, 40]) x = self.upsample_2(x) x = self.res_stack_2(x) # [B, dim/2, T*16] -> torch.Size([3, 128, 160]) out1 = self.sub_out_1(x) # [B, 1, T*16] -> torch.Size([3, 1, 160]) x = self.upsample_3(x) x = x + self.skip_upsample_1(mel) x = self.res_stack_3(x) # [B, dim/4, T*32] -> torch.Size([3, 64, 320]) out2 = self.sub_out_2(x) # [B, 1, T*32] -> torch.Size([3, 1, 320]) x = self.upsample_4(x) x = x + self.skip_upsample_2(mel) x = self.res_stack_4(x) # [B, dim/8, T*64] -> torch.Size([3, 32, 640]) out3 = self.sub_out_3(x) # [B, 1, T*64] -> torch.Size([3, 1, 640]) x = self.upsample_5(x) x = x + self.skip_upsample_3(mel) x = self.res_stack_5(x) # [B, dim/16, T*128] -> torch.Size([3, 16, 1280]) out4 = self.sub_out_4(x) # [B, 1, T*128] -> torch.Size([3, 1, 1280]) x = self.upsample_6(x) x = x + self.skip_upsample_4(mel) x = self.res_stack_6(x) # [B, dim/32, T*256] -> torch.Size([3, 8, 2560]) out = self.out(x) # [B, 1, T*256] -> torch.Size([3, 1, 2560]) return out1, out2, out3, out4, out def inference(self, mel): hop_length = 256 # pad input mel with zeros to cut artifact # see https://github.com/seungwonpark/melgan/issues/8 zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) mel = torch.cat((mel, zero), dim=2) _, _, _, _, audio = self.forward(mel) return audio class ModifiedGenerator(nn.Module): def __init__(self, mel_channel, n_residual_layers, ratios=[4, 4, 2, 2, 2, 2], mult=256, out_band=1): super(ModifiedGenerator, self).__init__() self.mel_channel = mel_channel self.start = nn.Sequential( nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mel_channel, mult * 2, kernel_size=7, stride=1)) ) r = ratios[0] self.upsample_1 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.res_stack_1 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) r = ratios[1] mult = mult // 2 self.upsample_2 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.res_stack_2 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) r = ratios[2] mult = mult // 2 self.upsample_3 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_1 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=64, stride=32, padding=16, output_padding=0) ) self.res_stack_3 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) r = ratios[3] mult = mult // 2 self.upsample_4 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_2 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=128, stride=64, padding=32, output_padding=0) ) self.res_stack_4 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) r = ratios[4] mult = mult // 2 self.upsample_5 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_3 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=256, stride=128, padding=64, output_padding=0) ) self.res_stack_5 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) r = ratios[5] mult = mult // 2 self.upsample_6 = nn.Sequential( nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, kernel_size=r * 2, stride=r, padding=r // 2 + r % 2, output_padding=r % 2) ) ) self.skip_upsample_4 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, kernel_size=512, stride=256, padding=128, output_padding=0) ) self.res_stack_6 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) self.out = nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), nn.Tanh(), ) self.apply(weights_init) def forward(self, mel): mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram # Mel Shape [B, num_mels, T] -> torch.Size([3, 80, 10]) x = self.start(mel) # [B, dim*2, T] -> torch.Size([3, 512, 10]) x = self.upsample_1(x) x = self.res_stack_1(x) # [B, dim, T*4] -> torch.Size([3, 256, 40]) x = self.upsample_2(x) x = self.res_stack_2(x) # [B, dim/2, T*16] -> torch.Size([3, 128, 160]) # out1 = self.sub_out_1(x) # [B, 1, T*16] -> torch.Size([3, 1, 160]) x = self.upsample_3(x) x = x + self.skip_upsample_1(mel) x = self.res_stack_3(x) # [B, dim/4, T*32] -> torch.Size([3, 64, 320]) # out2 = self.sub_out_2(x) # [B, 1, T*32] -> torch.Size([3, 1, 320]) x = self.upsample_4(x) x = x + self.skip_upsample_2(mel) x = self.res_stack_4(x) # [B, dim/8, T*64] -> torch.Size([3, 32, 640]) # out3 = self.sub_out_3(x) # [B, 1, T*64] -> torch.Size([3, 1, 640]) x = self.upsample_5(x) x = x + self.skip_upsample_3(mel) x = self.res_stack_5(x) # [B, dim/16, T*128] -> torch.Size([3, 16, 1280]) # out4 = self.sub_out_4(x) # [B, 1, T*128] -> torch.Size([3, 1, 1280]) x = self.upsample_6(x) x = x + self.skip_upsample_4(mel) x = self.res_stack_6(x) # [B, dim/32, T*256] -> torch.Size([3, 8, 2560]) out = self.out(x) # [B, 1, T*256] -> torch.Size([3, 1, 2560]) return out #out1, out2, out3, out4, out def eval(self, inference=False): super(ModifiedGenerator, self).eval() # don't remove weight norm while validation in training loop if inference: self.remove_weight_norm() # def remove_weight_norm(self): # for idx, layer in enumerate(self.generator): # if len(layer.state_dict()) != 0: # try: # nn.utils.remove_weight_norm(layer) # except: # layer.remove_weight_norm() def remove_weight_norm(self): """Remove weight normalization module from all of the layers.""" def _remove_weight_norm(m): try: torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(_remove_weight_norm) def apply_weight_norm(self): """Apply weight normalization module from all of the layers.""" def _apply_weight_norm(m): if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): torch.nn.utils.weight_norm(m) self.apply(_apply_weight_norm) def inference(self, mel): hop_length = 256 # pad input mel with zeros to cut artifact # see https://github.com/seungwonpark/melgan/issues/8 zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) mel = torch.cat((mel, zero), dim=2) audio = self.forward(mel) return audio ''' to run this, fix from . import ResStack into from res_stack import ResStack ''' if __name__ == '__main__': ''' torch.Size([3, 80, 10]) torch.Size([3, 1, 2560]) 4715698 ''' model = VOCGenerator(80, 4) x = torch.randn(3, 80, 10) # (B, channels, T). print(x.shape) out1, out2, out3, out4, out = model(x) # (B, 1, T ** prod(upsample_scales) assert out.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560]) assert out4.shape == torch.Size([3, 1, 1280]) assert out3.shape == torch.Size([3, 1, 640]) assert out2.shape == torch.Size([3, 1, 320]) assert out1.shape == torch.Size([3, 1, 160]) pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(pytorch_total_params)