#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch from torch.distributed import init_process_group import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel import torch.nn as nn from torch.nn import functional as F from torch.utils.data import DistributedSampler from torch.utils.data.dataloader import DataLoader import torchaudio from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminator, batch_pesq from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--batch_size", default=64, type=int) parser.add_argument("--learning_rate", default=1e-4, type=float) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--seed", default=0, type=int) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self, n_fft: int = 512, win_length: int = 200, hop_length: int = 80, window_fn: str = "hamming", irm_beta: float = 1.0, epsilon: float = 1e-8, ): self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.window_fn = window_fn self.irm_beta = irm_beta self.epsilon = epsilon self.transform = torchaudio.transforms.Spectrogram( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, power=2.0, window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) @staticmethod def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3): batch_size, channels, freq_dim, time_steps = x.shape # kernel: [freq_dim, n_time_step] kernel_size = (freq_dim, n_time_steps) # pad pad = n_time_steps // 2 x = torch.concat(tensors=[ x[:, :, :, :pad], x, x[:, :, :, -pad:], ], dim=-1) x = F.unfold( input=x, kernel_size=kernel_size, ) # x shape: [batch_size, fold, time_steps] return x def __call__(self, batch: List[dict]): mix_spec_list = list() speech_irm_list = list() snr_db_list = list() for sample in batch: noise_wave: torch.Tensor = sample["noise_wave"] speech_wave: torch.Tensor = sample["speech_wave"] mix_wave: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] noise_spec = self.transform.forward(noise_wave) speech_spec = self.transform.forward(speech_wave) mix_spec = self.transform.forward(mix_wave) # noise_irm = noise_spec / (noise_spec + speech_spec) speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) speech_irm = torch.pow(speech_irm, self.irm_beta) # noise_spec, speech_spec, mix_spec, speech_irm # shape: [freq_dim, time_steps] snr_db: torch.Tensor = 10 * torch.log10( speech_spec / (noise_spec + self.epsilon) ) snr_db = torch.clamp(snr_db, min=self.epsilon) snr_db_ = torch.unsqueeze(snr_db, dim=0) snr_db_ = torch.unsqueeze(snr_db_, dim=0) snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3) snr_db_ = torch.squeeze(snr_db_, dim=0) # snr_db_ shape: [fold, time_steps] snr_db = torch.mean(snr_db_, dim=0, keepdim=True) # snr_db shape: [1, time_steps] mix_spec_list.append(mix_spec) speech_irm_list.append(speech_irm) snr_db_list.append(snr_db) mix_spec_list = torch.stack(mix_spec_list) speech_irm_list = torch.stack(speech_irm_list) snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1) mix_spec_list = mix_spec_list[:, :-1, :] speech_irm_list = speech_irm_list[:, :-1, :] # mix_spec_list shape: [batch_size, freq_dim, time_steps] # speech_irm_list shape: [batch_size, freq_dim, time_steps] # snr_db shape: [batch_size, 1, time_steps] # assert if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)): raise AssertionError("nan or inf in mix_spec_list") if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)): raise AssertionError("nan or inf in speech_irm_list") if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)): raise AssertionError("nan or inf in snr_db_list") return mix_spec_list, speech_irm_list, snr_db_list collate_fn = CollateFunction() def main(): args = get_args() config = MPNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) logger.info(f"set seed: {config.seed}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"GPU available count: {n_gpu}; device: {device}") # datasets train_dataset = DenoiseExcelDataset( excel_file=args.train_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = DenoiseExcelDataset( excel_file=args.valid_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=args.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) # models logger.info(f"prepare models. config_file: {args.config_file}") generator = MPNetPretrainedModel(config).to(device) discriminator = MetricDiscriminator().to(device) # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") num_params = 0 for p in generator.parameters(): num_params += p.numel() print("Total Parameters (generator): {:.3f}M".format(num_params/1e6)) optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=-1) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1) # training loop logger.info("training") for idx_epoch in range(args.max_epochs): generator.train() discriminator.train() total_loss_d = 0. total_loss_g = 0. total_batches = 0. progress_bar = tqdm( total=len(train_data_loader), desc="Training; epoch: {}".format(idx_epoch), ) for batch in train_data_loader: clean_audio, noisy_audio = batch clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True)) noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True)) one_labels = torch.ones(config.batch_size).to(device, non_blocking=True) clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy()) batch_pesq_score = batch_pesq(audio_list_r, audio_list_g) # Discriminator optim_d.zero_grad() metric_r = discriminator.forward(clean_mag, clean_mag) metric_g = discriminator.forward(clean_mag, mag_g_hat.detach()) loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) if batch_pesq_score is not None: loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten()) else: print("pesq is None!") loss_disc_g = 0 loss_disc_all = loss_disc_r + loss_disc_g loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L2 Magnitude Loss loss_mag = F.mse_loss(clean_mag, mag_g) # Anti-wrapping Phase Loss loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) loss_pha = loss_ip + loss_gd + loss_iaf # L2 Complex Loss loss_com = F.mse_loss(clean_com, com_g) * 2 # L2 Consistency Loss loss_stft = F.mse_loss(com_g, com_g_hat) * 2 # Time Loss loss_time = F.l1_loss(clean_audio, audio_g) # Metric Loss metric_g = discriminator.forward(clean_mag, mag_g_hat) loss_metric = F.mse_loss(metric_g.flatten(), one_labels) loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2 loss_gen_all.backward() optim_g.step() total_loss_d += loss_disc_all.item() total_loss_g += loss_gen_all.item() total_batches += 1 progress_bar.update(1) progress_bar.set_postfix({ "loss_d": round(total_loss_d / total_batches, 4), "loss_g": round(total_loss_g / total_batches, 4), }) generator.eval() torch.cuda.empty_cache() total_pesq_score = 0. total_mag_err = 0. total_pha_err = 0. total_com_err = 0. total_stft_err = 0. total_batches = 0. progress_bar = tqdm( total=len(valid_data_loader), desc="Evaluation; epoch: {}".format(idx_epoch), ) with torch.no_grad(): for batch in valid_data_loader: clean_audio, noisy_audio = batch clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True)) noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True)) clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) total_pesq_score += pesq_score( torch.split(clean_audio, 1, dim=0), torch.split(audio_g, 1, dim=0), config ).item() total_mag_err += F.mse_loss(clean_mag, mag_g).item() val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() total_com_err += F.mse_loss(clean_com, com_g).item() total_stft_err += F.mse_loss(com_g, com_g_hat).item() total_batches += 1 progress_bar.update(1) progress_bar.set_postfix({ "pesq_score": round(total_pesq_score / total_batches, 4), "mag_err": round(total_mag_err / total_batches, 4), "pha_err": round(total_pha_err / total_batches, 4), "com_err": round(total_com_err / total_batches, 4), "stft_err": round(total_stft_err / total_batches, 4), }) return if __name__ == '__main__': main()