nx_denoise / examples /nx_mpnet /step_2_train_model.py
HoneyTian's picture
update
33aff71
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
"""
import argparse
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
import platform
from pathlib import Path
import random
import sys
import shutil
from typing import List
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
from toolbox.torchaudio.models.nx_mpnet.discriminator import MetricDiscriminatorPretrainedModel
from toolbox.torchaudio.models.nx_mpnet.modeling_nx_mpnet import NXMPNet, NXMPNetPretrainedModel
from toolbox.torchaudio.models.nx_mpnet.utils import mag_pha_stft, mag_pha_istft
from toolbox.torchaudio.models.nx_mpnet.metrics import run_batch_pesq, run_pesq_score
from toolbox.torchaudio.models.nx_mpnet.loss import phase_losses
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--max_epochs", default=100, type=int)
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
parser.add_argument("--patience", default=5, type=int)
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
parser.add_argument("--config_file", default="config.yaml", type=str)
args = parser.parse_args()
return args
def logging_config(file_dir: str):
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
logging.basicConfig(format=fmt,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
file_handler = TimedRotatingFileHandler(
filename=os.path.join(file_dir, "main.log"),
encoding="utf-8",
when="D",
interval=1,
backupCount=7
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(fmt))
logger = logging.getLogger(__name__)
logger.addHandler(file_handler)
return logger
class CollateFunction(object):
def __init__(self):
pass
def __call__(self, batch: List[dict]):
clean_audios = list()
noisy_audios = list()
for sample in batch:
# noise_wave: torch.Tensor = sample["noise_wave"]
clean_audio: torch.Tensor = sample["speech_wave"]
noisy_audio: torch.Tensor = sample["mix_wave"]
# snr_db: float = sample["snr_db"]
clean_audios.append(clean_audio)
noisy_audios.append(noisy_audio)
clean_audios = torch.stack(clean_audios)
noisy_audios = torch.stack(noisy_audios)
# assert
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
raise AssertionError("nan or inf in clean_audios")
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
raise AssertionError("nan or inf in noisy_audios")
return clean_audios, noisy_audios
collate_fn = CollateFunction()
def main():
args = get_args()
config = NXMPNetConfig.from_pretrained(
pretrained_model_name_or_path=args.config_file,
)
serialization_dir = Path(args.serialization_dir)
serialization_dir.mkdir(parents=True, exist_ok=True)
logger = logging_config(serialization_dir)
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
logger.info(f"set seed: {config.seed}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info(f"GPU available count: {n_gpu}; device: {device}")
# datasets
train_dataset = DenoiseExcelDataset(
excel_file=args.train_dataset,
expected_sample_rate=8000,
max_wave_value=32768.0,
)
valid_dataset = DenoiseExcelDataset(
excel_file=args.valid_dataset,
expected_sample_rate=8000,
max_wave_value=32768.0,
)
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=config.batch_size,
shuffle=True,
sampler=None,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
valid_data_loader = DataLoader(
dataset=valid_dataset,
batch_size=config.batch_size,
shuffle=True,
sampler=None,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
collate_fn=collate_fn,
pin_memory=False,
# prefetch_factor=64,
)
# models
logger.info(f"prepare models. config_file: {args.config_file}")
generator = NXMPNetPretrainedModel(config).to(device)
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
# optimizer
logger.info("prepare optimizer, lr_scheduler")
num_params = 0
for p in generator.parameters():
num_params += p.numel()
logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
# resume training
last_epoch = -1
for epoch_i in serialization_dir.glob("epoch-*"):
epoch_i = Path(epoch_i)
epoch_idx = epoch_i.stem.split("-")[1]
epoch_idx = int(epoch_idx)
if epoch_idx > last_epoch:
last_epoch = epoch_idx
if last_epoch != -1:
logger.info(f"resume from epoch-{last_epoch}.")
generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
logger.info(f"load state dict for generator.")
with open(generator_pt.as_posix(), "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
generator.load_state_dict(state_dict, strict=True)
logger.info(f"load state dict for discriminator.")
with open(discriminator_pt.as_posix(), "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
discriminator.load_state_dict(state_dict, strict=True)
logger.info(f"load state dict for optim_g.")
with open(optim_g_pth.as_posix(), "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
optim_g.load_state_dict(state_dict)
logger.info(f"load state dict for optim_d.")
with open(optim_d_pth.as_posix(), "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
optim_d.load_state_dict(state_dict)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
# training loop
# state
loss_d = 10000000000
loss_g = 10000000000
pesq_metric = 10000000000
mag_err = 10000000000
pha_err = 10000000000
com_err = 10000000000
stft_err = 10000000000
model_list = list()
best_idx_epoch = None
best_metric = None
patience_count = 0
logger.info("training")
for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
# train
generator.train()
discriminator.train()
total_loss_d = 0.
total_loss_g = 0.
total_batches = 0.
progress_bar = tqdm(
total=len(train_data_loader),
desc="Training; epoch: {}".format(idx_epoch),
)
for batch in train_data_loader:
clean_audio, noisy_audio = batch
clean_audio = clean_audio.to(device)
noisy_audio = noisy_audio.to(device)
one_labels = torch.ones(clean_audio.shape[0]).to(device)
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb")
# Discriminator
optim_d.zero_grad()
metric_r = discriminator.forward(clean_mag, clean_mag)
metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
if -1 in pesq_score_list:
# print("-1 in batch_pesq_score!")
loss_disc_g = 0
else:
pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
loss_disc_all = loss_disc_r + loss_disc_g
loss_disc_all.backward()
optim_d.step()
# Generator
optim_g.zero_grad()
# L2 Magnitude Loss
loss_mag = F.mse_loss(clean_mag, mag_g)
# Anti-wrapping Phase Loss
loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
loss_pha = loss_ip + loss_gd + loss_iaf
# L2 Complex Loss
loss_com = F.mse_loss(clean_com, com_g) * 2
# L2 Consistency Loss
loss_stft = F.mse_loss(com_g, com_g_hat) * 2
# Time Loss
loss_time = F.l1_loss(clean_audio, audio_g)
# Metric Loss
metric_g = discriminator.forward(clean_mag, mag_g_hat)
loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
loss_gen_all.backward()
optim_g.step()
total_loss_d += loss_disc_all.item()
total_loss_g += loss_gen_all.item()
total_batches += 1
loss_d = round(total_loss_d / total_batches, 4)
loss_g = round(total_loss_g / total_batches, 4)
progress_bar.update(1)
progress_bar.set_postfix({
"loss_d": loss_d,
"loss_g": loss_g,
})
# evaluation
generator.eval()
discriminator.eval()
torch.cuda.empty_cache()
total_pesq_score = 0.
total_mag_err = 0.
total_pha_err = 0.
total_com_err = 0.
total_stft_err = 0.
total_batches = 0.
progress_bar = tqdm(
total=len(valid_data_loader),
desc="Evaluation; epoch: {}".format(idx_epoch),
)
with torch.no_grad():
for batch in valid_data_loader:
clean_audio, noisy_audio = batch
clean_audio = clean_audio.to(device)
noisy_audio = noisy_audio.to(device)
clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
clean_audio_list = torch.split(clean_audio, 1, dim=0)
enhanced_audio_list = torch.split(audio_g, 1, dim=0)
clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list]
enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list]
pesq_score = run_pesq_score(
clean_audio_list,
enhanced_audio_list,
sample_rate = config.sample_rate,
mode = "nb",
)
total_pesq_score += pesq_score
total_mag_err += F.mse_loss(clean_mag, mag_g).item()
val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
total_com_err += F.mse_loss(clean_com, com_g).item()
total_stft_err += F.mse_loss(com_g, com_g_hat).item()
total_batches += 1
pesq_metric = round(total_pesq_score / total_batches, 4)
mag_err = round(total_mag_err / total_batches, 4)
pha_err = round(total_pha_err / total_batches, 4)
com_err = round(total_com_err / total_batches, 4)
stft_err = round(total_stft_err / total_batches, 4)
progress_bar.update(1)
progress_bar.set_postfix({
"pesq_metric": pesq_metric,
"mag_err": mag_err,
"pha_err": pha_err,
"com_err": com_err,
"stft_err": stft_err,
})
# scheduler
scheduler_g.step()
scheduler_d.step()
# save path
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
epoch_dir.mkdir(parents=True, exist_ok=False)
# save models
generator.save_pretrained(epoch_dir.as_posix())
discriminator.save_pretrained(epoch_dir.as_posix())
# save optim
torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
model_list.append(epoch_dir)
if len(model_list) >= args.num_serialized_models_to_keep:
model_to_delete: Path = model_list.pop(0)
shutil.rmtree(model_to_delete.as_posix())
# save metric
if best_metric is None:
best_idx_epoch = idx_epoch
best_metric = pesq_metric
elif pesq_metric > best_metric:
# great is better.
best_idx_epoch = idx_epoch
best_metric = pesq_metric
else:
pass
metrics = {
"idx_epoch": idx_epoch,
"best_idx_epoch": best_idx_epoch,
"loss_d": loss_d,
"loss_g": loss_g,
"pesq_metric": pesq_metric,
"mag_err": mag_err,
"pha_err": pha_err,
"com_err": com_err,
"stft_err": stft_err,
}
metrics_filename = epoch_dir / "metrics_epoch.json"
with open(metrics_filename, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=4, ensure_ascii=False)
# save best
best_dir = serialization_dir / "best"
if best_idx_epoch == idx_epoch:
if best_dir.exists():
shutil.rmtree(best_dir)
shutil.copytree(epoch_dir, best_dir)
# early stop
early_stop_flag = False
if best_idx_epoch == idx_epoch:
patience_count = 0
else:
patience_count += 1
if patience_count >= args.patience:
early_stop_flag = True
# early stop
if early_stop_flag:
break
return
if __name__ == "__main__":
main()