HoneyTian commited on
Commit
20d2f3e
·
1 Parent(s): 7a570aa
toolbox/torchaudio/models/mpnet/discriminator.py CHANGED
@@ -45,7 +45,6 @@ def metric_loss(metric_ref, metrics_gen):
45
  class MetricDiscriminator(nn.Module):
46
  def __init__(self, config: MPNetConfig):
47
  super(MetricDiscriminator, self).__init__()
48
-
49
  dim = config.discriminator_dim
50
  in_channel = config.discriminator_in_channel
51
 
@@ -86,6 +85,7 @@ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
86
  super(MetricDiscriminatorPretrainedModel, self).__init__(
87
  config=config,
88
  )
 
89
 
90
  @classmethod
91
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
 
45
  class MetricDiscriminator(nn.Module):
46
  def __init__(self, config: MPNetConfig):
47
  super(MetricDiscriminator, self).__init__()
 
48
  dim = config.discriminator_dim
49
  in_channel = config.discriminator_in_channel
50
 
 
85
  super(MetricDiscriminatorPretrainedModel, self).__init__(
86
  config=config,
87
  )
88
+ self.config = config
89
 
90
  @classmethod
91
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
toolbox/torchaudio/models/mpnet/modeling_mpnet.py CHANGED
@@ -151,7 +151,6 @@ class TSTransformerBlock(nn.Module):
151
  class MPNet(nn.Module):
152
  def __init__(self, config: MPNetConfig, num_tsblocks=4):
153
  super(MPNet, self).__init__()
154
- self.config = config
155
  self.num_tscblocks = num_tsblocks
156
  self.dense_encoder = DenseEncoder(config, in_channel=2)
157
 
@@ -193,6 +192,7 @@ class MPNetPretrainedModel(MPNet):
193
  super(MPNetPretrainedModel, self).__init__(
194
  config=config,
195
  )
 
196
 
197
  @classmethod
198
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
@@ -251,7 +251,7 @@ def pesq_score(utts_r, utts_g, h):
251
  pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
252
  utts_r[i].squeeze().cpu().numpy(),
253
  utts_g[i].squeeze().cpu().numpy(),
254
- h.sample_rate)
255
  for i in range(len(utts_r)))
256
  pesq_score = np.mean(pesq_score)
257
 
@@ -260,7 +260,8 @@ def pesq_score(utts_r, utts_g, h):
260
 
261
  def eval_pesq(clean_utt, esti_utt, sr):
262
  try:
263
- pesq_score = pesq(sr, clean_utt, esti_utt)
 
264
  except:
265
  pesq_score = -1
266
 
 
151
  class MPNet(nn.Module):
152
  def __init__(self, config: MPNetConfig, num_tsblocks=4):
153
  super(MPNet, self).__init__()
 
154
  self.num_tscblocks = num_tsblocks
155
  self.dense_encoder = DenseEncoder(config, in_channel=2)
156
 
 
192
  super(MPNetPretrainedModel, self).__init__(
193
  config=config,
194
  )
195
+ self.config = config
196
 
197
  @classmethod
198
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
 
251
  pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
252
  utts_r[i].squeeze().cpu().numpy(),
253
  utts_g[i].squeeze().cpu().numpy(),
254
+ h.sample_rate, )
255
  for i in range(len(utts_r)))
256
  pesq_score = np.mean(pesq_score)
257
 
 
260
 
261
  def eval_pesq(clean_utt, esti_utt, sr):
262
  try:
263
+ mode = "nb" if sr == 8000 else "wb"
264
+ pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode)
265
  except:
266
  pesq_score = -1
267