#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader import torchaudio from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset from toolbox.torchaudio.losses.snr import NegativeSISNRLoss from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss from toolbox.torchaudio.metrics.pesq import run_pesq_score from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.jsonl", type=str) parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--batch_size", default=64, type=int) parser.add_argument("--learning_rate", default=1e-3, type=float) parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) parser.add_argument("--patience", default=10, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): clean_audios = list() noisy_audios = list() snr_db_list = list() for sample in batch: # noise_wave: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] clean_audios.append(clean_audio) noisy_audios.append(noisy_audio) clean_audios = torch.stack(clean_audios) noisy_audios = torch.stack(noisy_audios) # assert if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): raise AssertionError("nan or inf in clean_audios") if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): raise AssertionError("nan or inf in noisy_audios") return clean_audios, noisy_audios collate_fn = CollateFunction() def main(): args = get_args() config = LstmConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) logger.info("set seed: {}".format(args.seed)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) # datasets logger.info("prepare datasets") train_dataset = DenoiseJsonlDataset( jsonl_file=args.train_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, # skip=225000, ) valid_dataset = DenoiseJsonlDataset( jsonl_file=args.valid_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=config.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=None if platform.system() == "Windows" else 2, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=config.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=None if platform.system() == "Windows" else 2, ) # models logger.info(f"prepare models. config_file: {args.config_file}") model = LstmPretrainedModel( config=config, ) model.to(device) model.train() # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") optimizer = torch.optim.AdamW(model.parameters(), config.lr) # resume training last_step_idx = -1 last_epoch = -1 for step_idx_str in serialization_dir.glob("steps-*"): step_idx_str = Path(step_idx_str) step_idx = step_idx_str.stem.split("-")[1] step_idx = int(step_idx) if step_idx > last_step_idx: last_step_idx = step_idx # last_epoch = 1 if last_step_idx != -1: logger.info(f"resume from steps-{last_step_idx}.") model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" logger.info(f"load state dict for model.") with open(model_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for optimizer.") with open(optimizer_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optimizer.load_state_dict(state_dict) if config.lr_scheduler == "CosineAnnealingLR": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, last_epoch=last_epoch, # T_max=10 * config.eval_steps, # eta_min=0.01 * config.lr, **config.lr_scheduler_kwargs, ) elif config.lr_scheduler == "MultiStepLR": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, last_epoch=last_epoch, milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 ) else: raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) mr_stft_loss_fn = MultiResolutionSTFTLoss( fft_size_list=[256, 512, 1024], win_size_list=[256, 512, 1024], hop_size_list=[128, 256, 512], factor_sc=1.5, factor_mag=1.0, reduction="mean" ).to(device) # training loop logger.info("training") average_pesq_score = 1000000000 average_loss = 1000000000 average_mr_stft_loss = 1000000000 average_neg_si_snr_loss = 1000000000 model_list = list() best_epoch_idx = None best_step_idx = None best_metric = None patience_count = 0 step_idx = 0 if last_step_idx == -1 else last_step_idx logger.info("training") early_stop_flag = False for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): if early_stop_flag: break # train model.train() total_pesq_score = 0. total_loss = 0. total_mr_stft_loss = 0. total_neg_si_snr_loss = 0. total_batches = 0. progress_bar_train = tqdm( initial=step_idx, desc="Training; epoch: {}".format(epoch_idx), ) for train_batch in train_data_loader: clean_audios, noisy_audios = train_batch clean_audios: torch.Tensor = clean_audios.to(device) noisy_audios: torch.Tensor = noisy_audios.to(device) denoise_audios, _, _ = model.forward(noisy_audios) mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): logger.info(f"find nan or inf in loss.") continue denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) optimizer.step() lr_scheduler.step() total_pesq_score += pesq_score total_loss += loss.item() total_mr_stft_loss += mr_stft_loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) progress_bar_train.update(1) progress_bar_train.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, "mr_stft_loss": average_mr_stft_loss, "neg_si_snr_loss": average_neg_si_snr_loss, }) # evaluation step_idx += 1 if step_idx % config.eval_steps == 0: model.eval() with torch.no_grad(): torch.cuda.empty_cache() total_pesq_score = 0. total_loss = 0. total_mr_stft_loss = 0. total_neg_si_snr_loss = 0. total_batches = 0. progress_bar_train.close() progress_bar_eval = tqdm( desc="Evaluation; steps-{}k".format(int(step_idx / 1000)), ) for eval_batch in valid_data_loader: clean_audios, noisy_audios = eval_batch clean_audios: torch.Tensor = clean_audios.to(device) noisy_audios: torch.Tensor = noisy_audios.to(device) denoise_audios, _, _ = model.forward(noisy_audios) mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): logger.info(f"find nan or inf in loss.") continue denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") total_pesq_score += pesq_score total_loss += loss.item() total_mr_stft_loss += mr_stft_loss.item() total_neg_si_snr_loss += neg_si_snr_loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) progress_bar_eval.update(1) progress_bar_eval.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, "mr_stft_loss": average_mr_stft_loss, "neg_si_snr_loss": average_neg_si_snr_loss, }) total_pesq_score = 0. total_loss = 0. total_mr_stft_loss = 0. total_neg_si_snr_loss = 0. total_batches = 0. progress_bar_eval.close() progress_bar_train = tqdm( initial=progress_bar_train.n, postfix=progress_bar_train.postfix, desc=progress_bar_train.desc, ) # save path epoch_dir = serialization_dir / "epoch-{}".format(epoch_idx) epoch_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(epoch_dir.as_posix()) model_list.append(epoch_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save metric if best_metric is None: best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score elif average_pesq_score >= best_metric: # great is better. best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score else: pass metrics = { "epoch_idx": epoch_idx, "best_epoch_idx": best_epoch_idx, "best_step_idx": best_step_idx, "pesq_score": average_pesq_score, "loss": average_loss, } metrics_filename = epoch_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_epoch_idx == epoch_idx: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(epoch_dir, best_dir) # early stop early_stop_flag = False if best_epoch_idx == epoch_idx and best_step_idx == step_idx: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break model.train() return if __name__ == '__main__': main()