HoneyTian commited on
Commit
45bf211
·
1 Parent(s): 55d487a
examples/dfnet2/step_2_train_model.py CHANGED
@@ -318,10 +318,11 @@ def main():
318
  # evaluation
319
  step_idx += 1
320
  if step_idx % config.eval_steps == 0:
321
- model.eval()
322
  with torch.no_grad():
323
  torch.cuda.empty_cache()
324
 
 
 
325
  total_pesq_score = 0.
326
  total_loss = 0.
327
  total_mr_stft_loss = 0.
@@ -384,82 +385,83 @@ def main():
384
  "lsnr_loss": average_lsnr_loss,
385
  })
386
 
387
- total_pesq_score = 0.
388
- total_loss = 0.
389
- total_mr_stft_loss = 0.
390
- total_neg_si_snr_loss = 0.
391
- total_mask_loss = 0.
392
- total_lsnr_loss = 0.
393
- total_batches = 0.
394
-
395
- progress_bar_eval.close()
396
- progress_bar_train = tqdm(
397
- initial=progress_bar_train.n,
398
- postfix=progress_bar_train.postfix,
399
- desc=progress_bar_train.desc,
400
- )
401
-
402
- # save path
403
- save_dir = serialization_dir / "steps-{}".format(step_idx)
404
- save_dir.mkdir(parents=True, exist_ok=False)
405
-
406
- # save models
407
- model.save_pretrained(save_dir.as_posix())
408
-
409
- model_list.append(save_dir)
410
- if len(model_list) >= args.num_serialized_models_to_keep:
411
- model_to_delete: Path = model_list.pop(0)
412
- shutil.rmtree(model_to_delete.as_posix())
413
-
414
- # save metric
415
- if best_metric is None:
416
- best_epoch_idx = epoch_idx
417
- best_step_idx = step_idx
418
- best_metric = average_pesq_score
419
- elif average_pesq_score >= best_metric:
420
- # great is better.
421
- best_epoch_idx = epoch_idx
422
- best_step_idx = step_idx
423
- best_metric = average_pesq_score
424
- else:
425
- pass
426
-
427
- metrics = {
428
- "epoch_idx": epoch_idx,
429
- "best_epoch_idx": best_epoch_idx,
430
- "best_step_idx": best_step_idx,
431
- "pesq_score": average_pesq_score,
432
- "loss": average_loss,
433
- "mr_stft_loss": average_mr_stft_loss,
434
- "neg_si_snr_loss": average_neg_si_snr_loss,
435
- "mask_loss": average_mask_loss,
436
- "lsnr_loss": average_lsnr_loss,
437
- }
438
- metrics_filename = save_dir / "metrics_epoch.json"
439
- with open(metrics_filename, "w", encoding="utf-8") as f:
440
- json.dump(metrics, f, indent=4, ensure_ascii=False)
441
-
442
- # save best
443
- best_dir = serialization_dir / "best"
444
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
445
- if best_dir.exists():
446
- shutil.rmtree(best_dir)
447
- shutil.copytree(save_dir, best_dir)
448
-
449
- # early stop
450
- early_stop_flag = False
451
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
452
- patience_count = 0
453
- else:
454
- patience_count += 1
455
- if patience_count >= args.patience:
456
- early_stop_flag = True
457
-
458
- # early stop
459
- if early_stop_flag:
460
- break
461
  model.train()
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  return
464
 
465
 
 
318
  # evaluation
319
  step_idx += 1
320
  if step_idx % config.eval_steps == 0:
 
321
  with torch.no_grad():
322
  torch.cuda.empty_cache()
323
 
324
+ model.eval()
325
+
326
  total_pesq_score = 0.
327
  total_loss = 0.
328
  total_mr_stft_loss = 0.
 
385
  "lsnr_loss": average_lsnr_loss,
386
  })
387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  model.train()
389
 
