#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/kaituoxu/Conv-TasNet/tree/master/src 一般场景: 目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。 高要求场景(如医疗助听、语音识别): 需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。 DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。 https://arxiv.org/abs/2205.05474 """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel from toolbox.torchaudio.losses.snr import NegativeSISNRLoss from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss from toolbox.torchaudio.metrics.pesq import run_pesq_score def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=200, type=int) parser.add_argument("--batch_size", default=8, type=int) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=1234, type=int) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): clean_audios = list() noisy_audios = list() for sample in batch: # noise_wave: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] clean_audios.append(clean_audio) noisy_audios.append(noisy_audio) clean_audios = torch.stack(clean_audios) noisy_audios = torch.stack(noisy_audios) # assert if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): raise AssertionError("nan or inf in clean_audios") if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): raise AssertionError("nan or inf in noisy_audios") return clean_audios, noisy_audios collate_fn = CollateFunction() def main(): args = get_args() config = ConvTasNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) logger.info(f"set seed: {args.seed}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"GPU available count: {n_gpu}; device: {device}") # datasets train_dataset = DenoiseJsonlDataset( jsonl_file=args.train_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, # skip=825000, ) valid_dataset = DenoiseJsonlDataset( jsonl_file=args.valid_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=2, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=2, ) # models logger.info(f"prepare models. config_file: {args.config_file}") model = ConvTasNetPretrainedModel(config).to(device) model.to(device) model.train() # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") optimizer = torch.optim.AdamW(model.parameters(), config.lr) # resume training last_step_idx = -1 last_epoch = -1 for step_idx_str in serialization_dir.glob("steps-*"): step_idx_str = Path(step_idx_str) step_idx = step_idx_str.stem.split("-")[1] step_idx = int(step_idx) if step_idx > last_step_idx: last_step_idx = step_idx if last_step_idx != -1: logger.info(f"resume from steps-{last_step_idx}.") model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" logger.info(f"load state dict for model.") with open(model_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for optimizer.") with open(optimizer_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optimizer.load_state_dict(state_dict) if config.lr_scheduler == "CosineAnnealingLR": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, last_epoch=last_epoch, # T_max=10 * config.eval_steps, # eta_min=0.01 * config.lr, **config.lr_scheduler_kwargs, ) elif config.lr_scheduler == "MultiStepLR": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, last_epoch=last_epoch, milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 ) else: raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") ae_loss_fn = nn.L1Loss(reduction="mean").to(device) neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device) mr_stft_loss_fn = MultiResolutionSTFTLoss( fft_size_list=[256, 512, 1024], win_size_list=[120, 240, 480], hop_size_list=[25, 50, 100], factor_sc=1.5, factor_mag=1.0, reduction="mean" ).to(device) pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device) # training loop # state average_pesq_score = 1000000000 average_loss = 1000000000 average_ae_loss = 1000000000 average_neg_si_snr_loss = 1000000000 average_neg_stoi_loss = 1000000000 model_list = list() best_epoch_idx = None best_step_idx = None best_metric = None patience_count = 0 step_idx = 0 if last_step_idx == -1 else last_step_idx logger.info("training") for epoch_idx in range(max(0, last_epoch+1), args.max_epochs): # train model.train() total_pesq_score = 0. total_loss = 0. total_ae_loss = 0. total_neg_si_snr_loss = 0. total_neg_stoi_loss = 0. total_mr_stft_loss = 0. total_pesq_loss = 0. total_batches = 0. progress_bar_train = tqdm( initial=step_idx, desc="Training; epoch-{}".format(epoch_idx), ) for train_batch in train_data_loader: clean_audios, noisy_audios = train_batch clean_audios: torch.Tensor = clean_audios.to(device) noisy_audios: torch.Tensor = noisy_audios.to(device) denoise_audios = model.forward(noisy_audios) denoise_audios = torch.squeeze(denoise_audios, dim=1) if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)): raise AssertionError("nan or inf in denoise_audios") ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios) mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios) # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): logger.info(f"find nan or inf in loss.") continue denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() total_pesq_score += pesq_score total_loss += loss.item() total_ae_loss += ae_loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_neg_stoi_loss += neg_stoi_loss.item() total_mr_stft_loss += mr_stft_loss.item() total_pesq_loss += pesq_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_ae_loss = round(total_ae_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4) average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) average_pesq_loss = round(total_pesq_loss / total_batches, 4) progress_bar_train.update(1) progress_bar_train.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "neg_stoi_loss": average_neg_stoi_loss, "mr_stft_loss": average_mr_stft_loss, "pesq_loss": average_pesq_loss, }) # evaluation step_idx += 1 if step_idx % config.eval_steps == 0: with torch.no_grad(): torch.cuda.empty_cache() total_pesq_score = 0. total_loss = 0. total_ae_loss = 0. total_neg_si_snr_loss = 0. total_neg_stoi_loss = 0. total_mr_stft_loss = 0. total_pesq_loss = 0. total_batches = 0. progress_bar_train.close() progress_bar_eval = tqdm( desc="Evaluation; steps-{}k".format(int(step_idx/1000)), ) for eval_batch in valid_data_loader: clean_audios, noisy_audios = eval_batch clean_audios = clean_audios.to(device) noisy_audios = noisy_audios.to(device) denoise_audios = model.forward(noisy_audios) denoise_audios = torch.squeeze(denoise_audios, dim=1) ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios) mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios) # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): logger.info(f"find nan or inf in loss.") continue denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") total_pesq_score += pesq_score total_loss += loss.item() total_ae_loss += ae_loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_neg_stoi_loss += neg_stoi_loss.item() total_mr_stft_loss += mr_stft_loss.item() total_pesq_loss += pesq_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_ae_loss = round(total_ae_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4) average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) average_pesq_loss = round(total_pesq_loss / total_batches, 4) progress_bar_eval.update(1) progress_bar_eval.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "neg_stoi_loss": average_neg_stoi_loss, "mr_stft_loss": average_mr_stft_loss, "pesq_loss": average_pesq_loss, }) total_pesq_score = 0. total_loss = 0. total_ae_loss = 0. total_neg_si_snr_loss = 0. total_neg_stoi_loss = 0. total_mr_stft_loss = 0. total_pesq_loss = 0. total_batches = 0. progress_bar_eval.close() progress_bar_train = tqdm( initial=progress_bar_train.n, postfix=progress_bar_train.postfix, desc=progress_bar_train.desc, ) # save path save_dir = serialization_dir / "steps-{}".format(step_idx) save_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(save_dir.as_posix()) model_list.append(save_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save optim torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix()) # save metric if best_metric is None: best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score elif average_pesq_score > best_metric: # great is better. best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score else: pass metrics = { "epoch_idx": epoch_idx, "best_epoch_idx": best_epoch_idx, "best_step_idx": best_step_idx, "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "neg_stoi_loss": average_neg_stoi_loss, "mr_stft_loss": average_mr_stft_loss, "pesq_loss": average_pesq_loss, } metrics_filename = save_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_epoch_idx == epoch_idx and best_step_idx == step_idx: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(save_dir, best_dir) # early stop early_stop_flag = False if best_epoch_idx == epoch_idx and best_step_idx == step_idx: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == "__main__": main()