HoneyTian commited on
Commit
4f045d5
·
1 Parent(s): b06a791
examples/clean_unet_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -32,7 +32,7 @@ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanU
32
  from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
33
  from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
34
  from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
35
- from toolbox.torchaudio.models.clean_unet.metrics import batch_pesq
36
 
37
  torch.autograd.set_detect_anomaly(True)
38
 
@@ -217,7 +217,7 @@ def main():
217
  # train
218
  model.train()
219
 
220
- total_pesq_metric = 0.
221
  total_loss = 0.
222
  total_ae_loss = 0.
223
  total_sc_loss = 0.
@@ -243,25 +243,21 @@ def main():
243
 
244
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
245
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
246
- pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
247
- if pesq_metric is None:
248
- pesq_metric = 0
249
- else:
250
- pesq_metric = torch.mean(pesq_metric).item()
251
 
252
  optimizer.zero_grad()
253
  loss.backward()
254
  optimizer.step()
255
  lr_scheduler.step()
256
 
257
- total_pesq_metric += pesq_metric
258
  total_loss += loss.item()
259
  total_ae_loss += ae_loss.item()
260
  total_sc_loss += sc_loss.item()
261
  total_mag_loss += mag_loss.item()
262
  total_batches += 1
263
 
264
- average_pesq_metric = round(total_pesq_metric / total_batches, 4)
265
  average_loss = round(total_loss / total_batches, 4)
266
  average_ae_loss = round(total_ae_loss / total_batches, 4)
267
  average_sc_loss = round(total_sc_loss / total_batches, 4)
@@ -269,7 +265,7 @@ def main():
269
 
270
  progress_bar.update(1)
