Spaces:
Running
Running
update
Browse files
examples/dfnet2/step_2_train_model.py
CHANGED
@@ -318,10 +318,11 @@ def main():
|
|
318 |
# evaluation
|
319 |
step_idx += 1
|
320 |
if step_idx % config.eval_steps == 0:
|
321 |
-
model.eval()
|
322 |
with torch.no_grad():
|
323 |
torch.cuda.empty_cache()
|
324 |
|
|
|
|
|
325 |
total_pesq_score = 0.
|
326 |
total_loss = 0.
|
327 |
total_mr_stft_loss = 0.
|
@@ -384,82 +385,83 @@ def main():
|
|
384 |
"lsnr_loss": average_lsnr_loss,
|
385 |
})
|
386 |
|
387 |
-
total_pesq_score = 0.
|
388 |
-
total_loss = 0.
|
389 |
-
total_mr_stft_loss = 0.
|
390 |
-
total_neg_si_snr_loss = 0.
|
391 |
-
total_mask_loss = 0.
|
392 |
-
total_lsnr_loss = 0.
|
393 |
-
total_batches = 0.
|
394 |
-
|
395 |
-
progress_bar_eval.close()
|
396 |
-
progress_bar_train = tqdm(
|
397 |
-
initial=progress_bar_train.n,
|
398 |
-
postfix=progress_bar_train.postfix,
|
399 |
-
desc=progress_bar_train.desc,
|
400 |
-
)
|
401 |
-
|
402 |
-
# save path
|
403 |
-
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
404 |
-
save_dir.mkdir(parents=True, exist_ok=False)
|
405 |
-
|
406 |
-
# save models
|
407 |
-
model.save_pretrained(save_dir.as_posix())
|
408 |
-
|
409 |
-
model_list.append(save_dir)
|
410 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
411 |
-
model_to_delete: Path = model_list.pop(0)
|
412 |
-
shutil.rmtree(model_to_delete.as_posix())
|
413 |
-
|
414 |
-
# save metric
|
415 |
-
if best_metric is None:
|
416 |
-
best_epoch_idx = epoch_idx
|
417 |
-
best_step_idx = step_idx
|
418 |
-
best_metric = average_pesq_score
|
419 |
-
elif average_pesq_score >= best_metric:
|
420 |
-
# great is better.
|
421 |
-
best_epoch_idx = epoch_idx
|
422 |
-
best_step_idx = step_idx
|
423 |
-
best_metric = average_pesq_score
|
424 |
-
else:
|
425 |
-
pass
|
426 |
-
|
427 |
-
metrics = {
|
428 |
-
"epoch_idx": epoch_idx,
|
429 |
-
"best_epoch_idx": best_epoch_idx,
|
430 |
-
"best_step_idx": best_step_idx,
|
431 |
-
"pesq_score": average_pesq_score,
|
432 |
-
"loss": average_loss,
|
433 |
-
"mr_stft_loss": average_mr_stft_loss,
|
434 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
435 |
-
"mask_loss": average_mask_loss,
|
436 |
-
"lsnr_loss": average_lsnr_loss,
|
437 |
-
}
|
438 |
-
metrics_filename = save_dir / "metrics_epoch.json"
|
439 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
440 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
441 |
-
|
442 |
-
# save best
|
443 |
-
best_dir = serialization_dir / "best"
|
444 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
445 |
-
if best_dir.exists():
|
446 |
-
shutil.rmtree(best_dir)
|
447 |
-
shutil.copytree(save_dir, best_dir)
|
448 |
-
|
449 |
-
# early stop
|
450 |
-
early_stop_flag = False
|
451 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
452 |
-
patience_count = 0
|
453 |
-
else:
|
454 |
-
patience_count += 1
|
455 |
-
if patience_count >= args.patience:
|
456 |
-
early_stop_flag = True
|
457 |
-
|
458 |
-
# early stop
|
459 |
-
if early_stop_flag:
|
460 |
-
break
|
461 |
model.train()
|
462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
return
|
464 |
|
465 |
|
|
|
318 |
# evaluation
|
319 |
step_idx += 1
|
320 |
if step_idx % config.eval_steps == 0:
|
|
|
321 |
with torch.no_grad():
|
322 |
torch.cuda.empty_cache()
|
323 |
|
324 |
+
model.eval()
|
325 |
+
|
326 |
total_pesq_score = 0.
|
327 |
total_loss = 0.
|
328 |
total_mr_stft_loss = 0.
|
|
|
385 |
"lsnr_loss": average_lsnr_loss,
|
386 |
})
|
387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
model.train()
|
389 |
|
390 |
+
total_pesq_score = 0.
|
391 |
+
total_loss = 0.
|
392 |
+
total_mr_stft_loss = 0.
|
393 |
+
total_neg_si_snr_loss = 0.
|
394 |
+
total_mask_loss = 0.
|
395 |
+
total_lsnr_loss = 0.
|
396 |
+
total_batches = 0.
|
397 |
+
|
398 |
+
progress_bar_eval.close()
|
399 |
+
progress_bar_train = tqdm(
|
400 |
+
initial=progress_bar_train.n,
|
401 |
+
postfix=progress_bar_train.postfix,
|
402 |
+
desc=progress_bar_train.desc,
|
403 |
+
)
|
404 |
+
|
405 |
+
# save path
|
406 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
407 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
408 |
+
|
409 |
+
# save models
|
410 |
+
model.save_pretrained(save_dir.as_posix())
|
411 |
+
|
412 |
+
model_list.append(save_dir)
|
413 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
414 |
+
model_to_delete: Path = model_list.pop(0)
|
415 |
+
shutil.rmtree(model_to_delete.as_posix())
|
416 |
+
|
417 |
+
# save metric
|
418 |
+
if best_metric is None:
|
419 |
+
best_epoch_idx = epoch_idx
|
420 |
+
best_step_idx = step_idx
|
421 |
+
best_metric = average_pesq_score
|
422 |
+
elif average_pesq_score >= best_metric:
|
423 |
+
# great is better.
|
424 |
+
best_epoch_idx = epoch_idx
|
425 |
+
best_step_idx = step_idx
|
426 |
+
best_metric = average_pesq_score
|
427 |
+
else:
|
428 |
+
pass
|
429 |
+
|
430 |
+
metrics = {
|
431 |
+
"epoch_idx": epoch_idx,
|
432 |
+
"best_epoch_idx": best_epoch_idx,
|
433 |
+
"best_step_idx": best_step_idx,
|
434 |
+
"pesq_score": average_pesq_score,
|
435 |
+
"loss": average_loss,
|
436 |
+
"mr_stft_loss": average_mr_stft_loss,
|
437 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
438 |
+
"mask_loss": average_mask_loss,
|
439 |
+
"lsnr_loss": average_lsnr_loss,
|
440 |
+
}
|
441 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
442 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
443 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
444 |
+
|
445 |
+
# save best
|
446 |
+
best_dir = serialization_dir / "best"
|
447 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
448 |
+
if best_dir.exists():
|
449 |
+
shutil.rmtree(best_dir)
|
450 |
+
shutil.copytree(save_dir, best_dir)
|
451 |
+
|
452 |
+
# early stop
|
453 |
+
early_stop_flag = False
|
454 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
455 |
+
patience_count = 0
|
456 |
+
else:
|
457 |
+
patience_count += 1
|
458 |
+
if patience_count >= args.patience:
|
459 |
+
early_stop_flag = True
|
460 |
+
|
461 |
+
# early stop
|
462 |
+
if early_stop_flag:
|
463 |
+
break
|
464 |
+
|
465 |
return
|
466 |
|
467 |
|
main.py
CHANGED
@@ -18,9 +18,11 @@ import numpy as np
|
|
18 |
import log
|
19 |
from project_settings import environment, project_path, log_directory
|
20 |
from toolbox.os.command import Command
|
21 |
-
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
|
22 |
-
from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
|
23 |
from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
log.setup_size_rotating(log_directory=log_directory)
|
@@ -66,6 +68,18 @@ def shell(cmd: str):
|
|
66 |
|
67 |
|
68 |
denoise_engines = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
"dfnet-nx-dns3": {
|
70 |
"infer_cls": InferenceDfNet,
|
71 |
"kwargs": {
|
|
|
18 |
import log
|
19 |
from project_settings import environment, project_path, log_directory
|
20 |
from toolbox.os.command import Command
|
|
|
|
|
21 |
from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet
|
22 |
+
from toolbox.torchaudio.models.dfnet2.inference_dfnet2 import InferenceDfNet2
|
23 |
+
from toolbox.torchaudio.models.dtln.inference_dtln import InferenceDTLN
|
24 |
+
from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
|
25 |
+
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
|
26 |
|
27 |
|
28 |
log.setup_size_rotating(log_directory=log_directory)
|
|
|
68 |
|
69 |
|
70 |
denoise_engines = {
|
71 |
+
"dtln-nx-dns3": {
|
72 |
+
"infer_cls": InferenceDTLN,
|
73 |
+
"kwargs": {
|
74 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dtln-nx-dns3.zip").as_posix()
|
75 |
+
}
|
76 |
+
},
|
77 |
+
"dfnet2-nx-dns3": {
|
78 |
+
"infer_cls": InferenceDfNet2,
|
79 |
+
"kwargs": {
|
80 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet2-nx-dns3.zip").as_posix()
|
81 |
+
}
|
82 |
+
},
|
83 |
"dfnet-nx-dns3": {
|
84 |
"infer_cls": InferenceDfNet,
|
85 |
"kwargs": {
|
toolbox/torchaudio/models/dfnet2/inference_dfnet2.py
CHANGED
@@ -20,7 +20,7 @@ from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2PretrainedMod
|
|
20 |
logger = logging.getLogger("toolbox")
|
21 |
|
22 |
|
23 |
-
class
|
24 |
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
25 |
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
26 |
self.device = torch.device(device)
|
@@ -99,7 +99,7 @@ class InferenceDfNet(object):
|
|
99 |
|
100 |
def main():
|
101 |
model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip"
|
102 |
-
infer_model =
|
103 |
|
104 |
sample_rate = 8000
|
105 |
noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
|
|
|
20 |
logger = logging.getLogger("toolbox")
|
21 |
|
22 |
|
23 |
+
class InferenceDfNet2(object):
|
24 |
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
25 |
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
26 |
self.device = torch.device(device)
|
|
|
99 |
|
100 |
def main():
|
101 |
model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip"
|
102 |
+
infer_model = InferenceDfNet2(model_zip_file)
|
103 |
|
104 |
sample_rate = 8000
|
105 |
noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
|