HoneyTian commited on
Commit
8128494
·
1 Parent(s): df77126
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
5
  """
6
  import argparse
7
  import json
@@ -42,14 +42,11 @@ def get_args():
42
  parser.add_argument("--max_epochs", default=200, type=int)
43
 
44
  parser.add_argument("--batch_size", default=8, type=int)
45
- parser.add_argument("--learning_rate", default=1e-3, type=float)
46
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
47
  parser.add_argument("--patience", default=5, type=int)
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
49
  parser.add_argument("--seed", default=1234, type=int)
50
 
51
- parser.add_argument("--eval_steps", default=25000, type=int)
52
-
53
  parser.add_argument("--config_file", default="config.yaml", type=str)
54
 
55
  args = parser.parse_args()
@@ -171,7 +168,7 @@ def main():
171
 
172
  # optimizer
173
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
174
- optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
175
 
176
  # resume training
177
  last_epoch = -1
@@ -197,10 +194,21 @@ def main():
197
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
198
  optimizer.load_state_dict(state_dict)
199
 
200
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
201
- optimizer,
202
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
203
- )
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
206
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
@@ -209,6 +217,8 @@ def main():
209
  fft_size_list=[256, 512, 1024],
210
  win_size_list=[120, 240, 480],
211
  hop_size_list=[25, 50, 100],
 
 
212
  reduction="mean"
213
  ).to(device)
214
 
@@ -222,7 +232,7 @@ def main():
222
  average_neg_stoi_loss = 1000000000
223
 
224
  model_list = list()
225
- best_idx_epoch = None
226
  best_metric = None
227
  patience_count = 0
228
 
@@ -260,7 +270,10 @@ def main():
260
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
261
 
262
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
263
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
 
 
 
264
 
265
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
266
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -288,6 +301,7 @@ def main():
288
 
289
  progress_bar_train.update(1)
290
  progress_bar_train.set_postfix({
 
291
  "pesq_score": average_pesq_score,
292
  "loss": average_loss,
293
  "ae_loss": average_ae_loss,
@@ -298,7 +312,7 @@ def main():
298
 
299
  # evaluation
300
  total_steps += 1
301
- if total_steps % args.eval_steps == 0:
302
  with torch.no_grad():
303
  torch.cuda.empty_cache()
304
 
@@ -311,7 +325,7 @@ def main():
311
 
312
  progress_bar_train.close()
313
  progress_bar_eval = tqdm(
314
- desc="Evaluation; step-{}".format(total_steps),
315
  )
316
  for eval_batch in valid_data_loader:
317
  clean_audios, noisy_audios = eval_batch
@@ -327,7 +341,10 @@ def main():
327
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
328
 
329
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
330
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
 
 
 
331
 
332
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
333
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -350,6 +367,7 @@ def main():
350
 
351
  progress_bar_eval.update(1)
352
  progress_bar_eval.set_postfix({
 
353
  "pesq_score": average_pesq_score,
354
  "loss": average_loss,
355
  "ae_loss": average_ae_loss,
@@ -373,7 +391,7 @@ def main():
373
  )
374
 
375
  # save path
376
- save_dir = serialization_dir / "steps-{}".format(total_steps)
377
  save_dir.mkdir(parents=True, exist_ok=False)
378
 
379
  # save models
@@ -389,18 +407,18 @@ def main():
389
 
390
  # save metric
391
  if best_metric is None:
392
- best_idx_epoch = idx_epoch
393
  best_metric = average_pesq_score
394
  elif average_pesq_score > best_metric:
395
  # great is better.
396
- best_idx_epoch = idx_epoch
397
  best_metric = average_pesq_score
398
  else:
399
  pass
400
 
401
  metrics = {
402
  "idx_epoch": idx_epoch,
403
- "best_idx_epoch": best_idx_epoch,
404
  "pesq_score": average_pesq_score,
405
  "loss": average_loss,
406
  "ae_loss": average_ae_loss,
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ https://github.com/kaituoxu/Conv-TasNet/tree/master/src
5
  """
6
  import argparse
7
  import json
 
42
  parser.add_argument("--max_epochs", default=200, type=int)
43
 
44
  parser.add_argument("--batch_size", default=8, type=int)
 
45
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
46
  parser.add_argument("--patience", default=5, type=int)
47
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
48
  parser.add_argument("--seed", default=1234, type=int)
49
 
 
 
50
  parser.add_argument("--config_file", default="config.yaml", type=str)
51
 
52
  args = parser.parse_args()
 
168
 
169
  # optimizer
170
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
171
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
172
 
173
  # resume training
174
  last_epoch = -1
 
