HoneyTian commited on
Commit
31b64ca
·
1 Parent(s): 54cd20f
toolbox/torchaudio/models/nx_mpnet/discriminator.py CHANGED
@@ -5,7 +5,10 @@ from typing import Optional, Union
5
 
6
  import torch
7
  import torch.nn as nn
8
- import torchaudio
 
 
 
9
 
10
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
11
  from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
@@ -16,23 +19,10 @@ class MetricDiscriminator(nn.Module):
16
  def __init__(self, config: NXMPNetConfig):
17
  super(MetricDiscriminator, self).__init__()
18
  dim = config.discriminator_dim
19
- self.in_channel = config.discriminator_in_channel
20
-
21
- self.n_fft = config.n_fft
22
- self.win_length = config.win_length
23
- self.hop_length = config.hop_length
24
-
25
- self.transform = torchaudio.transforms.Spectrogram(
26
- n_fft=self.n_fft,
27
- win_length=self.win_length,
28
- hop_length=self.hop_length,
29
- power=1.0,
30
- window_fn=torch.hann_window,
31
- # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
32
- )
33
 
34
  self.layers = nn.Sequential(
35
- nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
36
  nn.InstanceNorm2d(dim, affine=True),
37
  nn.PReLU(dim),
38
  nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
@@ -54,9 +44,6 @@ class MetricDiscriminator(nn.Module):
54
  )
55
 
56
  def forward(self, x, y):
57
- x = self.transform.forward(x)
58
- y = self.transform.forward(y)
59
-
60
  xy = torch.stack((x, y), dim=1)
61
  return self.layers(xy)
62
 
@@ -111,22 +98,5 @@ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
111
  return save_directory
112
 
113
 
114
- def main():
115
- config = NXMPNetConfig()
116
- discriminator = MetricDiscriminator(config=config)
117
-
118
- # shape: [batch_size, num_samples]
119
- # x = torch.ones([4, int(4.5 * 16000)])
120
- # y = torch.ones([4, int(4.5 * 16000)])
121
- x = torch.ones([4, 16000])
122
- y = torch.ones([4, 16000])
123
-
124
- output = discriminator.forward(x, y)
125
- print(output.shape)
126
- print(output)
127
-
128
- return
129
-
130
-
131
- if __name__ == "__main__":
132
- main()
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from pesq import pesq
11
+ from joblib import Parallel, delayed
12
 
13
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
  from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
 
19
  def __init__(self, config: NXMPNetConfig):
20
  super(MetricDiscriminator, self).__init__()
21
  dim = config.discriminator_dim
22
+ in_channel = config.discriminator_in_channel
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  self.layers = nn.Sequential(
25
+ nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
26
  nn.InstanceNorm2d(dim, affine=True),
27
  nn.PReLU(dim),
28
  nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
 
44
  )
45
 
46
  def forward(self, x, y):
 
 
 
47
  xy = torch.stack((x, y), dim=1)
48
  return self.layers(xy)
49
 
 
98
  return save_directory
99
 
100
 
101
+ if __name__ == '__main__':
102
+ pass