390
+ total_pesq_score = 0.
391
+ total_loss = 0.
392
+ total_mr_stft_loss = 0.
393
+ total_neg_si_snr_loss = 0.
394
+ total_mask_loss = 0.
395
+ total_lsnr_loss = 0.
396
+ total_batches = 0.
397
+
398
+ progress_bar_eval.close()
399
+ progress_bar_train = tqdm(
400
+ initial=progress_bar_train.n,
401
+ postfix=progress_bar_train.postfix,
402
+ desc=progress_bar_train.desc,
403
+ )
404
+
405
+ # save path
406
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
407
+ save_dir.mkdir(parents=True, exist_ok=False)
408
+
409
+ # save models
410
+ model.save_pretrained(save_dir.as_posix())
411
+
412
+ model_list.append(save_dir)
413
+ if len(model_list) >= args.num_serialized_models_to_keep:
414
+ model_to_delete: Path = model_list.pop(0)
415
+ shutil.rmtree(model_to_delete.as_posix())
416
+
417
+ # save metric
418
+ if best_metric is None:
419
+ best_epoch_idx = epoch_idx
420
+ best_step_idx = step_idx
421
+ best_metric = average_pesq_score
422
+ elif average_pesq_score >= best_metric:
423
+ # great is better.
424
+ best_epoch_idx = epoch_idx
425
+ best_step_idx = step_idx
426
+ best_metric = average_pesq_score
427
+ else:
428
+ pass
429
+
430
+ metrics = {
431
+ "epoch_idx": epoch_idx,
432
+ "best_epoch_idx": best_epoch_idx,
433
+ "best_step_idx": best_step_idx,
434
+ "pesq_score": average_pesq_score,
435
+ "loss": average_loss,
436
+ "mr_stft_loss": average_mr_stft_loss,
437
+ "neg_si_snr_loss": average_neg_si_snr_loss,
438
+ "mask_loss": average_mask_loss,
439
+ "lsnr_loss": average_lsnr_loss,
440
+ }
441
+ metrics_filename = save_dir / "metrics_epoch.json"
442
+ with open(metrics_filename, "w", encoding="utf-8") as f:
443
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
444
+
445
+ # save best
446
+ best_dir = serialization_dir / "best"
447
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
448
+ if best_dir.exists():
449
+ shutil.rmtree(best_dir)
450
+ shutil.copytree(save_dir, best_dir)
451
+
452
+ # early stop
453
+ early_stop_flag = False
454
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
455
+ patience_count = 0
456
+ else:
457
+ patience_count += 1
458
+ if patience_count >= args.patience:
459
+ early_stop_flag = True
460
+
461
+ # early stop
462
+ if early_stop_flag:
463
+ break
464
+
465
  return
466
 
467
 
main.py CHANGED
@@ -18,9 +18,11 @@ import numpy as np
18
  import log
19
  from project_settings import environment, project_path, log_directory
20
  from toolbox.os.command import Command
21
- from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
22
- from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
23
  from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet
 
 
 
 
24
 
25
 
26
  log.setup_size_rotating(log_directory=log_directory)
@@ -66,6 +68,18 @@ def shell(cmd: str):
66
 
67
 
68
  denoise_engines = {
 
 
 
 
 
 
 
 
 
 
 
 
69
  "dfnet-nx-dns3": {
70
  "infer_cls": InferenceDfNet,
71
  "kwargs": {
 
18
  import log
19
  from project_settings import environment, project_path, log_directory
20
  from toolbox.os.command import Command
 
 
21
  from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet
22
+ from toolbox.torchaudio.models.dfnet2.inference_dfnet2 import InferenceDfNet2
23
+ from toolbox.torchaudio.models.dtln.inference_dtln import InferenceDTLN
24
+ from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
25
+ from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
26
 
27
 
28
  log.setup_size_rotating(log_directory=log_directory)
 
68
 
69
 
70
  denoise_engines = {
71
+ "dtln-nx-dns3": {
72
+ "infer_cls": InferenceDTLN,
73
+ "kwargs": {
74
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/dtln-nx-dns3.zip").as_posix()
75
+ }
76
+ },
77
+ "dfnet2-nx-dns3": {
78
+ "infer_cls": InferenceDfNet2,
79
+ "kwargs": {
80
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet2-nx-dns3.zip").as_posix()
81
+ }
82
+ },
83
  "dfnet-nx-dns3": {
84
  "infer_cls": InferenceDfNet,
85
  "kwargs": {
toolbox/torchaudio/models/dfnet2/inference_dfnet2.py CHANGED
@@ -20,7 +20,7 @@ from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2PretrainedMod
20
  logger = logging.getLogger("toolbox")
21
 
22
 
23
- class InferenceDfNet(object):
24
  def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
25
  self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
26
  self.device = torch.device(device)
@@ -99,7 +99,7 @@ class InferenceDfNet(object):
99
 
100
  def main():
101
  model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip"
102
- infer_model = InferenceDfNet(model_zip_file)
103
 
104
  sample_rate = 8000
105
  noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
 
20
  logger = logging.getLogger("toolbox")
21
 
22
 
23
+ class InferenceDfNet2(object):
24
  def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
25
  self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
26
  self.device = torch.device(device)
 
99
 
100
  def main():
101
  model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip"
102
+ infer_model = InferenceDfNet2(model_zip_file)
103
 
104
  sample_rate = 8000
105
  noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"