nx_denoise / examples /mpnet_aishell /step_2_train_model.py
HoneyTian's picture
update
f74ae8e
raw
history blame
15.3 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
"""
import argparse
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
import platform
from pathlib import Path
import random
import sys
import shutil
from typing import List
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import numpy as np
import torch
from torch.distributed import init_process_group
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DistributedSampler
from torch.utils.data.dataloader import DataLoader
import torchaudio
from tqdm import tqdm
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminator, batch_pesq
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--max_epochs", default=100, type=int)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--learning_rate", default=1e-4, type=float)
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
parser.add_argument("--patience", default=5, type=int)
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--config_file", default="config.yaml", type=str)
args = parser.parse_args()
return args
def logging_config(file_dir: str):
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
logging.basicConfig(format=fmt,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
file_handler = TimedRotatingFileHandler(
filename=os.path.join(file_dir, "main.log"),
encoding="utf-8",
when="D",
interval=1,
backupCount=7
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(fmt))
logger = logging.getLogger(__name__)
logger.addHandler(file_handler)
return logger
class CollateFunction(object):
def __init__(self,
n_fft: int = 512,
win_length: int = 200,
hop_length: int = 80,
window_fn: str = "hamming",
irm_beta: float = 1.0,
epsilon: float = 1e-8,
):
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.window_fn = window_fn
self.irm_beta = irm_beta
self.epsilon = epsilon
self.transform = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
power=2.0,
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
)
@staticmethod
def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
batch_size, channels, freq_dim, time_steps = x.shape
# kernel: [freq_dim, n_time_step]
kernel_size = (freq_dim, n_time_steps)
# pad
pad = n_time_steps // 2
x = torch.concat(tensors=[
x[:, :, :, :pad],
x,
x[:, :, :, -pad:],
], dim=-1)
x = F.unfold(
input=x,
kernel_size=kernel_size,
)
# x shape: [batch_size, fold, time_steps]
return x
def __call__(self, batch: List[dict]):
mix_spec_list = list()
speech_irm_list = list()
snr_db_list = list()
for sample in batch:
noise_wave: torch.Tensor = sample["noise_wave"]
speech_wave: torch.Tensor = sample["speech_wave"]
mix_wave: torch.Tensor = sample["mix_wave"]
# snr_db: float = sample["snr_db"]
noise_spec = self.transform.forward(noise_wave)
speech_spec = self.transform.forward(speech_wave)
mix_spec = self.transform.forward(mix_wave)
# noise_irm = noise_spec / (noise_spec + speech_spec)
speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
speech_irm = torch.pow(speech_irm, self.irm_beta)
# noise_spec, speech_spec, mix_spec, speech_irm
# shape: [freq_dim, time_steps]
snr_db: torch.Tensor = 10 * torch.log10(
speech_spec / (noise_spec + self.epsilon)
)
snr_db = torch.clamp(snr_db, min=self.epsilon)
snr_db_ = torch.unsqueeze(snr_db, dim=0)
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
snr_db_ = torch.squeeze(snr_db_, dim=0)
# snr_db_ shape: [fold, time_steps]
snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
# snr_db shape: [1, time_steps]
mix_spec_list.append(mix_spec)
speech_irm_list.append(speech_irm)
snr_db_list.append(snr_db)
mix_spec_list = torch.stack(mix_spec_list)
speech_irm_list = torch.stack(speech_irm_list)
snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
mix_spec_list = mix_spec_list[:, :-1, :]
speech_irm_list = speech_irm_list[:, :-1, :]
# mix_spec_list shape: [batch_size, freq_dim, time_steps]
# speech_irm_list shape: [batch_size, freq_dim, time_steps]
# snr_db shape: [batch_size, 1, time_steps]
# assert
if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
raise AssertionError("nan or inf in mix_spec_list")
if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
raise AssertionError("nan or inf in speech_irm_list")
if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
raise AssertionError("nan or inf in snr_db_list")
return mix_spec_list, speech_irm_list, snr_db_list
collate_fn = CollateFunction()
def main():
args = get_args()
config = MPNetConfig.from_pretrained(
pretrained_model_name_or_path=args.config_file,
)
serialization_dir = Path(args.serialization_dir)
serialization_dir.mkdir(parents=True, exist_ok=True)
logger = logging_config(serialization_dir)
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
logger.info(f"set seed: {config.seed}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info(f"GPU available count: {n_gpu}; device: {device}")
# datasets
train_dataset = DenoiseExcelDataset(
excel_file=args.train_dataset,
expected_sample_rate=8000,
max_wave_value=32768.0,
)
valid_dataset = DenoiseExcelDataset(
excel_file=args.valid_dataset,
expected_sample_rate=8000,
max_wave_value=32768.0,
)
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
sampler=None,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
valid_data_loader = DataLoader(
dataset=valid_dataset,
batch_size=args.batch_size,
shuffle=True,
sampler=None,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
# models
logger.info(f"prepare models. config_file: {args.config_file}")
generator = MPNetPretrainedModel(config).to(device)
discriminator = MetricDiscriminator().to(device)
# optimizer
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
num_params = 0
for p in generator.parameters():
num_params += p.numel()
print("Total Parameters (generator): {:.3f}M".format(num_params/1e6))
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=-1)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1)
# training loop
logger.info("training")
for idx_epoch in range(args.max_epochs):
generator.train()
discriminator.train()
total_loss_d = 0.
total_loss_g = 0.
total_batches = 0.
progress_bar = tqdm(
total=len(train_data_loader),
desc="Training; epoch: {}".format(idx_epoch),
)
for batch in train_data_loader:
clean_audio, noisy_audio = batch
clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
one_labels = torch.ones(config.batch_size).to(device, non_blocking=True)
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
batch_pesq_score = batch_pesq(audio_list_r, audio_list_g)
# Discriminator
optim_d.zero_grad()
metric_r = discriminator.forward(clean_mag, clean_mag)
metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
if batch_pesq_score is not None:
loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
else:
print("pesq is None!")
loss_disc_g = 0
loss_disc_all = loss_disc_r + loss_disc_g
loss_disc_all.backward()
optim_d.step()
# Generator
optim_g.zero_grad()
# L2 Magnitude Loss
loss_mag = F.mse_loss(clean_mag, mag_g)
# Anti-wrapping Phase Loss
loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
loss_pha = loss_ip + loss_gd + loss_iaf
# L2 Complex Loss
loss_com = F.mse_loss(clean_com, com_g) * 2
# L2 Consistency Loss
loss_stft = F.mse_loss(com_g, com_g_hat) * 2
# Time Loss
loss_time = F.l1_loss(clean_audio, audio_g)
# Metric Loss
metric_g = discriminator.forward(clean_mag, mag_g_hat)
loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
loss_gen_all.backward()
optim_g.step()
total_loss_d += loss_disc_all.item()
total_loss_g += loss_gen_all.item()
total_batches += 1
progress_bar.update(1)
progress_bar.set_postfix({
"loss_d": round(total_loss_d / total_batches, 4),
"loss_g": round(total_loss_g / total_batches, 4),
})
generator.eval()
torch.cuda.empty_cache()
total_pesq_score = 0.
total_mag_err = 0.
total_pha_err = 0.
total_com_err = 0.
total_stft_err = 0.
total_batches = 0.
progress_bar = tqdm(
total=len(valid_data_loader),
desc="Evaluation; epoch: {}".format(idx_epoch),
)
with torch.no_grad():
for batch in valid_data_loader:
clean_audio, noisy_audio = batch
clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
total_pesq_score += pesq_score(
torch.split(clean_audio, 1, dim=0),
torch.split(audio_g, 1, dim=0),
config
).item()
total_mag_err += F.mse_loss(clean_mag, mag_g).item()
val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
total_com_err += F.mse_loss(clean_com, com_g).item()
total_stft_err += F.mse_loss(com_g, com_g_hat).item()
total_batches += 1
progress_bar.update(1)
progress_bar.set_postfix({
"pesq_score": round(total_pesq_score / total_batches, 4),
"mag_err": round(total_mag_err / total_batches, 4),
"pha_err": round(total_pha_err / total_batches, 4),
"com_err": round(total_com_err / total_batches, 4),
"stft_err": round(total_stft_err / total_batches, 4),
})
return
if __name__ == '__main__':
main()