Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..utils.utils import weights_init | |
from .res_stack import ResStack | |
# from res_stack import ResStack | |
MAX_WAV_VALUE = 32768.0 | |
class VOCGenerator(nn.Module): | |
def __init__(self, mel_channel, n_residual_layers, ratios=[4, 4, 2, 2, 2, 2], mult=256, out_band=1): | |
super(VOCGenerator, self).__init__() | |
self.mel_channel = mel_channel | |
self.start = nn.Sequential( | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mel_channel, mult * 2, kernel_size=7, stride=1)) | |
) | |
r = ratios[0] | |
self.upsample_1 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.res_stack_1 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
r = ratios[1] | |
mult = mult // 2 | |
self.upsample_2 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.res_stack_2 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
self.sub_out_1 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), | |
nn.Tanh(), | |
) | |
r = ratios[2] | |
mult = mult // 2 | |
self.upsample_3 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_1 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=64, stride=32, | |
padding=16, | |
output_padding=0) | |
) | |
self.res_stack_3 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
self.sub_out_2 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), | |
nn.Tanh(), | |
) | |
r = ratios[3] | |
mult = mult // 2 | |
self.upsample_4 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_2 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=128, stride=64, | |
padding=32, | |
output_padding=0) | |
) | |
self.res_stack_4 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
self.sub_out_3 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), | |
nn.Tanh(), | |
) | |
r = ratios[4] | |
mult = mult // 2 | |
self.upsample_5 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_3 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=256, stride=128, | |
padding=64, | |
output_padding=0) | |
) | |
self.res_stack_5 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
self.sub_out_4 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), | |
nn.Tanh(), | |
) | |
r = ratios[5] | |
mult = mult // 2 | |
self.upsample_6 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_4 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=512, stride=256, | |
padding=128, | |
output_padding=0) | |
) | |
self.res_stack_6 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
self.out = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), | |
nn.Tanh(), | |
) | |
self.apply(weights_init) | |
def forward(self, mel): | |
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram | |
# Mel Shape [B, num_mels, T] -> torch.Size([3, 80, 10]) | |
x = self.start(mel) # [B, dim*2, T] -> torch.Size([3, 512, 10]) | |
x = self.upsample_1(x) | |
x = self.res_stack_1(x) # [B, dim, T*4] -> torch.Size([3, 256, 40]) | |
x = self.upsample_2(x) | |
x = self.res_stack_2(x) # [B, dim/2, T*16] -> torch.Size([3, 128, 160]) | |
out1 = self.sub_out_1(x) # [B, 1, T*16] -> torch.Size([3, 1, 160]) | |
x = self.upsample_3(x) | |
x = x + self.skip_upsample_1(mel) | |
x = self.res_stack_3(x) # [B, dim/4, T*32] -> torch.Size([3, 64, 320]) | |
out2 = self.sub_out_2(x) # [B, 1, T*32] -> torch.Size([3, 1, 320]) | |
x = self.upsample_4(x) | |
x = x + self.skip_upsample_2(mel) | |
x = self.res_stack_4(x) # [B, dim/8, T*64] -> torch.Size([3, 32, 640]) | |
out3 = self.sub_out_3(x) # [B, 1, T*64] -> torch.Size([3, 1, 640]) | |
x = self.upsample_5(x) | |
x = x + self.skip_upsample_3(mel) | |
x = self.res_stack_5(x) # [B, dim/16, T*128] -> torch.Size([3, 16, 1280]) | |
out4 = self.sub_out_4(x) # [B, 1, T*128] -> torch.Size([3, 1, 1280]) | |
x = self.upsample_6(x) | |
x = x + self.skip_upsample_4(mel) | |
x = self.res_stack_6(x) # [B, dim/32, T*256] -> torch.Size([3, 8, 2560]) | |
out = self.out(x) # [B, 1, T*256] -> torch.Size([3, 1, 2560]) | |
return out1, out2, out3, out4, out | |
def inference(self, mel): | |
hop_length = 256 | |
# pad input mel with zeros to cut artifact | |
# see https://github.com/seungwonpark/melgan/issues/8 | |
zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) | |
mel = torch.cat((mel, zero), dim=2) | |
_, _, _, _, audio = self.forward(mel) | |
return audio | |
class ModifiedGenerator(nn.Module): | |
def __init__(self, mel_channel, n_residual_layers, ratios=[4, 4, 2, 2, 2, 2], mult=256, out_band=1): | |
super(ModifiedGenerator, self).__init__() | |
self.mel_channel = mel_channel | |
self.start = nn.Sequential( | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mel_channel, mult * 2, kernel_size=7, stride=1)) | |
) | |
r = ratios[0] | |
self.upsample_1 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.res_stack_1 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
r = ratios[1] | |
mult = mult // 2 | |
self.upsample_2 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.res_stack_2 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
r = ratios[2] | |
mult = mult // 2 | |
self.upsample_3 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_1 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=64, stride=32, | |
padding=16, | |
output_padding=0) | |
) | |
self.res_stack_3 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
r = ratios[3] | |
mult = mult // 2 | |
self.upsample_4 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_2 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=128, stride=64, | |
padding=32, | |
output_padding=0) | |
) | |
self.res_stack_4 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
r = ratios[4] | |
mult = mult // 2 | |
self.upsample_5 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_3 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=256, stride=128, | |
padding=64, | |
output_padding=0) | |
) | |
self.res_stack_5 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
r = ratios[5] | |
mult = mult // 2 | |
self.upsample_6 = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.ConvTranspose1d(mult * 2, mult, | |
kernel_size=r * 2, stride=r, | |
padding=r // 2 + r % 2, | |
output_padding=r % 2) | |
) | |
) | |
self.skip_upsample_4 = nn.utils.weight_norm(nn.ConvTranspose1d(mel_channel, mult, | |
kernel_size=512, stride=256, | |
padding=128, | |
output_padding=0) | |
) | |
self.res_stack_6 = nn.Sequential(*[ResStack(mult, dilation=3 ** j) for j in range(n_residual_layers)]) | |
self.out = nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.ReflectionPad1d(3), | |
nn.utils.weight_norm(nn.Conv1d(mult, out_band, kernel_size=7, stride=1)), | |
nn.Tanh(), | |
) | |
self.apply(weights_init) | |
def forward(self, mel): | |
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram | |
# Mel Shape [B, num_mels, T] -> torch.Size([3, 80, 10]) | |
x = self.start(mel) # [B, dim*2, T] -> torch.Size([3, 512, 10]) | |
x = self.upsample_1(x) | |
x = self.res_stack_1(x) # [B, dim, T*4] -> torch.Size([3, 256, 40]) | |
x = self.upsample_2(x) | |
x = self.res_stack_2(x) # [B, dim/2, T*16] -> torch.Size([3, 128, 160]) | |
# out1 = self.sub_out_1(x) # [B, 1, T*16] -> torch.Size([3, 1, 160]) | |
x = self.upsample_3(x) | |
x = x + self.skip_upsample_1(mel) | |
x = self.res_stack_3(x) # [B, dim/4, T*32] -> torch.Size([3, 64, 320]) | |
# out2 = self.sub_out_2(x) # [B, 1, T*32] -> torch.Size([3, 1, 320]) | |
x = self.upsample_4(x) | |
x = x + self.skip_upsample_2(mel) | |
x = self.res_stack_4(x) # [B, dim/8, T*64] -> torch.Size([3, 32, 640]) | |
# out3 = self.sub_out_3(x) # [B, 1, T*64] -> torch.Size([3, 1, 640]) | |
x = self.upsample_5(x) | |
x = x + self.skip_upsample_3(mel) | |
x = self.res_stack_5(x) # [B, dim/16, T*128] -> torch.Size([3, 16, 1280]) | |
# out4 = self.sub_out_4(x) # [B, 1, T*128] -> torch.Size([3, 1, 1280]) | |
x = self.upsample_6(x) | |
x = x + self.skip_upsample_4(mel) | |
x = self.res_stack_6(x) # [B, dim/32, T*256] -> torch.Size([3, 8, 2560]) | |
out = self.out(x) # [B, 1, T*256] -> torch.Size([3, 1, 2560]) | |
return out #out1, out2, out3, out4, out | |
def eval(self, inference=False): | |
super(ModifiedGenerator, self).eval() | |
# don't remove weight norm while validation in training loop | |
if inference: | |
self.remove_weight_norm() | |
# def remove_weight_norm(self): | |
# for idx, layer in enumerate(self.generator): | |
# if len(layer.state_dict()) != 0: | |
# try: | |
# nn.utils.remove_weight_norm(layer) | |
# except: | |
# layer.remove_weight_norm() | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): | |
torch.nn.utils.weight_norm(m) | |
self.apply(_apply_weight_norm) | |
def inference(self, mel): | |
hop_length = 256 | |
# pad input mel with zeros to cut artifact | |
# see https://github.com/seungwonpark/melgan/issues/8 | |
zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device) | |
mel = torch.cat((mel, zero), dim=2) | |
audio = self.forward(mel) | |
return audio | |
''' | |
to run this, fix | |
from . import ResStack | |
into | |
from res_stack import ResStack | |
''' | |
if __name__ == '__main__': | |
''' | |
torch.Size([3, 80, 10]) | |
torch.Size([3, 1, 2560]) | |
4715698 | |
''' | |
model = VOCGenerator(80, 4) | |
x = torch.randn(3, 80, 10) # (B, channels, T). | |
print(x.shape) | |
out1, out2, out3, out4, out = model(x) # (B, 1, T ** prod(upsample_scales) | |
assert out.shape == torch.Size([3, 1, 2560]) # For normal melgan torch.Size([3, 1, 2560]) | |
assert out4.shape == torch.Size([3, 1, 1280]) | |
assert out3.shape == torch.Size([3, 1, 640]) | |
assert out2.shape == torch.Size([3, 1, 320]) | |
assert out1.shape == torch.Size([3, 1, 160]) | |
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(pytorch_total_params) |