#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset from toolbox.torchaudio.losses.snr import NegativeSISNRLoss from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss from toolbox.torchaudio.metrics.pesq import run_pesq_score from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): clean_audios = list() noisy_audios = list() snr_db_list = list() for sample in batch: # noise_wave: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] snr_db: float = sample["snr_db"] clean_audios.append(clean_audio) noisy_audios.append(noisy_audio) snr_db_list.append(snr_db) clean_audios = torch.stack(clean_audios) noisy_audios = torch.stack(noisy_audios) snr_db_list = torch.stack(snr_db_list) # assert if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): raise AssertionError("nan or inf in clean_audios") if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): raise AssertionError("nan or inf in noisy_audios") return clean_audios, noisy_audios, snr_db_list collate_fn = CollateFunction() def main(): args = get_args() config = DfNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) logger.info(f"set seed: {config.seed}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"GPU available count: {n_gpu}; device: {device}") # datasets train_dataset = DenoiseJsonlDataset( jsonl_file=args.train_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, # skip=225000, ) valid_dataset = DenoiseJsonlDataset( jsonl_file=args.valid_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=config.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=2, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=config.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=2, ) # models logger.info(f"prepare models. config_file: {args.config_file}") model = DfNetPretrainedModel(config).to(device) model.to(device) model.train() # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") optimizer = torch.optim.AdamW(model.named_parameters(), config.lr) # resume training last_step_idx = -1 last_epoch = -1 for step_idx_str in serialization_dir.glob("steps-*"): step_idx_str = Path(step_idx_str) step_idx = step_idx_str.stem.split("-")[1] step_idx = int(step_idx) if step_idx > last_step_idx: last_step_idx = step_idx # last_epoch = 1 if last_step_idx != -1: logger.info(f"resume from steps-{last_step_idx}.") model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" logger.info(f"load state dict for model.") with open(model_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for optimizer.") with open(optimizer_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optimizer.load_state_dict(state_dict) if config.lr_scheduler == "CosineAnnealingLR": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, last_epoch=last_epoch, # T_max=10 * config.eval_steps, # eta_min=0.01 * config.lr, **config.lr_scheduler_kwargs, ) elif config.lr_scheduler == "MultiStepLR": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, last_epoch=last_epoch, milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 ) else: raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) mr_stft_loss_fn = MultiResolutionSTFTLoss( fft_size_list=[256, 512, 1024], win_size_list=[256, 512, 1024], hop_size_list=[128, 256, 512], factor_sc=1.5, factor_mag=1.0, reduction="mean" ).to(device) lsnr_loss_fn = nn.L1Loss(reduction="mean") # training loop # state average_pesq_score = 1000000000 average_loss = 1000000000 average_neg_si_snr_loss = 1000000000 average_mask_loss = 1000000000 model_list = list() best_epoch_idx = None best_step_idx = None best_metric = None patience_count = 0 step_idx = 0 if last_step_idx == -1 else last_step_idx logger.info("training") for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): # train model.train() total_pesq_score = 0. total_loss = 0. total_neg_si_snr_loss = 0. total_mask_loss = 0. total_batches = 0. progress_bar_train = tqdm( initial=step_idx, desc="Training; epoch-{}".format(epoch_idx), ) for train_batch in train_data_loader: clean_audios, noisy_audios, snr_db_list = train_batch clean_audios: torch.Tensor = clean_audios.to(device) noisy_audios: torch.Tensor = noisy_audios.to(device) snr_db_list: torch.Tensor = snr_db_list.to(device) est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) # mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) # neg_si_snr_loss = lsnr_loss_fn.forward(lsnr, snr_db_list) loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): logger.info(f"find nan or inf in loss.") continue denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) optimizer.step() lr_scheduler.step() total_pesq_score += pesq_score total_loss += loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_mask_loss += mask_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) average_mask_loss = round(total_mask_loss / total_batches, 4) progress_bar_train.update(1) progress_bar_train.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "mask_loss": average_mask_loss, }) # evaluation step_idx += 1 if step_idx % config.eval_steps == 0: with torch.no_grad(): torch.cuda.empty_cache() total_pesq_score = 0. total_loss = 0. total_neg_si_snr_loss = 0. total_mask_loss = 0. total_batches = 0. progress_bar_train.close() progress_bar_eval = tqdm( desc="Evaluation; steps-{}k".format(int(step_idx/1000)), ) for eval_batch in valid_data_loader: clean_audios, noisy_audios, snr_db_list = eval_batch clean_audios: torch.Tensor = clean_audios.to(device) noisy_audios: torch.Tensor = noisy_audios.to(device) snr_db_list: torch.Tensor = snr_db_list.to(device) est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): logger.info(f"find nan or inf in loss.") continue denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") total_pesq_score += pesq_score total_loss += loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_mask_loss += mask_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) average_mask_loss = round(total_mask_loss / total_batches, 4) progress_bar_eval.update(1) progress_bar_eval.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "mask_loss": average_mask_loss, }) total_pesq_score = 0. total_loss = 0. total_neg_si_snr_loss = 0. total_mask_loss = 0. total_batches = 0. progress_bar_eval.close() progress_bar_train = tqdm( initial=progress_bar_train.n, postfix=progress_bar_train.postfix, desc=progress_bar_train.desc, ) # save path save_dir = serialization_dir / "steps-{}".format(step_idx) save_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(save_dir.as_posix()) model_list.append(save_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save optim torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix()) # save metric if best_metric is None: best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score elif average_pesq_score > best_metric: # great is better. best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score else: pass metrics = { "epoch_idx": epoch_idx, "best_epoch_idx": best_epoch_idx, "best_step_idx": best_step_idx, "pesq_score": average_pesq_score, "loss": average_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "mask_loss": average_mask_loss, } metrics_filename = save_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_epoch_idx == epoch_idx and best_step_idx == step_idx: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(save_dir, best_dir) # early stop early_stop_flag = False if best_epoch_idx == epoch_idx and best_step_idx == step_idx: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == "__main__": main()