#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel from toolbox.torchaudio.losses.snr import NegativeSISNRLoss from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss from toolbox.torchaudio.losses.perceptual import NegSTOILoss from toolbox.torchaudio.metrics.pesq import run_pesq_score def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=200, type=int) parser.add_argument("--batch_size", default=16, type=int) parser.add_argument("--learning_rate", default=1e-4, type=float) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=1234, type=int) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): clean_audios = list() noisy_audios = list() for sample in batch: # noise_wave: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] clean_audios.append(clean_audio) noisy_audios.append(noisy_audio) clean_audios = torch.stack(clean_audios) noisy_audios = torch.stack(noisy_audios) # assert if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): raise AssertionError("nan or inf in clean_audios") if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): raise AssertionError("nan or inf in noisy_audios") return clean_audios, noisy_audios collate_fn = CollateFunction() def main(): args = get_args() config = ConvTasNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) logger.info(f"set seed: {args.seed}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"GPU available count: {n_gpu}; device: {device}") # datasets train_dataset = DenoiseJsonlDataset( jsonl_file=args.train_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = DenoiseJsonlDataset( jsonl_file=args.valid_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=2, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=2, ) # models logger.info(f"prepare models. config_file: {args.config_file}") model = ConvTasNetPretrainedModel(config).to(device) model.to(device) model.train() # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate) # resume training last_epoch = -1 for epoch_i in serialization_dir.glob("epoch-*"): epoch_i = Path(epoch_i) epoch_idx = epoch_i.stem.split("-")[1] epoch_idx = int(epoch_idx) if epoch_idx > last_epoch: last_epoch = epoch_idx if last_epoch != -1: logger.info(f"resume from epoch-{last_epoch}.") model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt" optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth" logger.info(f"load state dict for model.") with open(model_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for optimizer.") with open(optimizer_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optimizer.load_state_dict(state_dict) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 ) ae_loss_fn = nn.L1Loss(reduction="mean").to(device) neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device) mr_stft_loss_fn = MultiResolutionSTFTLoss( # fft_size_list=[256, 512, 1024], win_size_list=[120, 240, 480], hop_size_list=[25, 50, 100], reduction="mean" ).to(device) # training loop # state average_pesq_score = 1000000000 average_loss = 1000000000 average_ae_loss = 1000000000 average_neg_si_snr_loss = 1000000000 average_neg_stoi_loss = 1000000000 model_list = list() best_idx_epoch = None best_metric = None patience_count = 0 logger.info("training") for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): # train model.train() total_pesq_score = 0. total_loss = 0. total_ae_loss = 0. total_neg_si_snr_loss = 0. total_neg_stoi_loss = 0. total_batches = 0. progress_bar = tqdm( desc="Training; epoch-{}".format(idx_epoch), ) for batch in train_data_loader: clean_audios, noisy_audios = batch clean_audios = clean_audios.to(device) noisy_audios = noisy_audios.to(device) denoise_audios = model.forward(noisy_audios) denoise_audios = torch.squeeze(denoise_audios, dim=1) if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)): raise AssertionError("nan or inf in denoise_audios") ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios) mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb") optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() total_pesq_score += pesq_score total_loss += loss.item() total_ae_loss += ae_loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_neg_stoi_loss += neg_stoi_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_ae_loss = round(total_ae_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4) progress_bar.update(1) progress_bar.set_postfix({ "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "neg_stoi_loss": average_neg_stoi_loss, }) # evaluation model.eval() torch.cuda.empty_cache() total_pesq_score = 0. total_loss = 0. total_ae_loss = 0. total_neg_si_snr_loss = 0. total_neg_stoi_loss = 0. total_batches = 0. progress_bar = tqdm( desc="Evaluation; epoch-{}".format(idx_epoch), ) with torch.no_grad(): for batch in valid_data_loader: clean_audios, noisy_audios = batch clean_audios = clean_audios.to(device) noisy_audios = noisy_audios.to(device) denoise_audios = model.forward(noisy_audios) denoise_audios = torch.squeeze(denoise_audios, dim=1) ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios) mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb") total_pesq_score += pesq_score total_loss += loss.item() total_ae_loss += ae_loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_neg_stoi_loss += neg_stoi_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_ae_loss = round(total_ae_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4) progress_bar.update(1) progress_bar.set_postfix({ "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "neg_stoi_loss": average_neg_stoi_loss, }) # scheduler lr_scheduler.step() # save path epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) epoch_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(epoch_dir.as_posix()) model_list.append(epoch_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save optim torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix()) # save metric if best_metric is None: best_idx_epoch = idx_epoch best_metric = average_loss elif average_loss < best_metric: # great is better. best_idx_epoch = idx_epoch best_metric = average_loss else: pass metrics = { "idx_epoch": idx_epoch, "best_idx_epoch": best_idx_epoch, "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "neg_si_snr_loss": average_neg_si_snr_loss, "neg_stoi_loss": average_neg_stoi_loss, } metrics_filename = epoch_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_idx_epoch == idx_epoch: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(epoch_dir, best_dir) # early stop early_stop_flag = False if best_idx_epoch == idx_epoch: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == "__main__": main()