Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/mpnet/discriminator.py
CHANGED
@@ -45,7 +45,6 @@ def metric_loss(metric_ref, metrics_gen):
|
|
45 |
class MetricDiscriminator(nn.Module):
|
46 |
def __init__(self, config: MPNetConfig):
|
47 |
super(MetricDiscriminator, self).__init__()
|
48 |
-
|
49 |
dim = config.discriminator_dim
|
50 |
in_channel = config.discriminator_in_channel
|
51 |
|
@@ -86,6 +85,7 @@ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
|
|
86 |
super(MetricDiscriminatorPretrainedModel, self).__init__(
|
87 |
config=config,
|
88 |
)
|
|
|
89 |
|
90 |
@classmethod
|
91 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
|
45 |
class MetricDiscriminator(nn.Module):
|
46 |
def __init__(self, config: MPNetConfig):
|
47 |
super(MetricDiscriminator, self).__init__()
|
|
|
48 |
dim = config.discriminator_dim
|
49 |
in_channel = config.discriminator_in_channel
|
50 |
|
|
|
85 |
super(MetricDiscriminatorPretrainedModel, self).__init__(
|
86 |
config=config,
|
87 |
)
|
88 |
+
self.config = config
|
89 |
|
90 |
@classmethod
|
91 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
toolbox/torchaudio/models/mpnet/modeling_mpnet.py
CHANGED
@@ -151,7 +151,6 @@ class TSTransformerBlock(nn.Module):
|
|
151 |
class MPNet(nn.Module):
|
152 |
def __init__(self, config: MPNetConfig, num_tsblocks=4):
|
153 |
super(MPNet, self).__init__()
|
154 |
-
self.config = config
|
155 |
self.num_tscblocks = num_tsblocks
|
156 |
self.dense_encoder = DenseEncoder(config, in_channel=2)
|
157 |
|
@@ -193,6 +192,7 @@ class MPNetPretrainedModel(MPNet):
|
|
193 |
super(MPNetPretrainedModel, self).__init__(
|
194 |
config=config,
|
195 |
)
|
|
|
196 |
|
197 |
@classmethod
|
198 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
@@ -251,7 +251,7 @@ def pesq_score(utts_r, utts_g, h):
|
|
251 |
pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
|
252 |
utts_r[i].squeeze().cpu().numpy(),
|
253 |
utts_g[i].squeeze().cpu().numpy(),
|
254 |
-
h.sample_rate)
|
255 |
for i in range(len(utts_r)))
|
256 |
pesq_score = np.mean(pesq_score)
|
257 |
|
@@ -260,7 +260,8 @@ def pesq_score(utts_r, utts_g, h):
|
|
260 |
|
261 |
def eval_pesq(clean_utt, esti_utt, sr):
|
262 |
try:
|
263 |
-
|
|
|
264 |
except:
|
265 |
pesq_score = -1
|
266 |
|
|
|
151 |
class MPNet(nn.Module):
|
152 |
def __init__(self, config: MPNetConfig, num_tsblocks=4):
|
153 |
super(MPNet, self).__init__()
|
|
|
154 |
self.num_tscblocks = num_tsblocks
|
155 |
self.dense_encoder = DenseEncoder(config, in_channel=2)
|
156 |
|
|
|
192 |
super(MPNetPretrainedModel, self).__init__(
|
193 |
config=config,
|
194 |
)
|
195 |
+
self.config = config
|
196 |
|
197 |
@classmethod
|
198 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
|
251 |
pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
|
252 |
utts_r[i].squeeze().cpu().numpy(),
|
253 |
utts_g[i].squeeze().cpu().numpy(),
|
254 |
+
h.sample_rate, )
|
255 |
for i in range(len(utts_r)))
|
256 |
pesq_score = np.mean(pesq_score)
|
257 |
|
|
|
260 |
|
261 |
def eval_pesq(clean_utt, esti_utt, sr):
|
262 |
try:
|
263 |
+
mode = "nb" if sr == 8000 else "wb"
|
264 |
+
pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode)
|
265 |
except:
|
266 |
pesq_score = -1
|
267 |
|