Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
class BasicDiscriminatorBlock(nn.Module): | |
def __init__(self, in_channel, out_channel): | |
super(BasicDiscriminatorBlock, self).__init__() | |
self.block = nn.Sequential( | |
nn.utils.weight_norm(nn.Conv1d( | |
in_channel, | |
out_channel, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
)), | |
nn.LeakyReLU(0.2, True), | |
nn.utils.weight_norm(nn.Conv1d( | |
out_channel, | |
out_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
)), | |
nn.LeakyReLU(0.2, True), | |
nn.utils.weight_norm(nn.Conv1d( | |
out_channel, | |
out_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
)), | |
nn.LeakyReLU(0.2, True), | |
nn.utils.weight_norm(nn.Conv1d( | |
out_channel, | |
out_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
)), | |
) | |
def forward(self, x): | |
return self.block(x) | |
class ResDiscriminatorBlock(nn.Module): | |
def __init__(self, in_channel, out_channel): | |
super(ResDiscriminatorBlock, self).__init__() | |
self.block1 = nn.Sequential( | |
nn.utils.weight_norm(nn.Conv1d( | |
in_channel, | |
out_channel, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
)), | |
nn.LeakyReLU(0.2, True), | |
nn.utils.weight_norm(nn.Conv1d( | |
out_channel, | |
out_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
)), | |
) | |
self.shortcut1 = nn.utils.weight_norm(nn.Conv1d( | |
in_channel, | |
out_channel, | |
kernel_size=1, | |
stride=2, | |
)) | |
self.block2 = nn.Sequential( | |
nn.utils.weight_norm(nn.Conv1d( | |
out_channel, | |
out_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
)), | |
nn.LeakyReLU(0.2, True), | |
nn.utils.weight_norm(nn.Conv1d( | |
out_channel, | |
out_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
)), | |
) | |
self.shortcut2 = nn.utils.weight_norm(nn.Conv1d( | |
out_channel, | |
out_channel, | |
kernel_size=1, | |
stride=1, | |
)) | |
def forward(self, x): | |
x1 = self.block1(x) | |
x1 = x1 + self.shortcut1(x) | |
return self.block2(x1) + self.shortcut2(x1) | |
class ResNet18Discriminator(nn.Module): | |
def __init__(self, stft_channel, in_channel=64): | |
super(ResNet18Discriminator, self).__init__() | |
self.input = nn.Sequential( | |
nn.utils.weight_norm(nn.Conv1d(stft_channel, in_channel, kernel_size=7, stride=2, padding=1, )), | |
nn.LeakyReLU(0.2, True), | |
) | |
self.df1 = BasicDiscriminatorBlock(in_channel, in_channel) | |
self.df2 = ResDiscriminatorBlock(in_channel, in_channel * 2) | |
self.df3 = ResDiscriminatorBlock(in_channel * 2, in_channel * 4) | |
self.df4 = ResDiscriminatorBlock(in_channel * 4, in_channel * 8) | |
def forward(self, x): | |
x = self.input(x) | |
x = self.df1(x) | |
x = self.df2(x) | |
x = self.df3(x) | |
return self.df4(x) | |
class FrequencyDiscriminator(nn.Module): | |
def __init__(self, in_channel=64, fft_size=1024, hop_length=256, win_length=1024, window="hann_window"): | |
super(FrequencyDiscriminator, self).__init__() | |
self.fft_size = fft_size | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) | |
self.stft_channel = fft_size // 2 + 1 | |
self.resnet_disc = ResNet18Discriminator(self.stft_channel, in_channel) | |
def forward(self, x): | |
x_stft = torch.stft(x, self.fft_size, self.hop_length, self.win_length, self.window) | |
real = x_stft[..., 0] | |
imag = x_stft[..., 1] | |
x_real = self.resnet_disc(real) | |
x_imag = self.resnet_disc(imag) | |
return x_real, x_imag | |