yerfor's picture
init
22871e7
import torch
import torch.nn as nn
class BasicDiscriminatorBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super(BasicDiscriminatorBlock, self).__init__()
self.block = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(
in_channel,
out_channel,
kernel_size=3,
stride=2,
padding=1,
)),
nn.LeakyReLU(0.2, True),
nn.utils.weight_norm(nn.Conv1d(
out_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
)),
nn.LeakyReLU(0.2, True),
nn.utils.weight_norm(nn.Conv1d(
out_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
)),
nn.LeakyReLU(0.2, True),
nn.utils.weight_norm(nn.Conv1d(
out_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
)),
)
def forward(self, x):
return self.block(x)
class ResDiscriminatorBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super(ResDiscriminatorBlock, self).__init__()
self.block1 = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(
in_channel,
out_channel,
kernel_size=3,
stride=2,
padding=1,
)),
nn.LeakyReLU(0.2, True),
nn.utils.weight_norm(nn.Conv1d(
out_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
)),
)
self.shortcut1 = nn.utils.weight_norm(nn.Conv1d(
in_channel,
out_channel,
kernel_size=1,
stride=2,
))
self.block2 = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(
out_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
)),
nn.LeakyReLU(0.2, True),
nn.utils.weight_norm(nn.Conv1d(
out_channel,
out_channel,
kernel_size=3,
stride=1,
padding=1,
)),
)
self.shortcut2 = nn.utils.weight_norm(nn.Conv1d(
out_channel,
out_channel,
kernel_size=1,
stride=1,
))
def forward(self, x):
x1 = self.block1(x)
x1 = x1 + self.shortcut1(x)
return self.block2(x1) + self.shortcut2(x1)
class ResNet18Discriminator(nn.Module):
def __init__(self, stft_channel, in_channel=64):
super(ResNet18Discriminator, self).__init__()
self.input = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(stft_channel, in_channel, kernel_size=7, stride=2, padding=1, )),
nn.LeakyReLU(0.2, True),
)
self.df1 = BasicDiscriminatorBlock(in_channel, in_channel)
self.df2 = ResDiscriminatorBlock(in_channel, in_channel * 2)
self.df3 = ResDiscriminatorBlock(in_channel * 2, in_channel * 4)
self.df4 = ResDiscriminatorBlock(in_channel * 4, in_channel * 8)
def forward(self, x):
x = self.input(x)
x = self.df1(x)
x = self.df2(x)
x = self.df3(x)
return self.df4(x)
class FrequencyDiscriminator(nn.Module):
def __init__(self, in_channel=64, fft_size=1024, hop_length=256, win_length=1024, window="hann_window"):
super(FrequencyDiscriminator, self).__init__()
self.fft_size = fft_size
self.hop_length = hop_length
self.win_length = win_length
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.stft_channel = fft_size // 2 + 1
self.resnet_disc = ResNet18Discriminator(self.stft_channel, in_channel)
def forward(self, x):
x_stft = torch.stft(x, self.fft_size, self.hop_length, self.win_length, self.window)
real = x_stft[..., 0]
imag = x_stft[..., 1]
x_real = self.resnet_disc(real)
x_imag = self.resnet_disc(imag)
return x_real, x_imag