#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/NVIDIA/CleanUNet/blob/main/train.py https://github.com/NVIDIA/CleanUNet/blob/main/configs/DNS-large-full.json """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss from toolbox.torchaudio.models.clean_unet.metrics import run_pesq_score torch.autograd.set_detect_anomaly(True) def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--batch_size", default=64, type=int) parser.add_argument("--learning_rate", default=2e-4, type=float) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): clean_audios = list() noisy_audios = list() for sample in batch: # noise_wave: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] clean_audios.append(clean_audio) noisy_audios.append(noisy_audio) clean_audios = torch.stack(clean_audios) noisy_audios = torch.stack(noisy_audios) # assert if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): raise AssertionError("nan or inf in clean_audios") if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): raise AssertionError("nan or inf in noisy_audios") return clean_audios, noisy_audios collate_fn = CollateFunction() def main(): args = get_args() config = CleanUNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) logger.info(f"set seed: {args.seed}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"GPU available count: {n_gpu}; device: {device}") # datasets train_dataset = DenoiseExcelDataset( excel_file=args.train_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = DenoiseExcelDataset( excel_file=args.valid_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) # models logger.info(f"prepare models. config_file: {args.config_file}") model = CleanUNetPretrainedModel(config).to(device) # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate) # resume training last_epoch = -1 for epoch_i in serialization_dir.glob("epoch-*"): epoch_i = Path(epoch_i) epoch_idx = epoch_i.stem.split("-")[1] epoch_idx = int(epoch_idx) if epoch_idx > last_epoch: last_epoch = epoch_idx if last_epoch != -1: logger.info(f"resume from epoch-{last_epoch}.") model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt" optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth" logger.info(f"load state dict for model.") with open(model_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for optimizer.") with open(optimizer_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optimizer.load_state_dict(state_dict) lr_scheduler = LinearWarmupCosineDecay( optimizer, lr_max=args.learning_rate, n_iter=250000, iteration=250000, divider=25, warmup_proportion=0.05, phase=("linear", "cosine"), ) # ae_loss_fn = nn.MSELoss(reduction="mean") ae_loss_fn = nn.L1Loss(reduction="mean").to(device) mr_stft_loss_fn = MultiResolutionSTFTLoss( fft_sizes=[256, 512, 1024], hop_sizes=[25, 50, 120], win_lengths=[120, 240, 600], sc_lambda=0.5, mag_lambda=0.5, band="full" ).to(device) # training loop # state average_pesq_score = 10000000000 average_loss = 10000000000 average_ae_loss = 10000000000 average_sc_loss = 10000000000 average_mag_loss = 10000000000 model_list = list() best_idx_epoch = None best_metric = None patience_count = 0 logger.info("training") for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): # train model.train() total_pesq_score = 0. total_loss = 0. total_ae_loss = 0. total_sc_loss = 0. total_mag_loss = 0. total_batches = 0. progress_bar = tqdm( total=len(train_data_loader), desc="Training; epoch: {}".format(idx_epoch), ) for batch in train_data_loader: clean_audios, noisy_audios = batch clean_audios = clean_audios.to(device) noisy_audios = noisy_audios.to(device) enhanced_audios = model.forward(noisy_audios) enhanced_audios = torch.squeeze(enhanced_audios, dim=1) ae_loss = ae_loss_fn(enhanced_audios, clean_audios) sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios) loss = ae_loss + sc_loss + mag_loss enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb") optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() total_pesq_score += pesq_score total_loss += loss.item() total_ae_loss += ae_loss.item() total_sc_loss += sc_loss.item() total_mag_loss += mag_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_ae_loss = round(total_ae_loss / total_batches, 4) average_sc_loss = round(total_sc_loss / total_batches, 4) average_mag_loss = round(total_mag_loss / total_batches, 4) progress_bar.update(1) progress_bar.set_postfix({ "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "sc_loss": average_sc_loss, "mag_loss": average_mag_loss, }) # evaluation model.eval() torch.cuda.empty_cache() total_pesq_score = 0. total_loss = 0. total_ae_loss = 0. total_sc_loss = 0. total_mag_loss = 0. total_batches = 0. progress_bar = tqdm( total=len(valid_data_loader), desc="Evaluation; epoch: {}".format(idx_epoch), ) with torch.no_grad(): for batch in valid_data_loader: clean_audios, noisy_audios = batch clean_audios = clean_audios.to(device) noisy_audios = noisy_audios.to(device) enhanced_audios = model.forward(noisy_audios) enhanced_audios = torch.squeeze(enhanced_audios, dim=1) ae_loss = ae_loss_fn(enhanced_audios, clean_audios) sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios) loss = ae_loss + sc_loss + mag_loss enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb") total_pesq_score += pesq_score total_loss += loss.item() total_ae_loss += ae_loss.item() total_sc_loss += sc_loss.item() total_mag_loss += mag_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_ae_loss = round(total_ae_loss / total_batches, 4) average_sc_loss = round(total_sc_loss / total_batches, 4) average_mag_loss = round(total_mag_loss / total_batches, 4) progress_bar.update(1) progress_bar.set_postfix({ "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "sc_loss": average_sc_loss, "mag_loss": average_mag_loss, }) # scheduler lr_scheduler.step() # save path epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) epoch_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(epoch_dir.as_posix()) model_list.append(epoch_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save optim torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix()) # save metric if best_metric is None: best_idx_epoch = idx_epoch best_metric = average_pesq_score elif average_pesq_score > best_metric: # great is better. best_idx_epoch = idx_epoch best_metric = average_pesq_score else: pass metrics = { "idx_epoch": idx_epoch, "best_idx_epoch": best_idx_epoch, "pesq_score": average_pesq_score, "loss": average_loss, "ae_loss": average_ae_loss, "sc_loss": average_sc_loss, "mag_loss": average_mag_loss, } metrics_filename = epoch_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_idx_epoch == idx_epoch: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(epoch_dir, best_dir) # early stop early_stop_flag = False if best_idx_epoch == idx_epoch: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == "__main__": main()