Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_mpnet/discriminator.py
CHANGED
@@ -5,7 +5,10 @@ from typing import Optional, Union
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
-
import
|
|
|
|
|
|
|
9 |
|
10 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
11 |
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
|
@@ -16,23 +19,10 @@ class MetricDiscriminator(nn.Module):
|
|
16 |
def __init__(self, config: NXMPNetConfig):
|
17 |
super(MetricDiscriminator, self).__init__()
|
18 |
dim = config.discriminator_dim
|
19 |
-
|
20 |
-
|
21 |
-
self.n_fft = config.n_fft
|
22 |
-
self.win_length = config.win_length
|
23 |
-
self.hop_length = config.hop_length
|
24 |
-
|
25 |
-
self.transform = torchaudio.transforms.Spectrogram(
|
26 |
-
n_fft=self.n_fft,
|
27 |
-
win_length=self.win_length,
|
28 |
-
hop_length=self.hop_length,
|
29 |
-
power=1.0,
|
30 |
-
window_fn=torch.hann_window,
|
31 |
-
# window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
32 |
-
)
|
33 |
|
34 |
self.layers = nn.Sequential(
|
35 |
-
nn.utils.spectral_norm(nn.Conv2d(
|
36 |
nn.InstanceNorm2d(dim, affine=True),
|
37 |
nn.PReLU(dim),
|
38 |
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
|
@@ -54,9 +44,6 @@ class MetricDiscriminator(nn.Module):
|
|
54 |
)
|
55 |
|
56 |
def forward(self, x, y):
|
57 |
-
x = self.transform.forward(x)
|
58 |
-
y = self.transform.forward(y)
|
59 |
-
|
60 |
xy = torch.stack((x, y), dim=1)
|
61 |
return self.layers(xy)
|
62 |
|
@@ -111,22 +98,5 @@ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
|
|
111 |
return save_directory
|
112 |
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
discriminator = MetricDiscriminator(config=config)
|
117 |
-
|
118 |
-
# shape: [batch_size, num_samples]
|
119 |
-
# x = torch.ones([4, int(4.5 * 16000)])
|
120 |
-
# y = torch.ones([4, int(4.5 * 16000)])
|
121 |
-
x = torch.ones([4, 16000])
|
122 |
-
y = torch.ones([4, 16000])
|
123 |
-
|
124 |
-
output = discriminator.forward(x, y)
|
125 |
-
print(output.shape)
|
126 |
-
print(output)
|
127 |
-
|
128 |
-
return
|
129 |
-
|
130 |
-
|
131 |
-
if __name__ == "__main__":
|
132 |
-
main()
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from pesq import pesq
|
11 |
+
from joblib import Parallel, delayed
|
12 |
|
13 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
|
|
|
19 |
def __init__(self, config: NXMPNetConfig):
|
20 |
super(MetricDiscriminator, self).__init__()
|
21 |
dim = config.discriminator_dim
|
22 |
+
in_channel = config.discriminator_in_channel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
self.layers = nn.Sequential(
|
25 |
+
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
26 |
nn.InstanceNorm2d(dim, affine=True),
|
27 |
nn.PReLU(dim),
|
28 |
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
|
|
|
44 |
)
|
45 |
|
46 |
def forward(self, x, y):
|
|
|
|
|
|
|
47 |
xy = torch.stack((x, y), dim=1)
|
48 |
return self.layers(xy)
|
49 |
|
|
|
98 |
return save_directory
|
99 |
|
100 |
|
101 |
+
if __name__ == '__main__':
|
102 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|