HoneyTian commited on
Commit
e07fc8f
·
1 Parent(s): 78c7ce6
toolbox/torchaudio/models/nx_clean_unet/discriminator.py CHANGED
@@ -22,15 +22,16 @@ class MetricDiscriminator(nn.Module):
22
  self.win_length = config.win_length
23
  self.hop_length = config.hop_length
24
 
 
 
 
 
 
 
 
 
 
25
  self.layers = nn.Sequential(
26
- torchaudio.transforms.Spectrogram(
27
- n_fft=self.n_fft,
28
- win_length=self.win_length,
29
- hop_length=self.hop_length,
30
- power=1.0,
31
- window_fn=torch.hamming_window,
32
- # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
33
- ),
34
  nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
35
  nn.InstanceNorm2d(dim, affine=True),
36
  nn.PReLU(dim),
@@ -53,6 +54,9 @@ class MetricDiscriminator(nn.Module):
53
  )
54
 
55
  def forward(self, x, y):
 
 
 
56
  xy = torch.stack((x, y), dim=1)
57
  return self.layers(xy)
58
 
 
22
  self.win_length = config.win_length
23
  self.hop_length = config.hop_length
24
 
25
+ self.transform = torchaudio.transforms.Spectrogram(
26
+ n_fft=self.n_fft,
27
+ win_length=self.win_length,
28
+ hop_length=self.hop_length,
29
+ power=1.0,
30
+ window_fn=torch.hann_window,
31
+ # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
32
+ )
33
+
34
  self.layers = nn.Sequential(
 
 
 
 
 
 
 
 
35
  nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
36
  nn.InstanceNorm2d(dim, affine=True),
37
  nn.PReLU(dim),
 
54
  )
55
 
56
  def forward(self, x, y):
57
+ x = self.transform.forward(x)
58
+ y = self.transform.forward(y)
59
+
60
  xy = torch.stack((x, y), dim=1)
61
  return self.layers(xy)
62