271
  progress_bar.set_postfix({
272
- "pesq_metric": average_pesq_metric,
273
  "loss": average_loss,
274
  "ae_loss": average_ae_loss,
275
  "sc_loss": average_sc_loss,
@@ -281,7 +277,7 @@ def main():
281
 
282
  torch.cuda.empty_cache()
283
 
284
- total_pesq_metric = 0.
285
  total_loss = 0.
286
  total_ae_loss = 0.
287
  total_sc_loss = 0.
@@ -308,20 +304,16 @@ def main():
308
 
309
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
310
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
311
- pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
312
- if pesq_metric is None:
313
- pesq_metric = 0
314
- else:
315
- pesq_metric = torch.mean(pesq_metric).item()
316
 
317
- total_pesq_metric += pesq_metric
318
  total_loss += loss.item()
319
  total_ae_loss += ae_loss.item()
320
  total_sc_loss += sc_loss.item()
321
  total_mag_loss += mag_loss.item()
322
  total_batches += 1
323
 
324
- average_pesq_metric = round(total_pesq_metric / total_batches, 4)
325
  average_loss = round(total_loss / total_batches, 4)
326
  average_ae_loss = round(total_ae_loss / total_batches, 4)
327
  average_sc_loss = round(total_sc_loss / total_batches, 4)
@@ -329,7 +321,7 @@ def main():
329
 
330
  progress_bar.update(1)
331
  progress_bar.set_postfix({
332
- "pesq_metric": average_pesq_metric,
333
  "loss": average_loss,
334
  "ae_loss": average_ae_loss,
335
  "sc_loss": average_sc_loss,
 
32
  from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
33
  from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
34
  from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
35
+ from toolbox.torchaudio.models.clean_unet.metrics import run_pesq_score
36
 
37
  torch.autograd.set_detect_anomaly(True)
38
 
 
217
  # train
218
  model.train()
219
 
220
+ total_pesq_score = 0.
221
  total_loss = 0.
222
  total_ae_loss = 0.
223
  total_sc_loss = 0.
 
243
 
244
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
245
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
246
+ pesq_score = run_pesq_score(enhanced_audios_list_r, clean_audios_list_r, sample_rate=8000, mode="nb")
 
 
 
 
247
 
248
  optimizer.zero_grad()
249
  loss.backward()
250
  optimizer.step()
251
  lr_scheduler.step()
252
 
253
+ total_pesq_score += pesq_score
254
  total_loss += loss.item()
255
  total_ae_loss += ae_loss.item()
256
  total_sc_loss += sc_loss.item()
257
  total_mag_loss += mag_loss.item()
258
  total_batches += 1
259
 
260
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
261
  average_loss = round(total_loss / total_batches, 4)
262
  average_ae_loss = round(total_ae_loss / total_batches, 4)
263
  average_sc_loss = round(total_sc_loss / total_batches, 4)
 
265
 
266
  progress_bar.update(1)
267
  progress_bar.set_postfix({
268
+ "pesq_score": average_pesq_score,
269
  "loss": average_loss,
270
  "ae_loss": average_ae_loss,
271
  "sc_loss": average_sc_loss,
 
277
 
278
  torch.cuda.empty_cache()
279
 
280
+ total_pesq_score = 0.
281
  total_loss = 0.
282
  total_ae_loss = 0.
283
  total_sc_loss = 0.
 
304
 
305
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
306
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
307
+ pesq_score = run_pesq_score(enhanced_audios_list_r, clean_audios_list_r, sample_rate=8000, mode="nb")
 
 
 
 
308
 
309
+ total_pesq_score += pesq_score
310
  total_loss += loss.item()
311
  total_ae_loss += ae_loss.item()
312
  total_sc_loss += sc_loss.item()
313
  total_mag_loss += mag_loss.item()
314
  total_batches += 1
315
 
316
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
317
  average_loss = round(total_loss / total_batches, 4)
318
  average_ae_loss = round(total_ae_loss / total_batches, 4)
319
  average_sc_loss = round(total_sc_loss / total_batches, 4)
 
321
 
322
  progress_bar.update(1)
323
  progress_bar.set_postfix({
324
+ "pesq_score": average_pesq_score,
325
  "loss": average_loss,
326
  "ae_loss": average_ae_loss,
327
  "sc_loss": average_sc_loss,
toolbox/torchaudio/models/clean_unet/metrics.py CHANGED
@@ -3,38 +3,76 @@
3
  from joblib import Parallel, delayed
4
  import numpy as np
5
  from pesq import pesq
6
- import torch
7
 
 
8
 
9
- def cal_pesq(clean, noisy, sr=16000):
 
 
 
 
 
 
 
10
  try:
11
- pesq_score = pesq(sr, clean, noisy, "wb")
 
 
12
  except Exception as e:
13
- # print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
14
- # error can happen due to silent period
15
  pesq_score = -1
16
  return pesq_score
17
 
18
 
19
- def batch_pesq(clean, noisy):
20
- pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
21
- pesq_score = np.array(pesq_score)
22
- if -1 in pesq_score:
23
- return None
24
- pesq_score = (pesq_score - 1) / 3.5
25
- return torch.FloatTensor(pesq_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def main():
 
 
29
 
30
- prediction = torch.rand(size=(1, 160000), dtype=torch.float32)
31
- ground_truth = torch.rand(size=(1, 160000), dtype=torch.float32)
32
 
33
- prediction_list_r = list(prediction.cpu().numpy())
34
- ground_truth_list_r = list(ground_truth.cpu().numpy())
35
 
36
- pesq_score = batch_pesq(prediction_list_r, ground_truth_list_r)
37
  print(pesq_score)
 
38
  return
39
 
40
 
 
3
  from joblib import Parallel, delayed
4
  import numpy as np
5
  from pesq import pesq
6
+ from typing import List
7
 
8
+ from pesq import cypesq
9
 
10
+
11
+ def run_pesq(clean_audio: np.ndarray,
12
+ noisy_audio: np.ndarray,
13
+ sample_rate: int = 16000,
14
+ mode: str = "wb",
15
+ ) -> float:
16
+ if sample_rate == 8000 and mode == "wb":
17
+ raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
  try:
19
+ pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
+ except cypesq.NoUtterancesError as e:
21
+ pesq_score = -1
22
  except Exception as e:
23
+ print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
 
24
  pesq_score = -1
25
  return pesq_score
26
 
27
 
28
+ def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
+ noisy_audio_list: List[np.ndarray],
30
+ sample_rate: int = 16000,
31
+ mode: str = "wb",
32
+ n_jobs: int = 4,
33
+ ) -> List[float]:
34
+ parallel = Parallel(n_jobs=n_jobs)
35
+
36
+ parallel_tasks = list()
37
+ for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
+ parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
+ parallel_tasks.append(parallel_task)
40
+
41
+ pesq_score_list = parallel.__call__(parallel_tasks)
42
+ return pesq_score_list
43
+
44
+
45
+ def run_pesq_score(clean_audio_list: List[np.ndarray],
46
+ noisy_audio_list: List[np.ndarray],
47
+ sample_rate: int = 16000,
48
+ mode: str = "wb",
49
+ n_jobs: int = 4,
50
+ ) -> List[float]:
51
+
52
+ pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
+ noisy_audio_list=noisy_audio_list,
54
+ sample_rate=sample_rate,
55
+ mode=mode,
56
+ n_jobs=n_jobs,
57
+ )
58
+
59
+ pesq_score = np.mean(pesq_score_list)
60
+ return pesq_score
61
 
62
 
63
  def main():
64
+ clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
+ noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
 
67
+ clean_audio_list = list(clean_audio)
68
+ noisy_audio_list = list(noisy_audio)
69
 
70
+ pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
+ print(pesq_score_list)
72
 
73
+ pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
  print(pesq_score)
75
+
76
  return
77
 
78
 
toolbox/torchaudio/models/mpnet/metrics.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from joblib import Parallel, delayed
4
+ import numpy as np
5
+ from pesq import pesq
6
+ from typing import List
7
+
8
+ from pesq import cypesq
9
+
10
+
11
+ def run_pesq(clean_audio: np.ndarray,
12
+ noisy_audio: np.ndarray,
13
+ sample_rate: int = 16000,
14
+ mode: str = "wb",
15
+ ) -> float:
16
+ if sample_rate == 8000 and mode == "wb":
17
+ raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
+ try:
19
+ pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
+ except cypesq.NoUtterancesError as e:
21
+ pesq_score = -1
22
+ except Exception as e:
23
+ print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
24
+ pesq_score = -1
25
+ return pesq_score
26
+
27
+
28
+ def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
+ noisy_audio_list: List[np.ndarray],
30
+ sample_rate: int = 16000,
31
+ mode: str = "wb",
32
+ n_jobs: int = 4,
33
+ ) -> List[float]:
34
+ parallel = Parallel(n_jobs=n_jobs)
35
+
36
+ parallel_tasks = list()
37
+ for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
+ parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
+ parallel_tasks.append(parallel_task)
40
+
41
+ pesq_score_list = parallel.__call__(parallel_tasks)
42
+ return pesq_score_list
43
+
44
+
45
+ def run_pesq_score(clean_audio_list: List[np.ndarray],
46
+ noisy_audio_list: List[np.ndarray],
47
+ sample_rate: int = 16000,
48
+ mode: str = "wb",
49
+ n_jobs: int = 4,
50
+ ) -> List[float]:
51
+
52
+ pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
+ noisy_audio_list=noisy_audio_list,
54
+ sample_rate=sample_rate,
55
+ mode=mode,
56
+ n_jobs=n_jobs,
57
+ )
58
+
59
+ pesq_score = np.mean(pesq_score_list)
60
+ return pesq_score
61
+
62
+
63
+ def main():
64
+ clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
+ noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
+
67
+ clean_audio_list = list(clean_audio)
68
+ noisy_audio_list = list(noisy_audio)
69
+
70
+ pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
+ print(pesq_score_list)
72
+
73
+ pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
+ print(pesq_score)
75
+
76
+ return
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()