#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py """ import argparse import json import logging from logging.handlers import TimedRotatingFileHandler import os import platform from pathlib import Path import random import sys import shutil from typing import List pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import numpy as np import torch from torch.distributed import init_process_group import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel import torch.nn as nn from torch.nn import functional as F from torch.utils.data import DistributedSampler from torch.utils.data.dataloader import DataLoader import torchaudio from tqdm import tqdm from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminator, batch_pesq from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--max_epochs", default=100, type=int) parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) parser.add_argument("--patience", default=5, type=int) parser.add_argument("--serialization_dir", default="serialization_dir", type=str) parser.add_argument("--config_file", default="config.yaml", type=str) args = parser.parse_args() return args def logging_config(file_dir: str): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) file_handler = TimedRotatingFileHandler( filename=os.path.join(file_dir, "main.log"), encoding="utf-8", when="D", interval=1, backupCount=7 ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) logger.addHandler(file_handler) return logger class CollateFunction(object): def __init__(self): pass def __call__(self, batch: List[dict]): clean_audios = list() noisy_audios = list() for sample in batch: # noise_wave: torch.Tensor = sample["noise_wave"] clean_audio: torch.Tensor = sample["speech_wave"] noisy_audio: torch.Tensor = sample["mix_wave"] # snr_db: float = sample["snr_db"] clean_audios.append(clean_audio) noisy_audios.append(noisy_audio) clean_audios = torch.stack(clean_audios) noisy_audios = torch.stack(noisy_audios) # assert if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): raise AssertionError("nan or inf in clean_audios") if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): raise AssertionError("nan or inf in noisy_audios") return clean_audios, noisy_audios collate_fn = CollateFunction() def main(): args = get_args() config = MPNetConfig.from_pretrained( pretrained_model_name_or_path=args.config_file, ) serialization_dir = Path(args.serialization_dir) serialization_dir.mkdir(parents=True, exist_ok=True) logger = logging_config(serialization_dir) random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) logger.info(f"set seed: {config.seed}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info(f"GPU available count: {n_gpu}; device: {device}") # datasets train_dataset = DenoiseExcelDataset( excel_file=args.train_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) valid_dataset = DenoiseExcelDataset( excel_file=args.valid_dataset, expected_sample_rate=8000, max_wave_value=32768.0, ) train_data_loader = DataLoader( dataset=train_dataset, batch_size=config.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) valid_data_loader = DataLoader( dataset=valid_dataset, batch_size=config.batch_size, shuffle=True, sampler=None, # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, collate_fn=collate_fn, pin_memory=False, # prefetch_factor=64, ) # models logger.info(f"prepare models. config_file: {args.config_file}") generator = MPNetPretrainedModel(config).to(device) discriminator = MetricDiscriminator().to(device) # optimizer logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") num_params = 0 for p in generator.parameters(): num_params += p.numel() print("Total Parameters (generator): {:.3f}M".format(num_params/1e6)) optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=-1) scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1) # training loop logger.info("training") for idx_epoch in range(args.max_epochs): generator.train() discriminator.train() total_loss_d = 0. total_loss_g = 0. total_batches = 0. progress_bar = tqdm( total=len(train_data_loader), desc="Training; epoch: {}".format(idx_epoch), ) for batch in train_data_loader: clean_audio, noisy_audio = batch clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True)) noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True)) one_labels = torch.ones(config.batch_size).to(device, non_blocking=True) clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy()) batch_pesq_score = batch_pesq(audio_list_r, audio_list_g) # Discriminator optim_d.zero_grad() metric_r = discriminator.forward(clean_mag, clean_mag) metric_g = discriminator.forward(clean_mag, mag_g_hat.detach()) loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) if batch_pesq_score is not None: loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten()) else: print("pesq is None!") loss_disc_g = 0 loss_disc_all = loss_disc_r + loss_disc_g loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L2 Magnitude Loss loss_mag = F.mse_loss(clean_mag, mag_g) # Anti-wrapping Phase Loss loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) loss_pha = loss_ip + loss_gd + loss_iaf # L2 Complex Loss loss_com = F.mse_loss(clean_com, com_g) * 2 # L2 Consistency Loss loss_stft = F.mse_loss(com_g, com_g_hat) * 2 # Time Loss loss_time = F.l1_loss(clean_audio, audio_g) # Metric Loss metric_g = discriminator.forward(clean_mag, mag_g_hat) loss_metric = F.mse_loss(metric_g.flatten(), one_labels) loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2 loss_gen_all.backward() optim_g.step() total_loss_d += loss_disc_all.item() total_loss_g += loss_gen_all.item() total_batches += 1 progress_bar.update(1) progress_bar.set_postfix({ "loss_d": round(total_loss_d / total_batches, 4), "loss_g": round(total_loss_g / total_batches, 4), }) generator.eval() torch.cuda.empty_cache() total_pesq_score = 0. total_mag_err = 0. total_pha_err = 0. total_com_err = 0. total_stft_err = 0. total_batches = 0. progress_bar = tqdm( total=len(valid_data_loader), desc="Evaluation; epoch: {}".format(idx_epoch), ) with torch.no_grad(): for batch in valid_data_loader: clean_audio, noisy_audio = batch clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True)) noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True)) clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) total_pesq_score += pesq_score( torch.split(clean_audio, 1, dim=0), torch.split(audio_g, 1, dim=0), config ).item() total_mag_err += F.mse_loss(clean_mag, mag_g).item() val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() total_com_err += F.mse_loss(clean_com, com_g).item() total_stft_err += F.mse_loss(com_g, com_g_hat).item() total_batches += 1 progress_bar.update(1) progress_bar.set_postfix({ "pesq_score": round(total_pesq_score / total_batches, 4), "mag_err": round(total_mag_err / total_batches, 4), "pha_err": round(total_pha_err / total_batches, 4), "com_err": round(total_com_err / total_batches, 4), "stft_err": round(total_stft_err / total_batches, 4), }) return if __name__ == '__main__': main()