Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement | |
""" | |
import argparse | |
import json | |
import logging | |
from logging.handlers import TimedRotatingFileHandler | |
import os | |
import platform | |
from pathlib import Path | |
import random | |
import sys | |
import shutil | |
from typing import List | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.utils.data.dataloader import DataLoader | |
import torchaudio | |
from tqdm import tqdm | |
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset | |
from toolbox.torchaudio.models.spectrum_dfnet.configuration_spectrum_dfnet import SpectrumDfNetConfig | |
from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--train_dataset", default="train.xlsx", type=str) | |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
parser.add_argument("--max_epochs", default=100, type=int) | |
parser.add_argument("--batch_size", default=16, type=int) | |
parser.add_argument("--learning_rate", default=1e-4, type=float) | |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) | |
parser.add_argument("--patience", default=5, type=int) | |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str) | |
parser.add_argument("--seed", default=0, type=int) | |
parser.add_argument("--config_file", default="config.yaml", type=str) | |
args = parser.parse_args() | |
return args | |
def logging_config(file_dir: str): | |
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
logging.basicConfig(format=fmt, | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO) | |
file_handler = TimedRotatingFileHandler( | |
filename=os.path.join(file_dir, "main.log"), | |
encoding="utf-8", | |
when="D", | |
interval=1, | |
backupCount=7 | |
) | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(logging.Formatter(fmt)) | |
logger = logging.getLogger(__name__) | |
logger.addHandler(file_handler) | |
return logger | |
class CollateFunction(object): | |
def __init__(self, | |
n_fft: int = 512, | |
win_length: int = 200, | |
hop_length: int = 80, | |
window_fn: str = "hamming", | |
irm_beta: float = 1.0, | |
epsilon: float = 1e-8, | |
): | |
self.n_fft = n_fft | |
self.win_length = win_length | |
self.hop_length = hop_length | |
self.window_fn = window_fn | |
self.irm_beta = irm_beta | |
self.epsilon = epsilon | |
self.complex_transform = torchaudio.transforms.Spectrogram( | |
n_fft=self.n_fft, | |
win_length=self.win_length, | |
hop_length=self.hop_length, | |
power=None, | |
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, | |
) | |
self.transform = torchaudio.transforms.Spectrogram( | |
n_fft=self.n_fft, | |
win_length=self.win_length, | |
hop_length=self.hop_length, | |
power=2.0, | |
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, | |
) | |
def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3): | |
batch_size, channels, freq_dim, time_steps = x.shape | |
# kernel: [freq_dim, n_time_step] | |
kernel_size = (freq_dim, n_time_steps) | |
# pad | |
pad = n_time_steps // 2 | |
x = torch.concat(tensors=[ | |
x[:, :, :, :pad], | |
x, | |
x[:, :, :, -pad:], | |
], dim=-1) | |
x = F.unfold( | |
input=x, | |
kernel_size=kernel_size, | |
) | |
# x shape: [batch_size, fold, time_steps] | |
return x | |
def __call__(self, batch: List[dict]): | |
speech_complex_spec_list = list() | |
mix_complex_spec_list = list() | |
speech_irm_list = list() | |
snr_db_list = list() | |
for sample in batch: | |
noise_wave: torch.Tensor = sample["noise_wave"] | |
speech_wave: torch.Tensor = sample["speech_wave"] | |
mix_wave: torch.Tensor = sample["mix_wave"] | |
# snr_db: float = sample["snr_db"] | |
noise_spec = self.transform.forward(noise_wave) | |
speech_spec = self.transform.forward(speech_wave) | |
speech_complex_spec = self.complex_transform.forward(speech_wave) | |
mix_complex_spec = self.complex_transform.forward(mix_wave) | |
# noise_irm = noise_spec / (noise_spec + speech_spec) | |
speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) | |
speech_irm = torch.pow(speech_irm, self.irm_beta) | |
# noise_spec, speech_spec, mix_spec, speech_irm | |
# shape: [freq_dim, time_steps] | |
snr_db: torch.Tensor = 10 * torch.log10( | |
speech_spec / (noise_spec + self.epsilon) | |
) | |
snr_db = torch.clamp(snr_db, min=self.epsilon) | |
snr_db_ = torch.unsqueeze(snr_db, dim=0) | |
snr_db_ = torch.unsqueeze(snr_db_, dim=0) | |
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3) | |
snr_db_ = torch.squeeze(snr_db_, dim=0) | |
# snr_db_ shape: [fold, time_steps] | |
snr_db = torch.mean(snr_db_, dim=0, keepdim=True) | |
# snr_db shape: [1, time_steps] | |
speech_complex_spec_list.append(speech_complex_spec) | |
mix_complex_spec_list.append(mix_complex_spec) | |
speech_irm_list.append(speech_irm) | |
snr_db_list.append(snr_db) | |
speech_complex_spec_list = torch.stack(speech_complex_spec_list) | |
mix_complex_spec_list = torch.stack(mix_complex_spec_list) | |
speech_irm_list = torch.stack(speech_irm_list) | |
snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1) | |
speech_complex_spec_list = speech_complex_spec_list[:, :-1, :] | |
mix_complex_spec_list = mix_complex_spec_list[:, :-1, :] | |
speech_irm_list = speech_irm_list[:, :-1, :] | |
# speech_complex_spec_list shape: [batch_size, freq_dim, time_steps] | |
# mix_complex_spec_list shape: [batch_size, freq_dim, time_steps] | |
# speech_irm_list shape: [batch_size, freq_dim, time_steps] | |
# snr_db shape: [batch_size, 1, time_steps] | |
# assert | |
if torch.any(torch.isnan(speech_complex_spec_list)) or torch.any(torch.isinf(speech_complex_spec_list)): | |
raise AssertionError("nan or inf in speech_complex_spec_list") | |
if torch.any(torch.isnan(mix_complex_spec_list)) or torch.any(torch.isinf(mix_complex_spec_list)): | |
raise AssertionError("nan or inf in mix_complex_spec_list") | |
if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)): | |
raise AssertionError("nan or inf in speech_irm_list") | |
if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)): | |
raise AssertionError("nan or inf in snr_db_list") | |
return speech_complex_spec_list, mix_complex_spec_list, speech_irm_list, snr_db_list | |
collate_fn = CollateFunction() | |
def main(): | |
args = get_args() | |
serialization_dir = Path(args.serialization_dir) | |
serialization_dir.mkdir(parents=True, exist_ok=True) | |
logger = logging_config(serialization_dir) | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
logger.info("set seed: {}".format(args.seed)) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
n_gpu = torch.cuda.device_count() | |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) | |
# datasets | |
logger.info("prepare datasets") | |
train_dataset = DenoiseExcelDataset( | |
excel_file=args.train_dataset, | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
valid_dataset = DenoiseExcelDataset( | |
excel_file=args.valid_dataset, | |
expected_sample_rate=8000, | |
max_wave_value=32768.0, | |
) | |
train_data_loader = DataLoader( | |
dataset=train_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
valid_data_loader = DataLoader( | |
dataset=valid_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. | |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, | |
collate_fn=collate_fn, | |
pin_memory=False, | |
# prefetch_factor=64, | |
) | |
# models | |
logger.info(f"prepare models. config_file: {args.config_file}") | |
config = SpectrumDfNetConfig.from_pretrained( | |
pretrained_model_name_or_path=args.config_file, | |
# num_labels=vocabulary.get_vocab_size(namespace="labels") | |
) | |
model = SpectrumDfNetPretrainedModel( | |
config=config, | |
) | |
model.to(device) | |
model.train() | |
# optimizer | |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") | |
param_optimizer = model.parameters() | |
optimizer = torch.optim.Adam( | |
param_optimizer, | |
lr=args.learning_rate, | |
) | |
# lr_scheduler = torch.optim.lr_scheduler.StepLR( | |
# optimizer, | |
# step_size=2000 | |
# ) | |
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
optimizer, | |
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 | |
) | |
speech_mse_loss = nn.MSELoss( | |
reduction="mean", | |
) | |
irm_mse_loss = nn.MSELoss( | |
reduction="mean", | |
) | |
snr_mse_loss = nn.MSELoss( | |
reduction="mean", | |
) | |
# training loop | |
logger.info("training") | |
training_loss = 10000000000 | |
evaluation_loss = 10000000000 | |
model_list = list() | |
best_idx_epoch = None | |
best_metric = None | |
patience_count = 0 | |
for idx_epoch in range(args.max_epochs): | |
total_loss = 0. | |
total_examples = 0. | |
progress_bar = tqdm( | |
total=len(train_data_loader), | |
desc="Training; epoch: {}".format(idx_epoch), | |
) | |
for batch in train_data_loader: | |
speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch | |
speech_complex_spec = speech_complex_spec.to(device) | |
mix_complex_spec = mix_complex_spec.to(device) | |
speech_irm_target = speech_irm.to(device) | |
snr_db_target = snr_db.to(device) | |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec) | |
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)): | |
raise AssertionError("nan or inf in speech_spec_prediction") | |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): | |
raise AssertionError("nan or inf in speech_irm_prediction") | |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): | |
raise AssertionError("nan or inf in lsnr_prediction") | |
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec)) | |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) | |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) | |
loss = speech_loss + irm_loss + snr_loss | |
total_loss += loss.item() | |
total_examples += mix_complex_spec.size(0) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
training_loss = total_loss / total_examples | |
training_loss = round(training_loss, 4) | |
progress_bar.update(1) | |
progress_bar.set_postfix({ | |
"training_loss": training_loss, | |
}) | |
total_loss = 0. | |
total_examples = 0. | |
progress_bar = tqdm( | |
total=len(valid_data_loader), | |
desc="Evaluation; epoch: {}".format(idx_epoch), | |
) | |
for batch in valid_data_loader: | |
speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch | |
speech_complex_spec = speech_complex_spec.to(device) | |
mix_complex_spec = mix_complex_spec.to(device) | |
speech_irm_target = speech_irm.to(device) | |
snr_db_target = snr_db.to(device) | |
with torch.no_grad(): | |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec) | |
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)): | |
raise AssertionError("nan or inf in speech_spec_prediction") | |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): | |
raise AssertionError("nan or inf in speech_irm_prediction") | |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): | |
raise AssertionError("nan or inf in lsnr_prediction") | |
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec)) | |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) | |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) | |
loss = speech_loss + irm_loss + snr_loss | |
total_loss += loss.item() | |
total_examples += mix_complex_spec.size(0) | |
evaluation_loss = total_loss / total_examples | |
evaluation_loss = round(evaluation_loss, 4) | |
progress_bar.update(1) | |
progress_bar.set_postfix({ | |
"evaluation_loss": evaluation_loss, | |
}) | |
# save path | |
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) | |
epoch_dir.mkdir(parents=True, exist_ok=False) | |
# save models | |
model.save_pretrained(epoch_dir.as_posix()) | |
model_list.append(epoch_dir) | |
if len(model_list) >= args.num_serialized_models_to_keep: | |
model_to_delete: Path = model_list.pop(0) | |
shutil.rmtree(model_to_delete.as_posix()) | |
# save metric | |
if best_metric is None: | |
best_idx_epoch = idx_epoch | |
best_metric = evaluation_loss | |
elif evaluation_loss < best_metric: | |
best_idx_epoch = idx_epoch | |
best_metric = evaluation_loss | |
else: | |
pass | |
metrics = { | |
"idx_epoch": idx_epoch, | |
"best_idx_epoch": best_idx_epoch, | |
"training_loss": training_loss, | |
"evaluation_loss": evaluation_loss, | |
"learning_rate": optimizer.param_groups[0]["lr"], | |
} | |
metrics_filename = epoch_dir / "metrics_epoch.json" | |
with open(metrics_filename, "w", encoding="utf-8") as f: | |
json.dump(metrics, f, indent=4, ensure_ascii=False) | |
# save best | |
best_dir = serialization_dir / "best" | |
if best_idx_epoch == idx_epoch: | |
if best_dir.exists(): | |
shutil.rmtree(best_dir) | |
shutil.copytree(epoch_dir, best_dir) | |
# early stop | |
early_stop_flag = False | |
if best_idx_epoch == idx_epoch: | |
patience_count = 0 | |
else: | |
patience_count += 1 | |
if patience_count >= args.patience: | |
early_stop_flag = True | |
# early stop | |
if early_stop_flag: | |
break | |
return | |
if __name__ == '__main__': | |
main() | |