#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch import torch.nn as nn from torch.utils.data.dataloader import DataLoader import torchaudio from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset from toolbox.torchaudio.metrics.pesq import run_pesq_score from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.jsonl", type=str) parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--batch_size", default=64, type=int) parser.add_argument("--learning_rate", default=1e-3, type=float) parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) parser.add_argument("--patience", default=10, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self, n_fft: int = 512, win_length: int = 200, hop_length: int = 80, window_fn: str = "hamming", irm_beta: float = 1.0, epsilon: float = 1e-8, ): self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.window_fn = window_fn self.irm_beta = irm_beta self.epsilon = epsilon self.stft_mag = torchaudio.transforms.Spectrogram( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, power=1.0, window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) self.stft_complex = torchaudio.transforms.Spectrogram( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, power=None, window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) self.istft = torchaudio.transforms.InverseSpectrogram( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) def __call__(self, batch: List[dict]): mag_noisy_audios = list() pha_noisy_audios = list() irm_gth = list() clean_audios = list() for sample in batch: noise_audio: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] snr_db: float = sample["snr_db"] mag_noise = self.stft_mag.forward(noise_audio) mag_clean = self.stft_mag.forward(clean_audio) stft_noisy = self.stft_complex.forward(noisy_audio) irm_clean = mag_clean / (mag_noise + mag_clean + self.epsilon) irm_clean = torch.pow(irm_clean, self.irm_beta) real = torch.real(stft_noisy) imag = torch.imag(stft_noisy) mag_noisy = torch.sqrt(real ** 2 + imag ** 2) pha_noisy = torch.atan2(imag, real) mag_noisy_audios.append(mag_noisy) pha_noisy_audios.append(pha_noisy) irm_gth.append(irm_clean) clean_audios.append(clean_audio) mag_noisy_audios = torch.stack(mag_noisy_audios) pha_noisy_audios = torch.stack(pha_noisy_audios) irm_gth = torch.stack(irm_gth) clean_audios = torch.stack(clean_audios) # assert if torch.any(torch.isnan(mag_noisy_audios)): raise AssertionError("nan in mag_noisy_audios Tensor") if torch.any(torch.isnan(pha_noisy_audios)): raise AssertionError("nan in pha_noisy_audios Tensor") if torch.any(torch.isnan(irm_gth)): raise AssertionError("nan in irm_gth Tensor") if torch.any(torch.isnan(clean_audios)): raise AssertionError("nan in clean_audios Tensor") return mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios def enhance(self, mag_noisy: torch.Tensor, pha_noisy: torch.Tensor, irm_speech: torch.Tensor): mag_denoise = mag_noisy * irm_speech stft_denoise = mag_denoise * torch.exp((1j * pha_noisy)) denoise = self.istft.forward(stft_denoise) return denoise collate_fn = CollateFunction() def main(): args = get_args() config = LstmConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) logger.info("set seed: {}".format(args.seed)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) # datasets logger.info("prepare datasets") train_dataset = DenoiseJsonlDataset( jsonl_file=args.train_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, # skip=225000, ) valid_dataset = DenoiseJsonlDataset( jsonl_file=args.valid_dataset, expected_sample_rate=config.sample_rate, max_wave_value=32768.0, min_snr_db=config.min_snr_db, max_snr_db=config.max_snr_db, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=config.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=None if platform.system() == "Windows" else 2, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=config.batch_size, # shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, prefetch_factor=None if platform.system() == "Windows" else 2, ) # models logger.info(f"prepare models. config_file: {args.config_file}") model = LstmPretrainedModel( config=config, ) model.to(device) model.train() # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") optimizer = torch.optim.AdamW(model.parameters(), config.lr) # resume training last_step_idx = -1 last_epoch = -1 for step_idx_str in serialization_dir.glob("steps-*"): step_idx_str = Path(step_idx_str) step_idx = step_idx_str.stem.split("-")[1] step_idx = int(step_idx) if step_idx > last_step_idx: last_step_idx = step_idx # last_epoch = 1 if last_step_idx != -1: logger.info(f"resume from steps-{last_step_idx}.") model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" logger.info(f"load state dict for model.") with open(model_pt.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) logger.info(f"load state dict for optimizer.") with open(optimizer_pth.as_posix(), "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) optimizer.load_state_dict(state_dict) if config.lr_scheduler == "CosineAnnealingLR": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, last_epoch=last_epoch, # T_max=10 * config.eval_steps, # eta_min=0.01 * config.lr, **config.lr_scheduler_kwargs, ) elif config.lr_scheduler == "MultiStepLR": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, last_epoch=last_epoch, milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 ) else: raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") mse_loss_fn = nn.MSELoss( reduction="mean", ).to(device) # training loop logger.info("training") average_pesq_score = 1000000000 average_loss = 1000000000 model_list = list() best_epoch_idx = None best_step_idx = None best_metric = None patience_count = 0 step_idx = 0 if last_step_idx == -1 else last_step_idx logger.info("training") early_stop_flag = False for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): if early_stop_flag: break # train model.train() total_pesq_score = 0. total_loss = 0. total_batches = 0. progress_bar_train = tqdm( initial=step_idx, desc="Training; epoch: {}".format(epoch_idx), ) for train_batch in train_data_loader: mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = train_batch mag_noisy_audios = mag_noisy_audios.to(device) pha_noisy_audios = pha_noisy_audios.to(device) irm_gth = irm_gth.to(device) clean_audios = clean_audios.to(device) irm = model.forward(mag_noisy_audios) denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm) loss = mse_loss_fn.forward(irm, irm_gth) denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) optimizer.step() lr_scheduler.step() total_pesq_score += pesq_score total_loss += loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) progress_bar_train.update(1) progress_bar_train.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, }) # evaluation step_idx += 1 if step_idx % config.eval_steps == 0: with torch.no_grad(): torch.cuda.empty_cache() total_pesq_score = 0. total_loss = 0. total_batches = 0. progress_bar_train.close() progress_bar_eval = tqdm( desc="Evaluation; steps-{}k".format(int(step_idx / 1000)), ) for eval_batch in valid_data_loader: mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = eval_batch mag_noisy_audios = mag_noisy_audios.to(device) pha_noisy_audios = pha_noisy_audios.to(device) irm_gth = irm_gth.to(device) clean_audios = clean_audios.to(device) with torch.no_grad(): irm = model.forward(mag_noisy_audios) denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm) loss = mse_loss_fn.forward(irm, irm_gth) denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) optimizer.step() lr_scheduler.step() total_pesq_score += pesq_score total_loss += loss.item() total_batches += 1 average_pesq_score = round(total_pesq_score / total_batches, 4) average_loss = round(total_loss / total_batches, 4) progress_bar_eval.update(1) progress_bar_eval.set_postfix({ "lr": lr_scheduler.get_last_lr()[0], "pesq_score": average_pesq_score, "loss": average_loss, }) total_pesq_score = 0. total_loss = 0. total_batches = 0. progress_bar_eval.close() progress_bar_train = tqdm( initial=progress_bar_train.n, postfix=progress_bar_train.postfix, desc=progress_bar_train.desc, ) # save path epoch_dir = serialization_dir / "epoch-{}".format(epoch_idx) epoch_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(epoch_dir.as_posix()) model_list.append(epoch_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save metric if best_metric is None: best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score elif average_pesq_score >= best_metric: # great is better. best_epoch_idx = epoch_idx best_step_idx = step_idx best_metric = average_pesq_score else: pass metrics = { "epoch_idx": epoch_idx, "best_epoch_idx": best_epoch_idx, "best_step_idx": best_step_idx, "pesq_score": average_pesq_score, "loss": average_loss, } metrics_filename = epoch_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_epoch_idx == epoch_idx: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(epoch_dir, best_dir) # early stop early_stop_flag = False if best_epoch_idx == epoch_idx and best_step_idx == step_idx: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == '__main__': main()