194
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
195
  optimizer.load_state_dict(state_dict)
196
 
197
+ if config.lr_scheduler == "CosineAnnealingLR":
198
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
199
+ optimizer,
200
+ last_epoch=last_epoch,
201
+ # T_max=10 * config.eval_steps,
202
+ # eta_min=0.01 * config.lr,
203
+ **config.lr_scheduler_kwargs,
204
+ )
205
+ elif config.lr_scheduler == "MultiStepLR":
206
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
207
+ optimizer,
208
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
209
+ )
210
+ else:
211
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
212
 
213
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
214
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
 
217
  fft_size_list=[256, 512, 1024],
218
  win_size_list=[120, 240, 480],
219
  hop_size_list=[25, 50, 100],
220
+ factor_sc=1.5,
221
+ factor_mag=1.0,
222
  reduction="mean"
223
  ).to(device)
224
 
 
232
  average_neg_stoi_loss = 1000000000
233
 
234
  model_list = list()
235
+ best_steps = None
236
  best_metric = None
237
  patience_count = 0
238
 
 
270
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
271
 
272
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
273
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
274
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
275
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
276
+ loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
277
 
278
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
279
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
301
 
302
  progress_bar_train.update(1)
303
  progress_bar_train.set_postfix({
304
+ "lr": lr_scheduler.get_last_lr()[0],
305
  "pesq_score": average_pesq_score,
306
  "loss": average_loss,
307
  "ae_loss": average_ae_loss,
 
312
 
313
  # evaluation
314
  total_steps += 1
315
+ if total_steps % config.eval_steps == 0:
316
  with torch.no_grad():
317
  torch.cuda.empty_cache()
318
 
 
325
 
326
  progress_bar_train.close()
327
  progress_bar_eval = tqdm(
328
+ desc="Evaluation; step-{}k".format(int(total_steps/1000)),
329
  )
330
  for eval_batch in valid_data_loader:
331
  clean_audios, noisy_audios = eval_batch
 
341
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
342
 
343
  # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
344
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
345
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
346
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
347
+ loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
348
 
349
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
350
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
367
 
368
  progress_bar_eval.update(1)
369
  progress_bar_eval.set_postfix({
370
+ "lr": lr_scheduler.get_last_lr()[0],
371
  "pesq_score": average_pesq_score,
372
  "loss": average_loss,
373
  "ae_loss": average_ae_loss,
 
391
  )
392
 
393
  # save path
394
+ save_dir = serialization_dir / "steps-{}k".format(int(total_steps/1000))
395
  save_dir.mkdir(parents=True, exist_ok=False)
396
 
397
  # save models
 
407
 
408
  # save metric
409
  if best_metric is None:
410
+ best_steps = total_steps
411
  best_metric = average_pesq_score
412
  elif average_pesq_score > best_metric:
413
  # great is better.
414
+ best_steps = total_steps
415
  best_metric = average_pesq_score
416
  else:
417
  pass
418
 
419
  metrics = {
420
  "idx_epoch": idx_epoch,
421
+ "best_steps": best_steps,
422
  "pesq_score": average_pesq_score,
423
  "loss": average_loss,
424
  "ae_loss": average_ae_loss,
examples/conv_tasnet/yaml/config.yaml CHANGED
@@ -15,3 +15,11 @@ sub_blocks_kernel_size: 3
15
  norm_type: "gLN"
16
  causal: false
17
  mask_nonlinear: "relu"
 
 
 
 
 
 
 
 
 
15
  norm_type: "gLN"
16
  causal: false
17
  mask_nonlinear: "relu"
18
+
19
+ lr: 0.001
20
+ lr_scheduler: "CosineAnnealingLR"
21
+ lr_scheduler_kwargs:
22
+ T_max: 250000
23
+ eta_min: 0.00001
24
+
25
+ eval_steps: 25000
toolbox/torchaudio/models/clean_unet/inference_clean_unet.py CHANGED
@@ -79,6 +79,7 @@ class InferenceCleanUNet(object):
79
  # enhanced_audio shape: [channels, num_samples]
80
  return enhanced_audio
81
 
 
82
  def main():
83
  model_zip_file = project_path / "trained_models/clean-unet-aishell-18-epoch.zip"
84
  infer_mpnet = InferenceCleanUNet(model_zip_file)
@@ -100,5 +101,5 @@ def main():
100
  return
101
 
102
 
103
- if __name__ == '__main__':
104
  main()
 
79
  # enhanced_audio shape: [channels, num_samples]
80
  return enhanced_audio
81
 
82
+
83
  def main():
84
  model_zip_file = project_path / "trained_models/clean-unet-aishell-18-epoch.zip"
85
  infer_mpnet = InferenceCleanUNet(model_zip_file)
 
101
  return
102
 
103
 
104
+ if __name__ == "__main__":
105
  main()
toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py CHANGED
@@ -27,6 +27,12 @@ class ConvTasNetConfig(PretrainedConfig):
27
  causal: bool = False,
28
  mask_nonlinear: str = "relu",
29
 
 
 
 
 
 
 
30
  **kwargs
31
  ):
32
  super(ConvTasNetConfig, self).__init__(**kwargs)
@@ -47,6 +53,12 @@ class ConvTasNetConfig(PretrainedConfig):
47
  self.causal = causal
48
  self.mask_nonlinear = mask_nonlinear
49
 
 
 
 
 
 
 
50
 
51
  if __name__ == "__main__":
52
  pass
 
27
  causal: bool = False,
28
  mask_nonlinear: str = "relu",
29
 
30
+ lr: float = 1e-3,
31
+ eval_steps: int = 25000,
32
+
33
+ lr_scheduler: str = "CosineAnnealingLR",
34
+ lr_scheduler_kwargs: dict = None,
35
+
36
  **kwargs
37
  ):
38
  super(ConvTasNetConfig, self).__init__(**kwargs)
 
53
  self.causal = causal
54
  self.mask_nonlinear = mask_nonlinear
55
 
56
+ self.lr = lr
57
+ self.eval_steps = eval_steps
58
+
59
+ self.lr_scheduler = lr_scheduler
60
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
61
+
62
 
63
  if __name__ == "__main__":
64
  pass
toolbox/torchaudio/models/conv_tasnet/inference_conv_tasnet.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile, time
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ torch.set_num_threads(1)
15
+
16
+ from project_settings import project_path
17
+ from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
18
+ from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNetPretrainedModel, MODEL_FILE
19
+
20
+ logger = logging.getLogger("toolbox")
21
+
22
+
23
+ class InferenceConvTasNet(object):
24
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
25
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
26
+ self.device = torch.device(device)
27
+
28
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
29
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
30
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
31
+
32
+ self.config = config
33
+ self.model = model
34
+ self.model.to(device)
35
+ self.model.eval()
36
+
37
+ def load_models(self, model_path: str):
38
+ model_path = Path(model_path)
39
+ if model_path.name.endswith(".zip"):
40
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
41
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
42
+ out_root.mkdir(parents=True, exist_ok=True)
43
+ f_zip.extractall(path=out_root)
44
+ model_path = out_root / model_path.stem
45
+
46
+ config = ConvTasNetConfig.from_pretrained(
47
+ pretrained_model_name_or_path=model_path.as_posix(),
48
+ )
49
+ model = ConvTasNetPretrainedModel.from_pretrained(
50
+ pretrained_model_name_or_path=model_path.as_posix(),
51
+ )
52
+ model.to(self.device)
53
+ model.eval()
54
+
55
+ shutil.rmtree(model_path)
56
+ return config, model
57
+
58
+ def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
59
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
60
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
61
+
62
+ # noisy_audio shape: [batch_size, n_samples]
63
+ enhanced_audio = self.enhancement_by_tensor(noisy_audio)
64
+ # noisy_audio shape: [n_samples,]
65
+ return enhanced_audio.cpu().numpy()
66
+
67
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
68
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
69
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
70
+
71
+ # noisy_audio shape: [batch_size, num_samples]
72
+ noisy_audios = noisy_audio.to(self.device)
73
+
74
+ with torch.no_grad():
75
+ enhanced_audios = self.model.forward(noisy_audios)
76
+ # enhanced_audio shape: [batch_size, channels, num_samples]
77
+ # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
78
+
79
+ enhanced_audio = enhanced_audios[0]
80
+
81
+ # enhanced_audio shape: [channels, num_samples]
82
+ return enhanced_audio
83
+
84
+
85
+ def main():
86
+ model_zip_file = project_path / "trained_models/conv-tasnet-dns3-575k-steps.zip"
87
+ infer_conv_tasnet = InferenceConvTasNet(model_zip_file)
88
+
89
+ sample_rate = 8000
90
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
91
+ noisy_audio, sample_rate = librosa.load(
92
+ noisy_audio_file.as_posix(),
93
+ sr=sample_rate,
94
+ )
95
+ duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
96
+ # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
97
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
98
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
99
+
100
+ begin = time.time()
101
+ enhanced_audio = infer_conv_tasnet.enhancement_by_tensor(noisy_audio)
102
+ time_cost = time.time() - begin
103
+ print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
104
+
105
+ filename = "enhanced_audio.wav"
106
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
107
+
108
+ return
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()