#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader import torchaudio from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset from toolbox.torchaudio.models.spectrum_dfnet.configuration_spectrum_dfnet import SpectrumDfNetConfig from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--batch_size", default=16, type=int) parser.add_argument("--learning_rate", default=1e-4, type=float) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self, n_fft: int = 512, win_length: int = 200, hop_length: int = 80, window_fn: str = "hamming", irm_beta: float = 1.0, epsilon: float = 1e-8, ): self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.window_fn = window_fn self.irm_beta = irm_beta self.epsilon = epsilon self.complex_transform = torchaudio.transforms.Spectrogram( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, power=None, window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) self.transform = torchaudio.transforms.Spectrogram( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, power=2.0, window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) @staticmethod def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3): batch_size, channels, freq_dim, time_steps = x.shape # kernel: [freq_dim, n_time_step] kernel_size = (freq_dim, n_time_steps) # pad pad = n_time_steps // 2 x = torch.concat(tensors=[ x[:, :, :, :pad], x, x[:, :, :, -pad:], ], dim=-1) x = F.unfold( input=x, kernel_size=kernel_size, ) # x shape: [batch_size, fold, time_steps] return x def __call__(self, batch: List[dict]): speech_complex_spec_list = list() mix_complex_spec_list = list() speech_irm_list = list() snr_db_list = list() for sample in batch: noise_wave: torch.Tensor = sample["noise_wave"] speech_wave: torch.Tensor = sample["speech_wave"] mix_wave: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] noise_spec = self.transform.forward(noise_wave) speech_spec = self.transform.forward(speech_wave) speech_complex_spec = self.complex_transform.forward(speech_wave) mix_complex_spec = self.complex_transform.forward(mix_wave) # noise_irm = noise_spec / (noise_spec + speech_spec) speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) speech_irm = torch.pow(speech_irm, self.irm_beta) # noise_spec, speech_spec, mix_spec, speech_irm # shape: [freq_dim, time_steps] snr_db: torch.Tensor = 10 * torch.log10( speech_spec / (noise_spec + self.epsilon) ) snr_db = torch.clamp(snr_db, min=self.epsilon) snr_db_ = torch.unsqueeze(snr_db, dim=0) snr_db_ = torch.unsqueeze(snr_db_, dim=0) snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3) snr_db_ = torch.squeeze(snr_db_, dim=0) # snr_db_ shape: [fold, time_steps] snr_db = torch.mean(snr_db_, dim=0, keepdim=True) # snr_db shape: [1, time_steps] speech_complex_spec_list.append(speech_complex_spec) mix_complex_spec_list.append(mix_complex_spec) speech_irm_list.append(speech_irm) snr_db_list.append(snr_db) speech_complex_spec_list = torch.stack(speech_complex_spec_list) mix_complex_spec_list = torch.stack(mix_complex_spec_list) speech_irm_list = torch.stack(speech_irm_list) snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1) speech_complex_spec_list = speech_complex_spec_list[:, :-1, :] mix_complex_spec_list = mix_complex_spec_list[:, :-1, :] speech_irm_list = speech_irm_list[:, :-1, :] # speech_complex_spec_list shape: [batch_size, freq_dim, time_steps] # mix_complex_spec_list shape: [batch_size, freq_dim, time_steps] # speech_irm_list shape: [batch_size, freq_dim, time_steps] # snr_db shape: [batch_size, 1, time_steps] # assert if torch.any(torch.isnan(speech_complex_spec_list)) or torch.any(torch.isinf(speech_complex_spec_list)): raise AssertionError("nan or inf in speech_complex_spec_list") if torch.any(torch.isnan(mix_complex_spec_list)) or torch.any(torch.isinf(mix_complex_spec_list)): raise AssertionError("nan or inf in mix_complex_spec_list") if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)): raise AssertionError("nan or inf in speech_irm_list") if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)): raise AssertionError("nan or inf in snr_db_list") return speech_complex_spec_list, mix_complex_spec_list, speech_irm_list, snr_db_list collate_fn = CollateFunction() def main(): args = get_args() serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) logger.info("set seed: {}".format(args.seed)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) # datasets logger.info("prepare datasets") train_dataset = DenoiseExcelDataset( excel_file=args.train_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = DenoiseExcelDataset( excel_file=args.valid_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, shuffle=True, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) # models logger.info(f"prepare models. config_file: {args.config_file}") config = SpectrumDfNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, # num_labels=vocabulary.get_vocab_size(namespace="labels") ) model = SpectrumDfNetPretrainedModel( config=config, ) model.to(device) model.train() # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") param_optimizer = model.parameters() optimizer = torch.optim.Adam( param_optimizer, lr=args.learning_rate, ) # lr_scheduler = torch.optim.lr_scheduler.StepLR( # optimizer, # step_size=2000 # ) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 ) speech_mse_loss = nn.MSELoss( reduction="mean", ) irm_mse_loss = nn.MSELoss( reduction="mean", ) snr_mse_loss = nn.MSELoss( reduction="mean", ) # training loop logger.info("training") training_loss = 10000000000 evaluation_loss = 10000000000 model_list = list() best_idx_epoch = None best_metric = None patience_count = 0 for idx_epoch in range(args.max_epochs): total_loss = 0. total_examples = 0. progress_bar = tqdm( total=len(train_data_loader), desc="Training; epoch: {}".format(idx_epoch), ) for batch in train_data_loader: speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch speech_complex_spec = speech_complex_spec.to(device) mix_complex_spec = mix_complex_spec.to(device) speech_irm_target = speech_irm.to(device) snr_db_target = snr_db.to(device) speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec) if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)): raise AssertionError("nan or inf in speech_spec_prediction") if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): raise AssertionError("nan or inf in speech_irm_prediction") if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): raise AssertionError("nan or inf in lsnr_prediction") speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec)) irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) loss = speech_loss + irm_loss + snr_loss total_loss += loss.item() total_examples += mix_complex_spec.size(0) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() training_loss = total_loss / total_examples training_loss = round(training_loss, 4) progress_bar.update(1) progress_bar.set_postfix({ "training_loss": training_loss, }) total_loss = 0. total_examples = 0. progress_bar = tqdm( total=len(valid_data_loader), desc="Evaluation; epoch: {}".format(idx_epoch), ) for batch in valid_data_loader: speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch speech_complex_spec = speech_complex_spec.to(device) mix_complex_spec = mix_complex_spec.to(device) speech_irm_target = speech_irm.to(device) snr_db_target = snr_db.to(device) with torch.no_grad(): speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec) if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)): raise AssertionError("nan or inf in speech_spec_prediction") if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): raise AssertionError("nan or inf in speech_irm_prediction") if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): raise AssertionError("nan or inf in lsnr_prediction") speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec)) irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) loss = speech_loss + irm_loss + snr_loss total_loss += loss.item() total_examples += mix_complex_spec.size(0) evaluation_loss = total_loss / total_examples evaluation_loss = round(evaluation_loss, 4) progress_bar.update(1) progress_bar.set_postfix({ "evaluation_loss": evaluation_loss, }) # save path epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) epoch_dir.mkdir(parents=True, exist_ok=False) # save models model.save_pretrained(epoch_dir.as_posix()) model_list.append(epoch_dir) if len(model_list) >= args.num_serialized_models_to_keep: model_to_delete: Path = model_list.pop(0) shutil.rmtree(model_to_delete.as_posix()) # save metric if best_metric is None: best_idx_epoch = idx_epoch best_metric = evaluation_loss elif evaluation_loss < best_metric: best_idx_epoch = idx_epoch best_metric = evaluation_loss else: pass metrics = { "idx_epoch": idx_epoch, "best_idx_epoch": best_idx_epoch, "training_loss": training_loss, "evaluation_loss": evaluation_loss, "learning_rate": optimizer.param_groups[0]["lr"], } metrics_filename = epoch_dir / "metrics_epoch.json" with open(metrics_filename, "w", encoding="utf-8") as f: json.dump(metrics, f, indent=4, ensure_ascii=False) # save best best_dir = serialization_dir / "best" if best_idx_epoch == idx_epoch: if best_dir.exists(): shutil.rmtree(best_dir) shutil.copytree(epoch_dir, best_dir) # early stop early_stop_flag = False if best_idx_epoch == idx_epoch: patience_count = 0 else: patience_count += 1 if patience_count >= args.patience: early_stop_flag = True # early stop if early_stop_flag: break return if __name__ == '__main__': main()