Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/discriminator.py
CHANGED
@@ -22,15 +22,16 @@ class MetricDiscriminator(nn.Module):
|
|
22 |
self.win_length = config.win_length
|
23 |
self.hop_length = config.hop_length
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
self.layers = nn.Sequential(
|
26 |
-
torchaudio.transforms.Spectrogram(
|
27 |
-
n_fft=self.n_fft,
|
28 |
-
win_length=self.win_length,
|
29 |
-
hop_length=self.hop_length,
|
30 |
-
power=1.0,
|
31 |
-
window_fn=torch.hamming_window,
|
32 |
-
# window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
33 |
-
),
|
34 |
nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
35 |
nn.InstanceNorm2d(dim, affine=True),
|
36 |
nn.PReLU(dim),
|
@@ -53,6 +54,9 @@ class MetricDiscriminator(nn.Module):
|
|
53 |
)
|
54 |
|
55 |
def forward(self, x, y):
|
|
|
|
|
|
|
56 |
xy = torch.stack((x, y), dim=1)
|
57 |
return self.layers(xy)
|
58 |
|
|
|
22 |
self.win_length = config.win_length
|
23 |
self.hop_length = config.hop_length
|
24 |
|
25 |
+
self.transform = torchaudio.transforms.Spectrogram(
|
26 |
+
n_fft=self.n_fft,
|
27 |
+
win_length=self.win_length,
|
28 |
+
hop_length=self.hop_length,
|
29 |
+
power=1.0,
|
30 |
+
window_fn=torch.hann_window,
|
31 |
+
# window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
32 |
+
)
|
33 |
+
|
34 |
self.layers = nn.Sequential(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
36 |
nn.InstanceNorm2d(dim, affine=True),
|
37 |
nn.PReLU(dim),
|
|
|
54 |
)
|
55 |
|
56 |
def forward(self, x, y):
|
57 |
+
x = self.transform.forward(x)
|
58 |
+
y = self.transform.forward(y)
|
59 |
+
|
60 |
xy = torch.stack((x, y), dim=1)
|
61 |
return self.layers(xy)
|
62 |
|