HoneyTian commited on
Commit
a88ebd1
·
1 Parent(s): 32aa651
toolbox/torchaudio/models/mpnet/discriminator.py CHANGED
@@ -15,22 +15,22 @@ from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
15
  from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
16
 
17
 
18
- def cal_pesq(clean, noisy, sr=16000):
19
- try:
20
- pesq_score = pesq(sr, clean, noisy, 'wb')
21
- except:
22
- # error can happen due to silent period
23
- pesq_score = -1
24
- return pesq_score
25
-
26
-
27
- def batch_pesq(clean, noisy):
28
- pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
29
- pesq_score = np.array(pesq_score)
30
- if -1 in pesq_score:
31
- return None
32
- pesq_score = (pesq_score - 1) / 3.5
33
- return torch.FloatTensor(pesq_score)
34
 
35
 
36
  def metric_loss(metric_ref, metrics_gen):
 
15
  from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
16
 
17
 
18
+ # def cal_pesq(clean, noisy, sr=16000):
19
+ # try:
20
+ # pesq_score = pesq(sr, clean, noisy, 'wb')
21
+ # except:
22
+ # # error can happen due to silent period
23
+ # pesq_score = -1
24
+ # return pesq_score
25
+
26
+
27
+ # def batch_pesq(clean, noisy):
28
+ # pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
29
+ # pesq_score = np.array(pesq_score)
30
+ # if -1 in pesq_score:
31
+ # return None
32
+ # pesq_score = (pesq_score - 1) / 3.5
33
+ # return torch.FloatTensor(pesq_score)
34
 
35
 
36
  def metric_loss(metric_ref, metrics_gen):
toolbox/torchaudio/models/mpnet/modeling_mpnet.py CHANGED
@@ -250,26 +250,26 @@ def anti_wrapping_function(x):
250
  return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
251
 
252
 
253
- def pesq_score(utts_r, utts_g, h):
254
-
255
- pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
256
- utts_r[i].squeeze().cpu().numpy(),
257
- utts_g[i].squeeze().cpu().numpy(),
258
- h.sample_rate, )
259
- for i in range(len(utts_r)))
260
- pesq_score = np.mean(pesq_score)
261
-
262
- return pesq_score
263
-
264
-
265
- def eval_pesq(clean_utt, esti_utt, sr):
266
- try:
267
- mode = "nb" if sr == 8000 else "wb"
268
- pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode)
269
- except:
270
- pesq_score = -1
271
-
272
- return pesq_score
273
 
274
 
275
  def main():
 
250
  return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
251
 
252
 
253
+ # def pesq_score(utts_r, utts_g, h):
254
+ #
255
+ # pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
256
+ # utts_r[i].squeeze().cpu().numpy(),
257
+ # utts_g[i].squeeze().cpu().numpy(),
258
+ # h.sample_rate, )
259
+ # for i in range(len(utts_r)))
260
+ # pesq_score = np.mean(pesq_score)
261
+ #
262
+ # return pesq_score
263
+ #
264
+ #
265
+ # def eval_pesq(clean_utt, esti_utt, sr):
266
+ # try:
267
+ # mode = "nb" if sr == 8000 else "wb"
268
+ # pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode)
269
+ # except:
270
+ # pesq_score = -1
271
+ #
272
+ # return pesq_score
273
 
274
 
275
  def main():