nx_denoise / examples /conv_tasnet /step_2_train_model.py
HoneyTian's picture
update
1e6339d
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/kaituoxu/Conv-TasNet/tree/master/src
一般场景:
目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
高要求场景(如医疗助听、语音识别):
需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
"""
import argparse
import json
import logging
from logging.handlers import TimedRotatingFileHandler
import os
import platform
from pathlib import Path
import random
import sys
import shutil
from typing import List
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
from toolbox.torchaudio.metrics.pesq import run_pesq_score
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--max_epochs", default=200, type=int)
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
parser.add_argument("--patience", default=5, type=int)
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
parser.add_argument("--seed", default=1234, type=int)
parser.add_argument("--config_file", default="config.yaml", type=str)
args = parser.parse_args()
return args
def logging_config(file_dir: str):
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
logging.basicConfig(format=fmt,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
file_handler = TimedRotatingFileHandler(
filename=os.path.join(file_dir, "main.log"),
encoding="utf-8",
when="D",
interval=1,
backupCount=7
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(fmt))
logger = logging.getLogger(__name__)
logger.addHandler(file_handler)
return logger
class CollateFunction(object):
def __init__(self):
pass
def __call__(self, batch: List[dict]):
clean_audios = list()
noisy_audios = list()
for sample in batch:
# noise_wave: torch.Tensor = sample["noise_wave"]
clean_audio: torch.Tensor = sample["speech_wave"]
noisy_audio: torch.Tensor = sample["mix_wave"]
# snr_db: float = sample["snr_db"]
clean_audios.append(clean_audio)
noisy_audios.append(noisy_audio)
clean_audios = torch.stack(clean_audios)
noisy_audios = torch.stack(noisy_audios)
# assert
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
raise AssertionError("nan or inf in clean_audios")
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
raise AssertionError("nan or inf in noisy_audios")
return clean_audios, noisy_audios
collate_fn = CollateFunction()
def main():
args = get_args()
config = ConvTasNetConfig.from_pretrained(
pretrained_model_name_or_path=args.config_file,
)
serialization_dir = Path(args.serialization_dir)
serialization_dir.mkdir(parents=True, exist_ok=True)
logger = logging_config(serialization_dir)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
logger.info(f"set seed: {args.seed}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info(f"GPU available count: {n_gpu}; device: {device}")
# datasets
train_dataset = DenoiseJsonlDataset(
jsonl_file=args.train_dataset,
expected_sample_rate=config.sample_rate,
max_wave_value=32768.0,
min_snr_db=config.min_snr_db,
max_snr_db=config.max_snr_db,
skip=825000,
)
valid_dataset = DenoiseJsonlDataset(
jsonl_file=args.valid_dataset,
expected_sample_rate=config.sample_rate,
max_wave_value=32768.0,
min_snr_db=config.min_snr_db,
max_snr_db=config.max_snr_db,
)
train_data_loader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
# shuffle=True,
sampler=None,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
collate_fn=collate_fn,
pin_memory=False,
prefetch_factor=2,
)
valid_data_loader = DataLoader(
dataset=valid_dataset,
batch_size=args.batch_size,
# shuffle=True,
sampler=None,
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
collate_fn=collate_fn,
pin_memory=False,
prefetch_factor=2,
)
# models
logger.info(f"prepare models. config_file: {args.config_file}")
model = ConvTasNetPretrainedModel(config).to(device)
model.to(device)
model.train()
# optimizer
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
# resume training
last_step_idx = -1
last_epoch = -1
for step_idx_str in serialization_dir.glob("steps-*"):
step_idx_str = Path(step_idx_str)
step_idx = step_idx_str.stem.split("-")[1]
step_idx = int(step_idx)
if step_idx > last_step_idx:
last_step_idx = step_idx
if last_step_idx != -1:
logger.info(f"resume from steps-{last_step_idx}.")
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
logger.info(f"load state dict for model.")
with open(model_pt.as_posix(), "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
logger.info(f"load state dict for optimizer.")
with open(optimizer_pth.as_posix(), "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
optimizer.load_state_dict(state_dict)
if config.lr_scheduler == "CosineAnnealingLR":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
last_epoch=last_epoch,
# T_max=10 * config.eval_steps,
# eta_min=0.01 * config.lr,
**config.lr_scheduler_kwargs,
)
elif config.lr_scheduler == "MultiStepLR":
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
last_epoch=last_epoch,
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
)
else:
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
mr_stft_loss_fn = MultiResolutionSTFTLoss(
fft_size_list=[256, 512, 1024],
win_size_list=[120, 240, 480],
hop_size_list=[25, 50, 100],
factor_sc=1.5,
factor_mag=1.0,
reduction="mean"
).to(device)
pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
# training loop
# state
average_pesq_score = 1000000000
average_loss = 1000000000
average_ae_loss = 1000000000
average_neg_si_snr_loss = 1000000000
average_neg_stoi_loss = 1000000000
model_list = list()
best_epoch_idx = None
best_step_idx = None
best_metric = None
patience_count = 0
step_idx = 0 if last_step_idx == -1 else last_step_idx
logger.info("training")
for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
# train
model.train()
total_pesq_score = 0.
total_loss = 0.
total_ae_loss = 0.
total_neg_si_snr_loss = 0.
total_neg_stoi_loss = 0.
total_mr_stft_loss = 0.
total_pesq_loss = 0.
total_batches = 0.
progress_bar_train = tqdm(
initial=step_idx,
desc="Training; epoch-{}".format(epoch_idx),
)
for train_batch in train_data_loader:
clean_audios, noisy_audios = train_batch
clean_audios: torch.Tensor = clean_audios.to(device)
noisy_audios: torch.Tensor = noisy_audios.to(device)
denoise_audios = model.forward(noisy_audios)
denoise_audios = torch.squeeze(denoise_audios, dim=1)
if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)):
raise AssertionError("nan or inf in denoise_audios")
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
logger.info(f"find nan or inf in loss.")
continue
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
total_pesq_score += pesq_score
total_loss += loss.item()
total_ae_loss += ae_loss.item()
total_neg_si_snr_loss += neg_si_snr_loss.item()
total_neg_stoi_loss += neg_stoi_loss.item()
total_mr_stft_loss += mr_stft_loss.item()
total_pesq_loss += pesq_loss.item()
total_batches += 1
average_pesq_score = round(total_pesq_score / total_batches, 4)
average_loss = round(total_loss / total_batches, 4)
average_ae_loss = round(total_ae_loss / total_batches, 4)
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
progress_bar_train.update(1)
progress_bar_train.set_postfix({
"lr": lr_scheduler.get_last_lr()[0],
"pesq_score": average_pesq_score,
"loss": average_loss,
"ae_loss": average_ae_loss,
"neg_si_snr_loss": average_neg_si_snr_loss,
"neg_stoi_loss": average_neg_stoi_loss,
"mr_stft_loss": average_mr_stft_loss,
"pesq_loss": average_pesq_loss,
})
# evaluation
step_idx += 1
if step_idx % config.eval_steps == 0:
with torch.no_grad():
torch.cuda.empty_cache()
total_pesq_score = 0.
total_loss = 0.
total_ae_loss = 0.
total_neg_si_snr_loss = 0.
total_neg_stoi_loss = 0.
total_mr_stft_loss = 0.
total_pesq_loss = 0.
total_batches = 0.
progress_bar_train.close()
progress_bar_eval = tqdm(
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
)
for eval_batch in valid_data_loader:
clean_audios, noisy_audios = eval_batch
clean_audios = clean_audios.to(device)
noisy_audios = noisy_audios.to(device)
denoise_audios = model.forward(noisy_audios)
denoise_audios = torch.squeeze(denoise_audios, dim=1)
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
logger.info(f"find nan or inf in loss.")
continue
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
total_pesq_score += pesq_score
total_loss += loss.item()
total_ae_loss += ae_loss.item()
total_neg_si_snr_loss += neg_si_snr_loss.item()
total_neg_stoi_loss += neg_stoi_loss.item()
total_mr_stft_loss += mr_stft_loss.item()
total_pesq_loss += pesq_loss.item()
total_batches += 1
average_pesq_score = round(total_pesq_score / total_batches, 4)
average_loss = round(total_loss / total_batches, 4)
average_ae_loss = round(total_ae_loss / total_batches, 4)
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
progress_bar_eval.update(1)
progress_bar_eval.set_postfix({
"lr": lr_scheduler.get_last_lr()[0],
"pesq_score": average_pesq_score,
"loss": average_loss,
"ae_loss": average_ae_loss,
"neg_si_snr_loss": average_neg_si_snr_loss,
"neg_stoi_loss": average_neg_stoi_loss,
"mr_stft_loss": average_mr_stft_loss,
"pesq_loss": average_pesq_loss,
})
total_pesq_score = 0.
total_loss = 0.
total_ae_loss = 0.
total_neg_si_snr_loss = 0.
total_neg_stoi_loss = 0.
total_mr_stft_loss = 0.
total_pesq_loss = 0.
total_batches = 0.
progress_bar_eval.close()
progress_bar_train = tqdm(
initial=progress_bar_train.n,
postfix=progress_bar_train.postfix,
desc=progress_bar_train.desc,
)
# save path
save_dir = serialization_dir / "steps-{}".format(step_idx)
save_dir.mkdir(parents=True, exist_ok=False)
# save models
model.save_pretrained(save_dir.as_posix())
model_list.append(save_dir)
if len(model_list) >= args.num_serialized_models_to_keep:
model_to_delete: Path = model_list.pop(0)
shutil.rmtree(model_to_delete.as_posix())
# save optim
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
# save metric
if best_metric is None:
best_epoch_idx = epoch_idx
best_step_idx = step_idx
best_metric = average_pesq_score
elif average_pesq_score > best_metric:
# great is better.
best_epoch_idx = epoch_idx
best_step_idx = step_idx
best_metric = average_pesq_score
else:
pass
metrics = {
"epoch_idx": epoch_idx,
"best_epoch_idx": best_epoch_idx,
"best_step_idx": best_step_idx,
"pesq_score": average_pesq_score,
"loss": average_loss,
"ae_loss": average_ae_loss,
"neg_si_snr_loss": average_neg_si_snr_loss,
"neg_stoi_loss": average_neg_stoi_loss,
"mr_stft_loss": average_mr_stft_loss,
"pesq_loss": average_pesq_loss,
}
metrics_filename = save_dir / "metrics_epoch.json"
with open(metrics_filename, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=4, ensure_ascii=False)
# save best
best_dir = serialization_dir / "best"
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
if best_dir.exists():
shutil.rmtree(best_dir)
shutil.copytree(save_dir, best_dir)
# early stop
early_stop_flag = False
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
patience_count = 0
else:
patience_count += 1
if patience_count >= args.patience:
early_stop_flag = True
# early stop
if early_stop_flag:
break
return
if __name__ == "__main__":
main()