diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..29890e6063c072e1b4ba58a5ab544e3e38e02772 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,5 @@ + +.git/ +.idea/ + +/examples/ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d4c366e78320418ab6b957108ea211aa7251edec --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ + +.gradio/ +.git/ +.idea/ + +**/evaluation_audio/ +**/file_dir/ +**/flagged/ +**/log/ +**/logs/ +**/__pycache__/ + +/data/ +/docs/ +/dotenv/ +/hub_datasets/ +/script/ +/thirdparty/ +/trained_models/ +/temp/ + +**/*.wav +**/*.xlsx + +requirements-python-3-9-9.txt diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..9906487796975ccef50244eb784461e592dd2b0b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.12 + +WORKDIR /code + +COPY . /code + +RUN apt-get update +RUN apt-get install -y ffmpeg build-essential + +RUN pip install --upgrade pip +RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt + +RUN useradd -m -u 1000 user + +USER user + +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +WORKDIR $HOME/app + +COPY --chown=user . $HOME/app + +CMD ["python3", "main.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..19217cf8f9bdf1c720d7e7c366f9627caa63e7d4 --- /dev/null +++ b/README.md @@ -0,0 +1,129 @@ +--- +title: NX Denoise +emoji: 🐢 +colorFrom: purple +colorTo: blue +sdk: docker +pinned: false +license: apache-2.0 +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +## NX Denoise + + +### datasets + +```text + +AISHELL (15G) +https://openslr.trmal.net/resources/33/ + +AISHELL-3 (19G) +http://www.openslr.org/93/ + +DNS3 +https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh +噪音数据来源于 DEMAND, FreeSound, AudioSet. + +MS-SNSD +https://github.com/microsoft/MS-SNSD +噪音数据来源于 DEMAND, FreeSound. + +MUSAN +https://www.openslr.org/17/ +其中包含 music, noise, speech. +music 是一些纯音乐, noise 包含 free-sound, sound-bible, sound-bible部分也许可以做为补充部分. +总的来说, 有用的不部不多, 可能噪音数据仍然需要自己收集为主, 更加可靠. + +CHiME-4 +https://www.chimechallenge.org/challenges/chime4/download.html + +freesound +https://freesound.org/ + +AudioSet +https://research.google.com/audioset/index.html +``` + + +### ### 创建训练容器 + +```text +在容器中训练模型,需要能够从容器中访问到 GPU,参考: +https://hub.docker.com/r/ollama/ollama + +docker run -itd \ +--name nx_denoise \ +--network host \ +--gpus all \ +--privileged \ +--ipc=host \ +-v /data/tianxing/HuggingDatasets/nx_noise/data:/data/tianxing/HuggingDatasets/nx_noise/data \ +-v /data/tianxing/PycharmProjects/nx_denoise:/data/tianxing/PycharmProjects/nx_denoise \ +python:3.12 + + +查看GPU +nvidia-smi +watch -n 1 -d nvidia-smi + + +``` + +```text +在容器中访问 GPU + +参考: +https://blog.csdn.net/footless_bird/article/details/136291344 +步骤: +# 安装 +yum install -y nvidia-container-toolkit + +# 编辑文件 /etc/docker/daemon.json +cat /etc/docker/daemon.json +{ + "data-root": "/data/lib/docker", + "default-runtime": "nvidia", + "runtimes": { + "nvidia": { + "path": "/usr/bin/nvidia-container-runtime", + "runtimeArgs": [] + } + }, + "registry-mirrors": [ + "https://docker.m.daocloud.io", + "https://dockerproxy.com", + "https://docker.mirrors.ustc.edu.cn", + "https://docker.nju.edu.cn" + ] +} + +# 重启 docker +systemctl restart docker +systemctl daemon-reload + +# 测试容器内能否访问 GPU. +docker run --gpus all python:3.12-slim nvidia-smi + +# 通过这种方式启动容器, 在容器中, 可以查看到 GPU. 但是容器中没有 GPU驱动 nvidia-smi 不工作. +docker run -it --privileged python:3.12-slim /bin/bash +apt update +apt install -y pciutils +lspci | grep -i nvidia +#00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1) + +# 网上看的是这种启动容器的方式, 但是进去后仍然是 nvidia-smi 不工作. +docker run \ +--device /dev/nvidia0:/dev/nvidia0 \ +--device /dev/nvidiactl:/dev/nvidiactl \ +--device /dev/nvidia-uvm:/dev/nvidia-uvm \ +-v /usr/local/nvidia:/usr/local/nvidia \ +-it --privileged python:3.12-slim /bin/bash + + +# 这种方式进入容器, nvidia-smi 可以工作. 应该关键是 --gpus all 参数. +docker run -itd --gpus all --name open_unsloth python:3.12-slim /bin/bash +docker run -itd --gpus all --name Qwen2-7B-Instruct python:3.12-slim /bin/bash + +``` diff --git a/examples/clean_unet/run.sh b/examples/clean_unet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..3ae4611e58deb6afeb901a0c33148d6e9acb99b3 --- /dev/null +++ b/examples/clean_unet/run.sh @@ -0,0 +1,181 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=10000000 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/clean_unet/step_1_prepare_data.py b/examples/clean_unet/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfbdf266a33498229c6003bb4858c51ba9e28c0 --- /dev/null +++ b/examples/clean_unet/step_1_prepare_data.py @@ -0,0 +1,201 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=10000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/clean_unet/step_2_train_model.py b/examples/clean_unet/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..eb049b090c97f01f4fb3178a5f7862b3aec7994c --- /dev/null +++ b/examples/clean_unet/step_2_train_model.py @@ -0,0 +1,419 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/NVIDIA/CleanUNet/blob/main/train.py + +https://github.com/NVIDIA/CleanUNet/blob/main/configs/DNS-large-full.json +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig +from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel +from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay +from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss +from toolbox.torchaudio.models.clean_unet.metrics import run_pesq_score + +torch.autograd.set_detect_anomaly(True) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=2e-4, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = CleanUNetConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info(f"set seed: {args.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = CleanUNetPretrainedModel(config).to(device) + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate) + + # resume training + last_epoch = -1 + for epoch_i in serialization_dir.glob("epoch-*"): + epoch_i = Path(epoch_i) + epoch_idx = epoch_i.stem.split("-")[1] + epoch_idx = int(epoch_idx) + if epoch_idx > last_epoch: + last_epoch = epoch_idx + + if last_epoch != -1: + logger.info(f"resume from epoch-{last_epoch}.") + model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt" + optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + logger.info(f"load state dict for optimizer.") + with open(optimizer_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optimizer.load_state_dict(state_dict) + + lr_scheduler = LinearWarmupCosineDecay( + optimizer, + lr_max=args.learning_rate, + n_iter=250000, + iteration=250000, + divider=25, + warmup_proportion=0.05, + phase=("linear", "cosine"), + ) + + # ae_loss_fn = nn.MSELoss(reduction="mean") + ae_loss_fn = nn.L1Loss(reduction="mean").to(device) + + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_sizes=[256, 512, 1024], + hop_sizes=[25, 50, 120], + win_lengths=[120, 240, 600], + sc_lambda=0.5, + mag_lambda=0.5, + band="full" + ).to(device) + + # training loop + + # state + average_pesq_score = 10000000000 + average_loss = 10000000000 + average_ae_loss = 10000000000 + average_sc_loss = 10000000000 + average_mag_loss = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + logger.info("training") + for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_ae_loss = 0. + total_sc_loss = 0. + total_mag_loss = 0. + total_batches = 0. + + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + for batch in train_data_loader: + clean_audios, noisy_audios = batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + + enhanced_audios = model.forward(noisy_audios) + enhanced_audios = torch.squeeze(enhanced_audios, dim=1) + + ae_loss = ae_loss_fn(enhanced_audios, clean_audios) + sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios) + + loss = ae_loss + sc_loss + mag_loss + + enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb") + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_ae_loss += ae_loss.item() + total_sc_loss += sc_loss.item() + total_mag_loss += mag_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_ae_loss = round(total_ae_loss / total_batches, 4) + average_sc_loss = round(total_sc_loss / total_batches, 4) + average_mag_loss = round(total_mag_loss / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "pesq_score": average_pesq_score, + "loss": average_loss, + "ae_loss": average_ae_loss, + "sc_loss": average_sc_loss, + "mag_loss": average_mag_loss, + }) + + # evaluation + model.eval() + + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_ae_loss = 0. + total_sc_loss = 0. + total_mag_loss = 0. + total_batches = 0. + + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + with torch.no_grad(): + for batch in valid_data_loader: + clean_audios, noisy_audios = batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + + enhanced_audios = model.forward(noisy_audios) + enhanced_audios = torch.squeeze(enhanced_audios, dim=1) + + ae_loss = ae_loss_fn(enhanced_audios, clean_audios) + sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios) + + loss = ae_loss + sc_loss + mag_loss + + enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_ae_loss += ae_loss.item() + total_sc_loss += sc_loss.item() + total_mag_loss += mag_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_ae_loss = round(total_ae_loss / total_batches, 4) + average_sc_loss = round(total_sc_loss / total_batches, 4) + average_mag_loss = round(total_mag_loss / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "pesq_score": average_pesq_score, + "loss": average_loss, + "ae_loss": average_ae_loss, + "sc_loss": average_sc_loss, + "mag_loss": average_mag_loss, + }) + + # scheduler + lr_scheduler.step() + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save optim + torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = average_pesq_score + elif average_pesq_score > best_metric: + # great is better. + best_idx_epoch = idx_epoch + best_metric = average_pesq_score + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + + "pesq_score": average_pesq_score, + "loss": average_loss, + "ae_loss": average_ae_loss, + "sc_loss": average_sc_loss, + "mag_loss": average_mag_loss, + + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + + return + + +if __name__ == "__main__": + main() diff --git a/examples/clean_unet/step_3_evaluation.py b/examples/clean_unet/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/examples/clean_unet/step_3_evaluation.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/examples/clean_unet/yaml/config.yaml b/examples/clean_unet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c225846b9555521a3181d53b3f77045ef970ada2 --- /dev/null +++ b/examples/clean_unet/yaml/config.yaml @@ -0,0 +1,14 @@ +model_name: "clean_unet" + +channels_input: 1 +channels_output: 1 +channels_h: 64 +max_h: 768 +encoder_n_layers: 8 +kernel_size: 4 +stride: 2 +tsfm_n_layers: 5 +tsfm_n_head: 8 +tsfm_d_model: 512 +tsfm_d_inner: 2048 + diff --git a/examples/conv_tasnet/run.sh b/examples/conv_tasnet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..eaf165dd4a67d52b6094e32414798c1696ed2364 --- /dev/null +++ b/examples/conv_tasnet/run.sh @@ -0,0 +1,154 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \ +--max_epochs 400 + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=10000000 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/conv_tasnet/step_1_prepare_data.py b/examples/conv_tasnet/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..df073013368b4c4f0c76f9f9461684870d27875e --- /dev/null +++ b/examples/conv_tasnet/step_1_prepare_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import os +from pathlib import Path +import random +import sys + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--duration", default=4.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=10000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000): + data_dir = Path(data_dir) + for epoch_idx in range(max_epoch): + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + if np.sum(signal[begin: begin+win_size]) == 0: + continue + row = { + "epoch_idx": epoch_idx, + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def main(): + args = get_args() + + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=100000, + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=1, + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count > 0: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + } + row = json.dumps(row, ensure_ascii=False) + if random2 < (1 / 300 / 1): + fvalid.write(f"{row}\n") + else: + ftrain.write(f"{row}\n") + + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/conv_tasnet/step_2_train_model.py b/examples/conv_tasnet/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c55bb675e959f4366cc21f09cbaa076117167f0f --- /dev/null +++ b/examples/conv_tasnet/step_2_train_model.py @@ -0,0 +1,509 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/kaituoxu/Conv-TasNet/tree/master/src + +一般场景: + +目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。 + +高要求场景(如医疗助听、语音识别): +需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。 + +DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。 +https://arxiv.org/abs/2205.05474 + +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset +from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig +from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=200, type=int) + + parser.add_argument("--batch_size", default=8, type=int) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=1234, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = ConvTasNetConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info(f"set seed: {args.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + # skip=225000, + ) + valid_dataset = DenoiseJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = ConvTasNetPretrainedModel(config).to(device) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + optimizer = torch.optim.AdamW(model.parameters(), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + last_epoch = 1 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + logger.info(f"load state dict for optimizer.") + with open(optimizer_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optimizer.load_state_dict(state_dict) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + ae_loss_fn = nn.L1Loss(reduction="mean").to(device) + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[120, 240, 480], + hop_size_list=[25, 50, 100], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device) + + # training loop + + # state + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_ae_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + average_neg_stoi_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + for epoch_idx in range(max(0, last_epoch+1), args.max_epochs): + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_ae_loss = 0. + total_neg_si_snr_loss = 0. + total_neg_stoi_loss = 0. + total_mr_stft_loss = 0. + total_pesq_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + clean_audios, noisy_audios = train_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + denoise_audios = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)): + raise AssertionError("nan or inf in denoise_audios") + + ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios) + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios) + + # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss + # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss + # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss + # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss + loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_ae_loss += ae_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_neg_stoi_loss += neg_stoi_loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_pesq_loss += pesq_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_ae_loss = round(total_ae_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_pesq_loss = round(total_pesq_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "ae_loss": average_ae_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "neg_stoi_loss": average_neg_stoi_loss, + "mr_stft_loss": average_mr_stft_loss, + "pesq_loss": average_pesq_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_ae_loss = 0. + total_neg_si_snr_loss = 0. + total_neg_stoi_loss = 0. + total_mr_stft_loss = 0. + total_pesq_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx/1000)), + ) + for eval_batch in valid_data_loader: + clean_audios, noisy_audios = eval_batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + + denoise_audios = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios) + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios) + + # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss + # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss + # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss + # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss + loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_ae_loss += ae_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_neg_stoi_loss += neg_stoi_loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_pesq_loss += pesq_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_ae_loss = round(total_ae_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_pesq_loss = round(total_pesq_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "ae_loss": average_ae_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "neg_stoi_loss": average_neg_stoi_loss, + "mr_stft_loss": average_mr_stft_loss, + "pesq_loss": average_pesq_loss, + }) + + total_pesq_score = 0. + total_loss = 0. + total_ae_loss = 0. + total_neg_si_snr_loss = 0. + total_neg_stoi_loss = 0. + total_mr_stft_loss = 0. + total_pesq_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save optim + torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score > best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + "ae_loss": average_ae_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "neg_stoi_loss": average_neg_stoi_loss, + "mr_stft_loss": average_mr_stft_loss, + "pesq_loss": average_pesq_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + + return + + +if __name__ == "__main__": + main() diff --git a/examples/conv_tasnet/yaml/config.yaml b/examples/conv_tasnet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c18fc2e48b868e2a6e5a6531442bf1806cc57b1 --- /dev/null +++ b/examples/conv_tasnet/yaml/config.yaml @@ -0,0 +1,28 @@ +model_name: "conv_tasnet" + +sample_rate: 8000 +segment_size: 4 + +win_size: 20 +freq_bins: 256 +bottleneck_channels: 256 +num_speakers: 1 +num_blocks: 4 +num_sub_blocks: 8 +sub_blocks_channels: 512 +sub_blocks_kernel_size: 3 + +norm_type: "gLN" +causal: false +mask_nonlinear: "relu" + +min_snr_db: -10 +max_snr_db: 20 + +lr: 0.005 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.00005 + +eval_steps: 25000 diff --git a/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..3edc3f4a5fb335cba037231db704ec29bf387db4 --- /dev/null +++ b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py @@ -0,0 +1,90 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh + +1.2G +wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2 + +14G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2 + +38G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2 + +247M +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2 + + +""" +import argparse +import os +from pathlib import Path +import sys + +import numpy as np +from tqdm import tqdm + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +from scipy.io import wavfile + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data_dir", + default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech", + type=str + ) + parser.add_argument( + "--output_dir", + default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k", + type=str + ) + parser.add_argument("--sample_rate", default=8000, type=int) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # finished_set + finished_set = set() + for filename in tqdm(output_dir.glob("**/*.wav")): + name = filename.stem + finished_set.add(name) + print(f"finished_set count: {len(finished_set)}") + + for filename in tqdm(data_dir.glob("**/*.wav")): + label = filename.parts[-2] + name = filename.stem + # print(f"filename: {filename.as_posix()}") + if name in finished_set: + continue + + signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate) + + signal = signal * (1 << 15) + signal = np.array(signal, dtype=np.int16) + + to_file = output_dir / f"{label}/{name}.wav" + to_file.parent.mkdir(parents=True, exist_ok=True) + wavfile.write( + to_file.as_posix(), + rate=args.sample_rate, + data=signal, + ) + return + + +if __name__ == "__main__": + main() diff --git a/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..94a2d1eb49aaeab117753a7e19fa344bb677abdb --- /dev/null +++ b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py @@ -0,0 +1,129 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh + +1.2G +wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2 + +14G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2 + +38G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2 + +12G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.french_data.tar.bz2 + +43G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.german_speech.tar.bz2 + +7.9G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.italian_speech.tar.bz2 + +12G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.mandarin_speech.tar.bz2 + +3.1G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.russian_speech.tar.bz2 + +9.7G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.spanish_speech.tar.bz2 + +617M +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.singing_voice.tar.bz2 + +""" +import argparse +import os +from pathlib import Path +import sys + +import numpy as np +from tqdm import tqdm + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +from scipy.io import wavfile + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data_dir", + # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean", + # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech", + # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice", + # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.french_data\datasets\clean\french_data", + default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech", + # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech", + # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech", + # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.spanish_speech\datasets\clean\spanish_speech", + type=str + ) + parser.add_argument( + "--output_dir", + # default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k", + # default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k", + # default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k", + # default=r"E:\programmer\asr_datasets\denoise\dns-clean-french-speech-8k", + default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k", + # default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k", + # default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k", + # default=r"E:\programmer\asr_datasets\denoise\dns-clean-spanish-speech-8k", + type=str + ) + parser.add_argument("--sample_rate", default=8000, type=int) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # finished_set + finished_set = set() + for filename in tqdm(output_dir.glob("**/*.wav")): + filename = Path(filename) + relative_name = filename.relative_to(output_dir) + relative_name_ = relative_name.as_posix() + finished_set.add(relative_name_) + print(f"finished_set count: {len(finished_set)}") + + for filename in tqdm(data_dir.glob("**/*.wav")): + relative_name = filename.relative_to(data_dir) + relative_name_ = relative_name.as_posix() + if relative_name_ in finished_set: + continue + finished_set.add(relative_name_) + + try: + signal, _ = librosa.load(filename.as_posix(), mono=False, sr=args.sample_rate) + except Exception: + print(f"skip file: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError + + signal = signal * (1 << 15) + signal = np.array(signal, dtype=np.int16) + + to_file = output_dir / relative_name.as_posix() + to_file.parent.mkdir(parents=True, exist_ok=True) + wavfile.write( + to_file.as_posix(), + rate=args.sample_rate, + data=signal, + ) + return + + +if __name__ == "__main__": + main() diff --git a/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py new file mode 100644 index 0000000000000000000000000000000000000000..a18973b9780774793565d4862fc157dd2b343b1d --- /dev/null +++ b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py @@ -0,0 +1,71 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh + +1.2G +wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2 + +""" +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +import numpy as np + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +from scipy.io import wavfile + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data_dir", + default=r"E:\programmer\asr_datasets\dns-challenge\DEMAND\demand", + type=str + ) + parser.add_argument( + "--output_dir", + default=r"E:\programmer\asr_datasets\denoise\demand-8k", + type=str + ) + parser.add_argument("--sample_rate", default=8000, type=int) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for filename in data_dir.glob("**/ch01.wav"): + label = filename.parts[-2] + name = filename.stem + + signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate) + + signal = signal * (1 << 15) + signal = np.array(signal, dtype=np.int16) + + to_file = output_dir / f"{label}/{name}.wav" + to_file.parent.mkdir(parents=True, exist_ok=True) + wavfile.write( + to_file.as_posix(), + rate=args.sample_rate, + data=signal, + ) + return + + +if __name__ == '__main__': + main() diff --git a/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py new file mode 100644 index 0000000000000000000000000000000000000000..eed3ec1e2a9e3e2b1a42a0198be106d2b828454b --- /dev/null +++ b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py @@ -0,0 +1,93 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh + +1.2G +wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2 + +14G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2 + +38G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2 + +247M +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2 + +240M +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.impulse_responses.tar.bz2 + + +""" +import argparse +import os +from pathlib import Path +import sys + +import numpy as np +from tqdm import tqdm + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +from scipy.io import wavfile + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data_dir", + default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech", + type=str + ) + parser.add_argument( + "--output_dir", + default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k", + type=str + ) + parser.add_argument("--sample_rate", default=8000, type=int) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # finished_set + finished_set = set() + for filename in tqdm(output_dir.glob("**/*.wav")): + name = filename.stem + finished_set.add(name) + print(f"finished_set count: {len(finished_set)}") + + for filename in tqdm(data_dir.glob("**/*.wav")): + label = filename.parts[-2] + name = filename.stem + # print(f"filename: {filename.as_posix()}") + if name in finished_set: + continue + + signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate) + + signal = signal * (1 << 15) + signal = np.array(signal, dtype=np.int16) + + to_file = output_dir / f"{label}/{name}.wav" + to_file.parent.mkdir(parents=True, exist_ok=True) + wavfile.write( + to_file.as_posix(), + rate=args.sample_rate, + data=signal, + ) + return + + +if __name__ == "__main__": + main() diff --git a/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff2593ea86a13a31517c7f1c91610d246770b01 --- /dev/null +++ b/examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py @@ -0,0 +1,77 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh + +1.2G +wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2 + +14G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2 + +38G +wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2 + +""" +import argparse +import os +from pathlib import Path +import sys + +import numpy as np +from tqdm import tqdm + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +from scipy.io import wavfile + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data_dir", + default=r"E:\programmer\asr_datasets\dns-challenge\datasets.noise\datasets", + type=str + ) + parser.add_argument( + "--output_dir", + default=r"E:\programmer\asr_datasets\denoise\dns-noise-8k", + type=str + ) + parser.add_argument("--sample_rate", default=8000, type=int) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + for filename in tqdm(data_dir.glob("**/*.wav")): + label = filename.parts[-2] + name = filename.stem + # print(f"filename: {filename.as_posix()}") + + signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate) + + signal = signal * (1 << 15) + signal = np.array(signal, dtype=np.int16) + + to_file = output_dir / f"{label}/{name}.wav" + to_file.parent.mkdir(parents=True, exist_ok=True) + wavfile.write( + to_file.as_posix(), + rate=args.sample_rate, + data=signal, + ) + return + + +if __name__ == '__main__': + main() diff --git a/examples/data_preprocess/dns_challenge_to_8k/process_musan.py b/examples/data_preprocess/dns_challenge_to_8k/process_musan.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c77c004bb525b0c70d0ab34597c43ad8bb89bd --- /dev/null +++ b/examples/data_preprocess/dns_challenge_to_8k/process_musan.py @@ -0,0 +1,8 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://www.openslr.org/17/ +""" + +if __name__ == '__main__': + pass diff --git a/examples/data_preprocess/ms_snsd_to_8k/process_ms_snsd.py b/examples/data_preprocess/ms_snsd_to_8k/process_ms_snsd.py new file mode 100644 index 0000000000000000000000000000000000000000..c2499a23cb540cd8c8f72f7b54ef69fcae486ffb --- /dev/null +++ b/examples/data_preprocess/ms_snsd_to_8k/process_ms_snsd.py @@ -0,0 +1,70 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +MS-SNSD +https://github.com/microsoft/MS-SNSD +""" +import argparse +import os +from pathlib import Path +import sys + +import numpy as np +from tqdm import tqdm + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +from scipy.io import wavfile + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data_dir", + default=r"E:\programmer\asr_datasets\MS-SNSD", + type=str + ) + parser.add_argument( + "--output_dir", + default=r"E:\programmer\asr_datasets\denoise\ms-snsd-noise-8k", + type=str + ) + parser.add_argument("--sample_rate", default=8000, type=int) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + for filename in tqdm(data_dir.glob("**/*.wav")): + label = filename.parts[-2] + name = filename.stem + + if label not in ["noise_train", "noise_test", "clean_train", "clean_test"]: + continue + + signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate) + + signal = signal * (1 << 15) + signal = np.array(signal, dtype=np.int16) + + to_file = output_dir / f"{label}/{name}.wav" + to_file.parent.mkdir(parents=True, exist_ok=True) + wavfile.write( + to_file.as_posix(), + rate=args.sample_rate, + data=signal, + ) + return + + +if __name__ == "__main__": + main() diff --git a/examples/dfnet/run.sh b/examples/dfnet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..cbde2d32010635a5f8f5c1e6222f1d2b31edd598 --- /dev/null +++ b/examples/dfnet/run.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech" + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-nx-dns3 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=10000000 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/dfnet/step_1_prepare_data.py b/examples/dfnet/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8fa9bf732f2f1d2682a76f22cea3f426c60c1b --- /dev/null +++ b/examples/dfnet/step_1_prepare_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import os +from pathlib import Path +import random +import sys + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--duration", default=4.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=10000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000): + data_dir = Path(data_dir) + for epoch_idx in range(max_epoch): + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + if np.sum(signal[begin: begin+win_size]) == 0: + continue + row = { + "epoch_idx": epoch_idx, + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def main(): + args = get_args() + + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=100000, + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=1, + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset jsonl") + with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count > 0: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + } + row = json.dumps(row, ensure_ascii=False) + if random2 < (1 / 300 / 1): + fvalid.write(f"{row}\n") + else: + ftrain.write(f"{row}\n") + + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dfnet/step_2_train_model.py b/examples/dfnet/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..804f6cbebfc54f26971e855d75b876f239c2b592 --- /dev/null +++ b/examples/dfnet/step_2_train_model.py @@ -0,0 +1,461 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/Rikorose/DeepFilterNet +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +from fontTools.varLib.plot import stops + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig +from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=10, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + snr_db_list = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = DfNetConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + # skip=225000, + ) + valid_dataset = DenoiseJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = DfNetPretrainedModel(config).to(device) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.parameters(), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 1 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + + # training loop + + # state + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_mr_stft_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + average_mask_loss = 1000000000 + average_lsnr_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_lsnr_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + clean_audios, noisy_audios = train_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) + + mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_lsnr_loss += lsnr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + average_lsnr_loss = round(total_lsnr_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + "lsnr_loss": average_lsnr_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_lsnr_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx/1000)), + ) + for eval_batch in valid_data_loader: + clean_audios, noisy_audios = eval_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) + + mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_lsnr_loss += lsnr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + average_lsnr_loss = round(total_lsnr_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + "lsnr_loss": average_lsnr_loss, + }) + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_lsnr_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + "lsnr_loss": average_lsnr_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dfnet/yaml/config.yaml b/examples/dfnet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51369c4da056b8edde7fd84c4fde30b30b3122b9 --- /dev/null +++ b/examples/dfnet/yaml/config.yaml @@ -0,0 +1,74 @@ +model_name: "dfnet" + +# spec +sample_rate: 8000 +nfft: 512 +win_size: 200 +hop_size: 80 + +spec_bins: 256 + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +conv_lookahead: 0 + +convt_kernel_size_inner: + - 1 + - 3 + +embedding_hidden_size: 256 +encoder_combine_op: "concat" + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +encoder_linear_groups: 32 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +df_decoder_hidden_size: 256 +df_num_layers: 2 +df_order: 5 +df_bins: 96 +df_gru_skip: "grouped_linear" +df_decoder_linear_groups: 16 +df_pathway_kernel_size_t: 5 +df_lookahead: 2 + +# lsnr +n_frame: 3 +lsnr_max: 30 +lsnr_min: -15 +norm_tau: 1. + +# data +min_snr_db: -10 +max_snr_db: 20 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 8 +batch_size: 64 +eval_steps: 10000 + +# runtime +use_post_filter: true diff --git a/examples/dfnet2/run.sh b/examples/dfnet2/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..d6d88abda59587085f3f53d566b95ffaac74010a --- /dev/null +++ b/examples/dfnet2/run.sh @@ -0,0 +1,164 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech" + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-dns3 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dfnet2-nx2-dns3 --final_model_name dfnet2-nx2-dns3 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=-1 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/dfnet2/step_1_prepare_data.py b/examples/dfnet2/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfb4d8d988cd137a7cd4f53194ac26dcbc45525 --- /dev/null +++ b/examples/dfnet2/step_1_prepare_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import os +from pathlib import Path +import random +import sys + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=-1, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000): + data_dir = Path(data_dir) + for epoch_idx in range(max_epoch): + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + if np.sum(signal[begin: begin+win_size]) == 0: + continue + row = { + "epoch_idx": epoch_idx, + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def main(): + args = get_args() + + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=100000, + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=1, + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset jsonl") + with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count > 0: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + } + row = json.dumps(row, ensure_ascii=False) + if random2 < (1 / 300 / 1): + fvalid.write(f"{row}\n") + else: + ftrain.write(f"{row}\n") + + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dfnet2/step_2_train_model.py b/examples/dfnet2/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5991663a6fd0c9541c44eee9a5ba1150ecc9d825 --- /dev/null +++ b/examples/dfnet2/step_2_train_model.py @@ -0,0 +1,469 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/Rikorose/DeepFilterNet +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +from fontTools.varLib.plot import stops + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config +from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2, DfNet2PretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=30, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + snr_db_list = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = DfNet2Config.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + # skip=225000, + ) + valid_dataset = DenoiseJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = DfNet2PretrainedModel(config).to(device) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.parameters(), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 1 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + + # training loop + + # state + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_mr_stft_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + average_mask_loss = 1000000000 + average_lsnr_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_lsnr_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + clean_audios, noisy_audios = train_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) + # est_wav shape: [b, 1, n_samples] + est_wav = torch.squeeze(est_wav, dim=1) + # est_wav shape: [b, n_samples] + + mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss. continue.") + continue + + denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_lsnr_loss += lsnr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + average_lsnr_loss = round(total_lsnr_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + "lsnr_loss": average_lsnr_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + with torch.no_grad(): + torch.cuda.empty_cache() + + model.eval() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_lsnr_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx/1000)), + ) + for eval_batch in valid_data_loader: + clean_audios, noisy_audios = eval_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios) + # est_wav shape: [b, 1, n_samples] + est_wav = torch.squeeze(est_wav, dim=1) + # est_wav shape: [b, n_samples] + + mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss. continue.") + continue + + denoise_audios_list_r = list(est_wav.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_lsnr_loss += lsnr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + average_lsnr_loss = round(total_lsnr_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + "lsnr_loss": average_lsnr_loss, + }) + + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_lsnr_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + "lsnr_loss": average_lsnr_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dfnet2/yaml/config.yaml b/examples/dfnet2/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1f3aa86a642e0f8fc385d7881bd80f8da8cec62f --- /dev/null +++ b/examples/dfnet2/yaml/config.yaml @@ -0,0 +1,75 @@ +model_name: "dfnet2" + +# spec +sample_rate: 8000 +nfft: 512 +win_size: 200 +hop_size: 80 + +spec_bins: 256 +erb_bins: 32 +min_freq_bins_for_erb: 2 +use_ema_norm: true + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +convt_kernel_size_inner: + - 1 + - 3 + +embedding_hidden_size: 256 +encoder_combine_op: "concat" + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +encoder_linear_groups: 32 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +df_decoder_hidden_size: 256 +df_num_layers: 2 +df_order: 5 +df_bins: 96 +df_gru_skip: "grouped_linear" +df_decoder_linear_groups: 16 +df_pathway_kernel_size_t: 5 +df_lookahead: 2 + +# lsnr +n_frame: 3 +lsnr_max: 30 +lsnr_min: -15 +norm_tau: 1. + +# data +min_snr_db: -5 +max_snr_db: 40 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 8 +batch_size: 96 +eval_steps: 10000 + +# runtime +use_post_filter: true diff --git a/examples/dtln/run.sh b/examples/dtln/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..5065cb79dc80556b99efc5363e5ae71fb2482abf --- /dev/null +++ b/examples/dtln/run.sh @@ -0,0 +1,171 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \ +--config_file "yaml/config-256.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" + + +sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \ +--config_file "yaml/config-512.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \ +--config_file "yaml/config-1024.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3 --final_model_name dtln-256-nx2-dns3 \ +--config_file "yaml/config-256.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=-1 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/dtln/step_1_prepare_data.py b/examples/dtln/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfb4d8d988cd137a7cd4f53194ac26dcbc45525 --- /dev/null +++ b/examples/dtln/step_1_prepare_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import os +from pathlib import Path +import random +import sys + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=-1, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000): + data_dir = Path(data_dir) + for epoch_idx in range(max_epoch): + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + if np.sum(signal[begin: begin+win_size]) == 0: + continue + row = { + "epoch_idx": epoch_idx, + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def main(): + args = get_args() + + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=100000, + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=1, + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset jsonl") + with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count > 0: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + } + row = json.dumps(row, ensure_ascii=False) + if random2 < (1 / 300 / 1): + fvalid.write(f"{row}\n") + else: + ftrain.write(f"{row}\n") + + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dtln/step_2_train_model.py b/examples/dtln/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7395110fc489894d183dd511c0d8ae72fc636c10 --- /dev/null +++ b/examples/dtln/step_2_train_model.py @@ -0,0 +1,437 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/breizhn/DTLN + +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig +from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=30, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + snr_db_list = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = DTLNConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + # skip=225000, + ) + valid_dataset = DenoiseJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = DTLNPretrainedModel(config).to(device) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.parameters(), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 1 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + + # training loop + + # state + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_mr_stft_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + clean_audios, noisy_audios = train_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + denoise_audios = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx/1000)), + ) + for eval_batch in valid_data_loader: + clean_audios, noisy_audios = eval_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + denoise_audios = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + + }) + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dtln/yaml/config-1024.yaml b/examples/dtln/yaml/config-1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b622f176b7eae1b1855f4e420035a4726554f04 --- /dev/null +++ b/examples/dtln/yaml/config-1024.yaml @@ -0,0 +1,29 @@ +model_name: "DTLN" + +# spec +sample_rate: 8000 +fft_size: 512 +hop_size: 128 +win_type: hann + +# data +min_snr_db: -5 +max_snr_db: 25 + +# model +encoder_size: 1024 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 4 +batch_size: 64 +eval_steps: 15000 diff --git a/examples/dtln/yaml/config-256.yaml b/examples/dtln/yaml/config-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0da3b47d03beeb581fac6c7d3c9d8459879c79d6 --- /dev/null +++ b/examples/dtln/yaml/config-256.yaml @@ -0,0 +1,29 @@ +model_name: "DTLN" + +# spec +sample_rate: 8000 +fft_size: 256 +hop_size: 128 +win_type: hann + +# data +min_snr_db: -5 +max_snr_db: 25 + +# model +encoder_size: 256 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 4 +batch_size: 64 +eval_steps: 15000 diff --git a/examples/dtln/yaml/config-512.yaml b/examples/dtln/yaml/config-512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..517fd3dea974816df14a395c58173bbbe234aada --- /dev/null +++ b/examples/dtln/yaml/config-512.yaml @@ -0,0 +1,29 @@ +model_name: "DTLN" + +# spec +sample_rate: 8000 +fft_size: 512 +hop_size: 128 +win_type: hann + +# data +min_snr_db: -5 +max_snr_db: 25 + +# model +encoder_size: 512 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 4 +batch_size: 64 +eval_steps: 15000 diff --git a/examples/dtln_mp3_to_wav/run.sh b/examples/dtln_mp3_to_wav/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..328fcbe7089bf5b8eefa6afded8253390c4bc867 --- /dev/null +++ b/examples/dtln_mp3_to_wav/run.sh @@ -0,0 +1,168 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \ +--config_file "yaml/config-256.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" + + +sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \ +--config_file "yaml/config-512.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \ +--config_file "yaml/config-1024.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" + + +bash run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3-mp3 --final_model_name dtln-256-nx2-dns3-mp3 \ +--config_file "yaml/config-256.yaml" \ +--audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \ + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data + +max_count=-1 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --audio_dir "${audio_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/dtln_mp3_to_wav/step_1_prepare_data.py b/examples/dtln_mp3_to_wav/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2aeafb115b817fd7fcee4d1444cb5829f5766010 --- /dev/null +++ b/examples/dtln_mp3_to_wav/step_1_prepare_data.py @@ -0,0 +1,127 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import os +from pathlib import Path +import random +import sys + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--audio_dir", + default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech", + type=str + ) + + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--duration", default=4.0, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=-1, type=int) + + args = parser.parse_args() + return args + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1): + data_dir = Path(data_dir) + for epoch_idx in range(max_epoch): + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + if np.sum(signal[begin: begin+win_size]) == 0: + continue + row = { + "epoch_idx": epoch_idx, + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def main(): + args = get_args() + + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + audio_dir = Path(args.audio_dir) + + audio_generator = target_second_signal_generator( + audio_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=1, + ) + count = 0 + process_bar = tqdm(desc="build dataset jsonl") + with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: + for audio in audio_generator: + if count >= args.max_count > 0: + break + + filename = audio["filename"] + raw_duration = audio["raw_duration"] + offset = audio["offset"] + duration = audio["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "filename": filename, + "raw_duration": raw_duration, + "offset": offset, + "duration": duration, + + "random1": random1, + } + row = json.dumps(row, ensure_ascii=False) + if random2 < (1 / 300): + fvalid.write(f"{row}\n") + else: + ftrain.write(f"{row}\n") + + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + "duration_hours": round(duration_hours, 4), + }) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dtln_mp3_to_wav/step_2_train_model.py b/examples/dtln_mp3_to_wav/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..617670c9580d6b94312853c94d342acbbb82b7da --- /dev/null +++ b/examples/dtln_mp3_to_wav/step_2_train_model.py @@ -0,0 +1,445 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/breizhn/DTLN + +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig +from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=30, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + mp3_waveform_list = list() + wav_waveform_list = list() + + for sample in batch: + mp3_waveform: torch.Tensor = sample["mp3_waveform"] + wav_waveform: torch.Tensor = sample["wav_waveform"] + + mp3_waveform_list.append(mp3_waveform) + wav_waveform_list.append(wav_waveform) + + mp3_waveform_list = torch.stack(mp3_waveform_list) + wav_waveform_list = torch.stack(wav_waveform_list) + + # assert + if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)): + raise AssertionError("nan or inf in mp3_waveform_list") + if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)): + raise AssertionError("nan or inf in wav_waveform_list") + + return mp3_waveform_list, wav_waveform_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = DTLNConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = Mp3ToWavJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + # skip=225000, + ) + valid_dataset = Mp3ToWavJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = DTLNPretrainedModel(config).to(device) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.parameters(), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 1 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + audio_l1_loss_fn = nn.L1Loss(reduction="mean") + + # training loop + + # state + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_mr_stft_loss = 1000000000 + average_audio_l1_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_audio_l1_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + mp3_audios, wav_audios = train_batch + noisy_audios: torch.Tensor = mp3_audios.to(device) + clean_audios: torch.Tensor = wav_audios.to(device) + + denoise_audios = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_audio_l1_loss += audio_l1_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "audio_l1_loss": average_audio_l1_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_audio_l1_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx/1000)), + ) + for eval_batch in valid_data_loader: + mp3_audios, wav_audios = eval_batch + noisy_audios: torch.Tensor = mp3_audios.to(device) + clean_audios: torch.Tensor = wav_audios.to(device) + + denoise_audios = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_audio_l1_loss += audio_l1_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "audio_l1_loss": average_audio_l1_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + + }) + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_audio_l1_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "audio_l1_loss": average_audio_l1_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + + return + + +if __name__ == "__main__": + main() diff --git a/examples/dtln_mp3_to_wav/yaml/config-1024.yaml b/examples/dtln_mp3_to_wav/yaml/config-1024.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b622f176b7eae1b1855f4e420035a4726554f04 --- /dev/null +++ b/examples/dtln_mp3_to_wav/yaml/config-1024.yaml @@ -0,0 +1,29 @@ +model_name: "DTLN" + +# spec +sample_rate: 8000 +fft_size: 512 +hop_size: 128 +win_type: hann + +# data +min_snr_db: -5 +max_snr_db: 25 + +# model +encoder_size: 1024 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 4 +batch_size: 64 +eval_steps: 15000 diff --git a/examples/dtln_mp3_to_wav/yaml/config-256.yaml b/examples/dtln_mp3_to_wav/yaml/config-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0da3b47d03beeb581fac6c7d3c9d8459879c79d6 --- /dev/null +++ b/examples/dtln_mp3_to_wav/yaml/config-256.yaml @@ -0,0 +1,29 @@ +model_name: "DTLN" + +# spec +sample_rate: 8000 +fft_size: 256 +hop_size: 128 +win_type: hann + +# data +min_snr_db: -5 +max_snr_db: 25 + +# model +encoder_size: 256 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 4 +batch_size: 64 +eval_steps: 15000 diff --git a/examples/dtln_mp3_to_wav/yaml/config-512.yaml b/examples/dtln_mp3_to_wav/yaml/config-512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..517fd3dea974816df14a395c58173bbbe234aada --- /dev/null +++ b/examples/dtln_mp3_to_wav/yaml/config-512.yaml @@ -0,0 +1,29 @@ +model_name: "DTLN" + +# spec +sample_rate: 8000 +fft_size: 512 +hop_size: 128 +win_type: hann + +# data +min_snr_db: -5 +max_snr_db: 25 + +# model +encoder_size: 512 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 4 +batch_size: 64 +eval_steps: 15000 diff --git a/examples/frcrn/run.sh b/examples/frcrn/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..0cdd5a10c69a4827bfdb443d594a42886c35ea1f --- /dev/null +++ b/examples/frcrn/run.sh @@ -0,0 +1,159 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \ +--config_file "yaml/config-10.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \ +--config_file "yaml/config-10.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=10000000 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/frcrn/step_1_prepare_data.py b/examples/frcrn/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfb4d8d988cd137a7cd4f53194ac26dcbc45525 --- /dev/null +++ b/examples/frcrn/step_1_prepare_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import os +from pathlib import Path +import random +import sys + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=-1, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000): + data_dir = Path(data_dir) + for epoch_idx in range(max_epoch): + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + if np.sum(signal[begin: begin+win_size]) == 0: + continue + row = { + "epoch_idx": epoch_idx, + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def main(): + args = get_args() + + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=100000, + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=1, + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset jsonl") + with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count > 0: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + } + row = json.dumps(row, ensure_ascii=False) + if random2 < (1 / 300 / 1): + fvalid.write(f"{row}\n") + else: + ftrain.write(f"{row}\n") + + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/frcrn/step_2_train_model.py b/examples/frcrn/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e0b690e5192a53394021e1a01bc57ce7c635f7 --- /dev/null +++ b/examples/frcrn/step_2_train_model.py @@ -0,0 +1,457 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2206.07293 + +FRCRN 论文中: +在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33 + +WSJ0 包含约 80小时的纯净英语语音录音. + +我的音频大约是 1300 小时, 则预期大约需要 10个 epoch +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig +from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=30, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = FRCRNConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + # skip=225000, + ) + valid_dataset = DenoiseJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = FRCRNPretrainedModel(config).to(device) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 0 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + # optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + # logger.info(f"load state dict for optimizer.") + # with open(optimizer_pth.as_posix(), "rb") as f: + # state_dict = torch.load(f, map_location="cpu", weights_only=True) + # optimizer.load_state_dict(state_dict) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + + # training loop + + # state + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + average_mask_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + clean_audios, noisy_audios = train_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + est_spec, est_wav, est_mask = model.forward(noisy_audios) + denoise_audios = est_wav + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx/1000)), + ) + for eval_batch in valid_data_loader: + clean_audios, noisy_audios = eval_batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + + est_spec, est_wav, est_mask = model.forward(noisy_audios) + denoise_audios = est_wav + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + }) + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + + return + + +if __name__ == "__main__": + main() diff --git a/examples/frcrn/yaml/config-10.yaml b/examples/frcrn/yaml/config-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2be0c666314c955a33db19e7f25ffea30d784d92 --- /dev/null +++ b/examples/frcrn/yaml/config-10.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 128 +win_size: 128 +hop_size: 64 +win_type: hann + +use_complex_networks: true +model_depth: 10 +model_complexity: -1 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 20000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/examples/frcrn/yaml/config-14.yaml b/examples/frcrn/yaml/config-14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..707a9506ae459ba555c17b187d6631bbf63d681f --- /dev/null +++ b/examples/frcrn/yaml/config-14.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 640 +win_size: 640 +hop_size: 320 +win_type: hann + +use_complex_networks: true +model_depth: 14 +model_complexity: -1 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 10000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/examples/frcrn/yaml/config-20.yaml b/examples/frcrn/yaml/config-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3182a4024e561560d8688bdc92617dfcb000d6e7 --- /dev/null +++ b/examples/frcrn/yaml/config-20.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 512 +win_size: 512 +hop_size: 256 +win_type: hann + +use_complex_networks: true +model_depth: 20 +model_complexity: 45 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 10000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/examples/frcrn_mp3_to_wav/run.sh b/examples/frcrn_mp3_to_wav/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..56a78e3598f6ea3da32ca8d383861879d096f371 --- /dev/null +++ b/examples/frcrn_mp3_to_wav/run.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \ +--config_file "yaml/config-10.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \ +--config_file "yaml/config-10.yaml" \ +--audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \ + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data + +max_count=10000000 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --audio_dir "${audio_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/frcrn_mp3_to_wav/step_1_prepare_data.py b/examples/frcrn_mp3_to_wav/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea3ac05ed29805d186564b19a82be62095a2152 --- /dev/null +++ b/examples/frcrn_mp3_to_wav/step_1_prepare_data.py @@ -0,0 +1,127 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import os +from pathlib import Path +import random +import sys + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--audio_dir", + default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech", + type=str + ) + + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--duration", default=4.0, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=-1, type=int) + + args = parser.parse_args() + return args + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1): + data_dir = Path(data_dir) + for epoch_idx in range(max_epoch): + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + if np.sum(signal[begin: begin+win_size]) == 0: + continue + row = { + "epoch_idx": epoch_idx, + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def main(): + args = get_args() + + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + audio_dir = Path(args.audio_dir) + + audio_generator = target_second_signal_generator( + audio_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate, + max_epoch=1, + ) + count = 0 + process_bar = tqdm(desc="build dataset jsonl") + with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid: + for audio in audio_generator: + if count >= args.max_count > 0: + break + + filename = audio["filename"] + raw_duration = audio["raw_duration"] + offset = audio["offset"] + duration = audio["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "filename": filename, + "raw_duration": raw_duration, + "offset": offset, + "duration": duration, + + "random1": random1, + } + row = json.dumps(row, ensure_ascii=False) + if random2 < (1 / 10): + fvalid.write(f"{row}\n") + else: + ftrain.write(f"{row}\n") + + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + "duration_hours": round(duration_hours, 4), + }) + + return + + +if __name__ == "__main__": + main() diff --git a/examples/frcrn_mp3_to_wav/step_2_train_model.py b/examples/frcrn_mp3_to_wav/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..007f9d80dbabb348b918453beb02811a3094a513 --- /dev/null +++ b/examples/frcrn_mp3_to_wav/step_2_train_model.py @@ -0,0 +1,442 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig +from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=30, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + mp3_waveform_list = list() + wav_waveform_list = list() + + for sample in batch: + mp3_waveform: torch.Tensor = sample["mp3_waveform"] + wav_waveform: torch.Tensor = sample["wav_waveform"] + + mp3_waveform_list.append(mp3_waveform) + wav_waveform_list.append(wav_waveform) + + mp3_waveform_list = torch.stack(mp3_waveform_list) + wav_waveform_list = torch.stack(wav_waveform_list) + + # assert + if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)): + raise AssertionError("nan or inf in mp3_waveform_list") + if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)): + raise AssertionError("nan or inf in wav_waveform_list") + + return mp3_waveform_list, wav_waveform_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = FRCRNConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = Mp3ToWavJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + # skip=225000, + ) + valid_dataset = Mp3ToWavJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = FRCRNPretrainedModel(config).to(device) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 0 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + # optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + # logger.info(f"load state dict for optimizer.") + # with open(optimizer_pth.as_posix(), "rb") as f: + # state_dict = torch.load(f, map_location="cpu", weights_only=True) + # optimizer.load_state_dict(state_dict) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + + # training loop + + # state + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + average_mask_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + mp3_audios, wav_audios = train_batch + noisy_audios: torch.Tensor = mp3_audios.to(device) + clean_audios: torch.Tensor = wav_audios.to(device) + + est_spec, est_wav, est_mask = model.forward(noisy_audios) + denoise_audios = est_wav + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx/1000)), + ) + for eval_batch in valid_data_loader: + mp3_audios, wav_audios = eval_batch + noisy_audios: torch.Tensor = mp3_audios.to(device) + clean_audios: torch.Tensor = wav_audios.to(device) + + est_spec, est_wav, est_mask = model.forward(noisy_audios) + denoise_audios = est_wav + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_mask_loss += mask_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + average_mask_loss = round(total_mask_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + }) + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_mask_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + "mask_loss": average_mask_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + + return + + +if __name__ == "__main__": + main() diff --git a/examples/frcrn_mp3_to_wav/yaml/config-10.yaml b/examples/frcrn_mp3_to_wav/yaml/config-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2be0c666314c955a33db19e7f25ffea30d784d92 --- /dev/null +++ b/examples/frcrn_mp3_to_wav/yaml/config-10.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 128 +win_size: 128 +hop_size: 64 +win_type: hann + +use_complex_networks: true +model_depth: 10 +model_complexity: -1 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 20000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/examples/frcrn_mp3_to_wav/yaml/config-14.yaml b/examples/frcrn_mp3_to_wav/yaml/config-14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..707a9506ae459ba555c17b187d6631bbf63d681f --- /dev/null +++ b/examples/frcrn_mp3_to_wav/yaml/config-14.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 640 +win_size: 640 +hop_size: 320 +win_type: hann + +use_complex_networks: true +model_depth: 14 +model_complexity: -1 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 10000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/examples/frcrn_mp3_to_wav/yaml/config-20.yaml b/examples/frcrn_mp3_to_wav/yaml/config-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3182a4024e561560d8688bdc92617dfcb000d6e7 --- /dev/null +++ b/examples/frcrn_mp3_to_wav/yaml/config-20.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 512 +win_size: 512 +hop_size: 256 +win_type: hann + +use_complex_networks: true +model_depth: 20 +model_complexity: 45 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 10000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/examples/lstm/run.sh b/examples/lstm/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..52d9691c1c3ced86192dfae6da6bed1106306390 --- /dev/null +++ b/examples/lstm/run.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/lstm/step_1_prepare_data.py b/examples/lstm/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..16398c263d39e9accfff6ffb4a64650037eba385 --- /dev/null +++ b/examples/lstm/step_1_prepare_data.py @@ -0,0 +1,197 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_nsr_db", default=-20, type=float) + parser.add_argument("--max_nsr_db", default=5, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/lstm/step_2_train_model.py b/examples/lstm/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b41e437cceb8f41737df94bfaf1dc89a8936755b --- /dev/null +++ b/examples/lstm/step_2_train_model.py @@ -0,0 +1,444 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig +from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-3, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=10, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + snr_db_list = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = LstmConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + # skip=225000, + ) + valid_dataset = DenoiseJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = LstmPretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.parameters(), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 1 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + logger.info(f"load state dict for optimizer.") + with open(optimizer_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optimizer.load_state_dict(state_dict) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + + # training loop + logger.info("training") + + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_mr_stft_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + logger.info("training") + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch: {}".format(epoch_idx), + ) + for train_batch in train_data_loader: + clean_audios, noisy_audios = train_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + denoise_audios, _, _ = model.forward(noisy_audios) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx / 1000)), + ) + + for eval_batch in valid_data_loader: + clean_audios, noisy_audios = eval_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + denoise_audios, _, _ = model.forward(noisy_audios) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + }) + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(epoch_idx) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + return + + +if __name__ == '__main__': + main() diff --git a/examples/lstm/step_3_evaluation.py b/examples/lstm/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7736b5e5e4813714f6ea33d2d3fc53661eb9a0 --- /dev/null +++ b/examples/lstm/step_3_evaluation.py @@ -0,0 +1,239 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +stft_power = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=2.0, + window_fn=torch.hamming_window, +) + + +stft_complex = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=None, + window_fn=torch.hamming_window, +) + + +istft = torchaudio.transforms.InverseSpectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, +) + + +def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): + mix_spec_complex = mix_spec_complex.detach().cpu() + speech_irm_prediction = speech_irm_prediction.detach().cpu() + + mask_speech = speech_irm_prediction + mask_noise = 1.0 - speech_irm_prediction + + speech_spec = mix_spec_complex * mask_speech + noise_spec = mix_spec_complex * mask_noise + + speech_wave = istft.forward(speech_spec) + noise_wave = istft.forward(noise_spec) + + return speech_wave, noise_wave + + +def save_audios(noise_wave: torch.Tensor, + speech_wave: torch.Tensor, + mix_wave: torch.Tensor, + speech_wave_enhanced: torch.Tensor, + noise_wave_enhanced: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_wave.wav" + torchaudio.save(filename, noise_wave, sample_rate) + filename = output_dir / "speech_wave.wav" + torchaudio.save(filename, speech_wave, sample_rate) + filename = output_dir / "mix_wave.wav" + torchaudio.save(filename, mix_wave, sample_rate) + + filename = output_dir / "speech_wave_enhanced.wav" + torchaudio.save(filename, speech_wave_enhanced, sample_rate) + filename = output_dir / "noise_wave_enhanced.wav" + torchaudio.save(filename, noise_wave_enhanced, sample_rate) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + model = LstmPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + model.to(device) + model.eval() + + # optimizer + logger.info("prepare loss_fn") + mse_loss = nn.MSELoss( + reduction="mean", + ) + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_wave, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + speech_wave, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + mix_wave: np.ndarray = mix_speech_and_noise( + speech=speech_wave, + noise=noise_wave, + snr_db=snr_db, + ) + noise_wave = torch.tensor(noise_wave, dtype=torch.float32) + speech_wave = torch.tensor(speech_wave, dtype=torch.float32) + mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) + + noise_wave = noise_wave.unsqueeze(dim=0) + speech_wave = speech_wave.unsqueeze(dim=0) + mix_wave = mix_wave.unsqueeze(dim=0) + + noise_spec: torch.Tensor = stft_power.forward(noise_wave) + speech_spec: torch.Tensor = stft_power.forward(speech_wave) + mix_spec: torch.Tensor = stft_power.forward(mix_wave) + mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) + + speech_irm = speech_spec / (noise_spec + speech_spec) + speech_irm = torch.pow(speech_irm, 1.0) + + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + with torch.no_grad(): + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) + save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/lstm/yaml/config.yaml b/examples/lstm/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbc362dccca99c9dcd7bc746998df1c169ef4c10 --- /dev/null +++ b/examples/lstm/yaml/config.yaml @@ -0,0 +1,32 @@ +model_name: "lstm" + +# spec +sample_rate: 8000 +segment_size: 32000 +n_fft: 320 +win_size: 320 +hop_size: 160 +win_type: hann + +# data +max_snr_db: 20 +min_snr_db: -10 + +# model +hidden_size: 512 +num_layers: 3 +dropout: 0.1 + +# train +max_epochs: 100 +batch_size: 32 +num_workers: 4 +seed: 1234 + +lr: 0.001 +lr_scheduler: CosineAnnealingLR +lr_scheduler_kwargs: {} + +weight_decay: 0.00001 +clip_grad_norm: 10.0 +eval_steps: 25000 diff --git a/examples/mpnet/run.sh b/examples/mpnet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..2d6482450be49f1d78985a47ba71afc556b7ed70 --- /dev/null +++ b/examples/mpnet/run.sh @@ -0,0 +1,166 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train" + + +sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + +sh run.sh --stage 5 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \ +--max_epochs 100 + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/mpnet/step_1_prepare_data.py b/examples/mpnet/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..91afbbd02541035f6aa18ae4424a343d755dc53c --- /dev/null +++ b/examples/mpnet/step_1_prepare_data.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--scale", default=1, type=float) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + flag = random.random() + if flag > args.scale: + continue + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "count": count, + + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/mpnet/step_2_train_model.py b/examples/mpnet/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5b02a9fe2bf5fd533fdf9b6cc146a8d648a203 --- /dev/null +++ b/examples/mpnet/step_2_train_model.py @@ -0,0 +1,450 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/yxlu-0102/MP-SENet/blob/main/train.py +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig +from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel +from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses +from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft +from toolbox.torchaudio.models.mpnet.metrics import run_batch_pesq, run_pesq_score + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = MPNetConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + generator = MPNetPretrainedModel(config).to(device) + discriminator = MetricDiscriminatorPretrainedModel(config).to(device) + + # optimizer + logger.info("prepare optimizer, lr_scheduler") + num_params = 0 + for p in generator.parameters(): + num_params += p.numel() + logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6)) + + optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + + # resume training + last_epoch = -1 + for epoch_i in serialization_dir.glob("epoch-*"): + epoch_i = Path(epoch_i) + epoch_idx = epoch_i.stem.split("-")[1] + epoch_idx = int(epoch_idx) + if epoch_idx > last_epoch: + last_epoch = epoch_idx + + if last_epoch != -1: + logger.info(f"resume from epoch-{last_epoch}.") + generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt" + discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt" + optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth" + optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth" + + logger.info(f"load state dict for generator.") + with open(generator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + generator.load_state_dict(state_dict, strict=True) + logger.info(f"load state dict for discriminator.") + with open(discriminator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + discriminator.load_state_dict(state_dict, strict=True) + + logger.info(f"load state dict for optim_g.") + with open(optim_g_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_g.load_state_dict(state_dict) + logger.info(f"load state dict for optim_d.") + with open(optim_d_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_d.load_state_dict(state_dict) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch) + + # training loop + + # state + loss_d = 10000000000 + loss_g = 10000000000 + pesq_metric = 10000000000 + mag_err = 10000000000 + pha_err = 10000000000 + com_err = 10000000000 + stft_err = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + logger.info("training") + early_stop_flag = False + for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): + if early_stop_flag: + break + + # train + generator.train() + discriminator.train() + + total_loss_d = 0. + total_loss_g = 0. + total_batches = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + for batch in train_data_loader: + clean_audio, noisy_audio = batch + clean_audio = clean_audio.to(device) + noisy_audio = noisy_audio.to(device) + one_labels = torch.ones(clean_audio.shape[0]).to(device) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) + + audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy()) + pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb") + + # Discriminator + optim_d.zero_grad() + metric_r = discriminator.forward(clean_mag, clean_mag) + metric_g = discriminator.forward(clean_mag, mag_g_hat.detach()) + loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) + + if -1 in pesq_score_list: + # print("-1 in batch_pesq_score!") + loss_disc_g = 0 + else: + pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32) + loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten()) + + loss_disc_all = loss_disc_r + loss_disc_g + loss_disc_all.backward() + optim_d.step() + + # Generator + optim_g.zero_grad() + # L2 Magnitude Loss + loss_mag = F.mse_loss(clean_mag, mag_g) + # Anti-wrapping Phase Loss + loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) + loss_pha = loss_ip + loss_gd + loss_iaf + # L2 Complex Loss + loss_com = F.mse_loss(clean_com, com_g) * 2 + # L2 Consistency Loss + loss_stft = F.mse_loss(com_g, com_g_hat) * 2 + # Time Loss + loss_time = F.l1_loss(clean_audio, audio_g) + # Metric Loss + metric_g = discriminator.forward(clean_mag, mag_g_hat) + loss_metric = F.mse_loss(metric_g.flatten(), one_labels) + + loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2 + + loss_gen_all.backward() + optim_g.step() + + total_loss_d += loss_disc_all.item() + total_loss_g += loss_gen_all.item() + total_batches += 1 + + loss_d = round(total_loss_d / total_batches, 4) + loss_g = round(total_loss_g / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "loss_d": loss_d, + "loss_g": loss_g, + }) + + # evaluation + generator.eval() + discriminator.eval() + + torch.cuda.empty_cache() + total_pesq_score = 0. + total_mag_err = 0. + total_pha_err = 0. + total_com_err = 0. + total_stft_err = 0. + total_batches = 0. + + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + with torch.no_grad(): + for batch in valid_data_loader: + clean_audio, noisy_audio = batch + clean_audio = clean_audio.to(device) + noisy_audio = noisy_audio.to(device) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) + + audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + clean_audio_list = torch.split(clean_audio, 1, dim=0) + enhanced_audio_list = torch.split(audio_g, 1, dim=0) + clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list] + enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list] + pesq_score = run_pesq_score( + clean_audio_list, + enhanced_audio_list, + sample_rate = config.sample_rate, + mode = "nb", + ) + total_pesq_score += pesq_score + total_mag_err += F.mse_loss(clean_mag, mag_g).item() + val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) + total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() + total_com_err += F.mse_loss(clean_com, com_g).item() + total_stft_err += F.mse_loss(com_g, com_g_hat).item() + + total_batches += 1 + + pesq_metric = round(total_pesq_score / total_batches, 4) + mag_err = round(total_mag_err / total_batches, 4) + pha_err = round(total_pha_err / total_batches, 4) + com_err = round(total_com_err / total_batches, 4) + stft_err = round(total_stft_err / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + "stft_err": stft_err, + }) + + # scheduler + scheduler_g.step() + scheduler_d.step() + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + generator.save_pretrained(epoch_dir.as_posix()) + discriminator.save_pretrained(epoch_dir.as_posix()) + + # save optim + torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix()) + torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = pesq_metric + elif pesq_metric > best_metric: + # great is better. + best_idx_epoch = idx_epoch + best_metric = pesq_metric + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "loss_d": loss_d, + "loss_g": loss_g, + + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + "stft_err": stft_err, + + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + + return + + +if __name__ == "__main__": + main() diff --git a/examples/mpnet/step_3_evaluation.py b/examples/mpnet/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4c1ba29b83196dca3f3c52e0521383b4aa6b21 --- /dev/null +++ b/examples/mpnet/step_3_evaluation.py @@ -0,0 +1,187 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/yxlu-0102/MP-SENet/blob/main/inference.py +""" +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig +from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel +from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +def save_audios(noise_audio: torch.Tensor, + clean_audio: torch.Tensor, + noisy_audio: torch.Tensor, + enhanced_audio: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_audio.wav" + torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate, bits_per_sample=16) + filename = output_dir / "clean_audio.wav" + torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate, bits_per_sample=16) + filename = output_dir / "noisy_audio.wav" + torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate, bits_per_sample=16) + + filename = output_dir / "enhanced_audio.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate, bits_per_sample=16) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + config = MPNetConfig.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + generator = MPNetPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + generator.to(device) + generator.eval() + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_audio, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + clean_audio, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + noisy_audio: np.ndarray = mix_speech_and_noise( + speech=clean_audio, + noise=noise_audio, + snr_db=snr_db, + ) + noise_audio = torch.tensor(noise_audio, dtype=torch.float32) + clean_audio = torch.tensor(clean_audio, dtype=torch.float32) + noisy_audio: torch.Tensor = torch.tensor(noisy_audio, dtype=torch.float32) + + noise_audio = noise_audio.unsqueeze(dim=0) + clean_audio = clean_audio.unsqueeze(dim=0) + noisy_audio: torch.Tensor = noisy_audio.unsqueeze(dim=0) + + # inference + clean_audio = clean_audio.to(device) + noisy_audio = noisy_audio.to(device) + with torch.no_grad(): + noisy_mag, noisy_pha, noisy_com = mag_pha_stft( + noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor + ) + mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) + audio_g = mag_pha_istft( + mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor + ) + enhanced_audio = audio_g.detach() + + save_audios( + noise_audio, clean_audio, noisy_audio, + enhanced_audio, + args.evaluation_audio_dir + ) + + progress_bar.update(1) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/mpnet/yaml/config.yaml b/examples/mpnet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86f043265ea8057d865e9fabf20bd2f5a6013c48 --- /dev/null +++ b/examples/mpnet/yaml/config.yaml @@ -0,0 +1,30 @@ +model_name: "mpnet" + +num_gpus: 0 +batch_size: 3 +learning_rate: 0.0005 +adam_b1: 0.8 +adam_b2: 0.99 +lr_decay: 0.99 +seed: 1234 + +dense_channel: 64 +compress_factor: 0.3 +num_tsconformers: 4 +beta: 2.0 + +sample_rate: 8000 +segment_size: 16000 +n_fft: 512 +hop_size: 80 +win_size: 200 + +num_workers: 4 + +dist_config: + dist_backend: nccl + dist_url: tcp://localhost:54321 + world_size: 1 + +discriminator_dim: 32 +discriminator_in_channel: 2 diff --git a/examples/nx_clean_unet/run.sh b/examples/nx_clean_unet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..9e8ad4c5f09308a6942be414af869028eadd293c --- /dev/null +++ b/examples/nx_clean_unet/run.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train" + + +sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \ +--max_epochs 100 + + +sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \ +--max_epochs 100 --max_count 10000 + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=10000000 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/nx_clean_unet/step_1_prepare_data.py b/examples/nx_clean_unet/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfbdf266a33498229c6003bb4858c51ba9e28c0 --- /dev/null +++ b/examples/nx_clean_unet/step_1_prepare_data.py @@ -0,0 +1,201 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=10000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/nx_clean_unet/step_2_train_model.py b/examples/nx_clean_unet/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa1338023926ff3fdf1c78ce1fb8e91a4690c83 --- /dev/null +++ b/examples/nx_clean_unet/step_2_train_model.py @@ -0,0 +1,440 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/yxlu-0102/MP-SENet/blob/main/train.py +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig +from toolbox.torchaudio.models.nx_clean_unet.discriminator import MetricDiscriminator, MetricDiscriminatorPretrainedModel +from toolbox.torchaudio.models.nx_clean_unet.modeling_nx_clean_unet import NXCleanUNet, NXCleanUNetPretrainedModel +from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score +from toolbox.torchaudio.models.nx_clean_unet.utils import mag_pha_stft, mag_pha_istft +from toolbox.torchaudio.models.nx_clean_unet.loss import phase_losses + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = NXCleanUNetConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=16, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=16, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + generator = NXCleanUNetPretrainedModel(config).to(device) + discriminator = MetricDiscriminatorPretrainedModel(config).to(device) + + # optimizer + logger.info("prepare optimizer, lr_scheduler") + num_params = 0 + for p in generator.parameters(): + num_params += p.numel() + logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6)) + + optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + + # resume training + last_epoch = -1 + for epoch_i in serialization_dir.glob("epoch-*"): + epoch_i = Path(epoch_i) + epoch_idx = epoch_i.stem.split("-")[1] + epoch_idx = int(epoch_idx) + if epoch_idx > last_epoch: + last_epoch = epoch_idx + + if last_epoch != -1: + logger.info(f"resume from epoch-{last_epoch}.") + generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt" + discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt" + optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth" + optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth" + + logger.info(f"load state dict for generator.") + with open(generator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + generator.load_state_dict(state_dict, strict=True) + logger.info(f"load state dict for discriminator.") + with open(discriminator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + discriminator.load_state_dict(state_dict, strict=True) + + logger.info(f"load state dict for optim_g.") + with open(optim_g_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_g.load_state_dict(state_dict) + logger.info(f"load state dict for optim_d.") + with open(optim_d_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_d.load_state_dict(state_dict) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch) + + # training loop + + # state + loss_d = 10000000000 + loss_g = 10000000000 + pesq_metric = 10000000000 + mag_err = 10000000000 + pha_err = 10000000000 + com_err = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + logger.info("training") + for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): + # train + generator.train() + discriminator.train() + + total_loss_d = 0. + total_loss_g = 0. + total_batches = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + for batch in train_data_loader: + clean_audios, noisy_audios = batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + one_labels = torch.ones(clean_audios.shape[0]).to(device) + + audio_g = generator.forward(noisy_audios) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + + clean_audio_list = torch.split(clean_audios, 1, dim=0) + enhanced_audio_list = torch.split(audio_g, 1, dim=0) + clean_audio_list = [t.squeeze().detach().cpu().numpy() for t in clean_audio_list] + enhanced_audio_list = [t.squeeze().detach().cpu().numpy() for t in enhanced_audio_list] + + pesq_score_list: List[float] = run_batch_pesq(clean_audio_list, enhanced_audio_list, sample_rate=config.sample_rate, mode="nb") + + # Discriminator + optim_d.zero_grad() + metric_r = discriminator.forward(clean_audios, clean_audios) + metric_g = discriminator.forward(clean_audios, audio_g.detach()) + loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) + + if -1 in pesq_score_list: + # print("-1 in batch_pesq_score!") + loss_disc_g = 0 + else: + pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32) + loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten()) + + loss_disc_all = loss_disc_r + loss_disc_g + loss_disc_all.backward() + optim_d.step() + + # Generator + optim_g.zero_grad() + # L2 Magnitude Loss + loss_mag = F.mse_loss(clean_mag, mag_g) + # Anti-wrapping Phase Loss + loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) + loss_pha = loss_ip + loss_gd + loss_iaf + # L2 Complex Loss + loss_com = F.mse_loss(clean_com, com_g) * 2 + # L2 Consistency Loss + # Time Loss + loss_time = F.l1_loss(clean_audios, audio_g) + # Metric Loss + metric_g = discriminator.forward(clean_audios, audio_g) + # metric_g = discriminator.forward(clean_audios, audio_g.detach()) + loss_metric = F.mse_loss(metric_g.flatten(), one_labels) + + # loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2 + loss_gen_all = loss_mag * 0.1 + loss_pha * 0.1 + loss_com * 0.1 + loss_metric * 0.9 + loss_time * 0.9 + + loss_gen_all.backward() + optim_g.step() + + total_loss_d += loss_disc_all.item() + total_loss_g += loss_gen_all.item() + total_batches += 1 + + loss_d = round(total_loss_d / total_batches, 4) + loss_g = round(total_loss_g / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "loss_d": loss_d, + "loss_g": loss_g, + }) + + # evaluation + generator.eval() + discriminator.eval() + + torch.cuda.empty_cache() + total_pesq_score = 0. + total_mag_err = 0. + total_pha_err = 0. + total_com_err = 0. + total_batches = 0. + + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + with torch.no_grad(): + for batch in valid_data_loader: + clean_audios, noisy_audios = batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + + audio_g = generator.forward(noisy_audios) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + + clean_audio_list = torch.split(clean_audios, 1, dim=0) + enhanced_audio_list = torch.split(audio_g, 1, dim=0) + clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list] + enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list] + pesq_score = run_pesq_score( + clean_audio_list, + enhanced_audio_list, + sample_rate = config.sample_rate, + mode = "nb", + ) + total_pesq_score += pesq_score + total_mag_err += F.mse_loss(clean_mag, mag_g).item() + val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) + total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() + total_com_err += F.mse_loss(clean_com, com_g).item() + + total_batches += 1 + + pesq_metric = round(total_pesq_score / total_batches, 4) + mag_err = round(total_mag_err / total_batches, 4) + pha_err = round(total_pha_err / total_batches, 4) + com_err = round(total_com_err / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + }) + + # scheduler + scheduler_g.step() + scheduler_d.step() + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + generator.save_pretrained(epoch_dir.as_posix()) + discriminator.save_pretrained(epoch_dir.as_posix()) + + # save optim + torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix()) + torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = pesq_metric + elif pesq_metric > best_metric: + # great is better. + best_idx_epoch = idx_epoch + best_metric = pesq_metric + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "loss_d": loss_d, + "loss_g": loss_g, + + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + + return + + +if __name__ == "__main__": + main() diff --git a/examples/nx_clean_unet/step_3_evaluation.py b/examples/nx_clean_unet/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..7e86e343655195684e02b4ea095db136798319d5 --- /dev/null +++ b/examples/nx_clean_unet/step_3_evaluation.py @@ -0,0 +1,59 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig +from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel +from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def main(): + return + + +if __name__ == '__main__': + main() diff --git a/examples/nx_clean_unet/yaml/config.yaml b/examples/nx_clean_unet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48015af285e634d46a257bf9c15397a92ca1b2f0 --- /dev/null +++ b/examples/nx_clean_unet/yaml/config.yaml @@ -0,0 +1,42 @@ +model_name: "nx_clean_unet" + +sample_rate: 8000 +segment_size: 16000 +n_fft: 512 +win_size: 200 +hop_size: 80 + +down_sampling_num_layers: 6 +down_sampling_in_channels: 1 +down_sampling_hidden_channels: 64 +down_sampling_kernel_size: 4 +down_sampling_stride: 2 + +causal_in_channels: 1 +causal_out_channels: 1 +causal_kernel_size: 3 +causal_bias: false +causal_separable: true +causal_f_stride: 1 +causal_num_layers: 5 + +tsfm_hidden_size: 256 +tsfm_attention_heads: 8 +tsfm_num_blocks: 6 +tsfm_dropout_rate: 0.1 +tsfm_max_length: 512 +tsfm_chunk_size: 1 +tsfm_num_left_chunks: 128 +tsfm_num_right_chunks: 4 + +discriminator_dim: 32 +discriminator_in_channel: 2 + +compress_factor: 0.3 + +batch_size: 64 +learning_rate: 0.0005 +adam_b1: 0.8 +adam_b2: 0.99 +lr_decay: 0.99 +seed: 1234 diff --git a/examples/nx_denoise/run.sh b/examples/nx_denoise/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..910d7c4c5046b08900bc797aad0786857d785906 --- /dev/null +++ b/examples/nx_denoise/run.sh @@ -0,0 +1,154 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-denoise-aishell-20250228 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \ +--max_epochs 100 + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +max_count=10000000 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --max_count "${max_count}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/nx_denoise/step_1_prepare_data.py b/examples/nx_denoise/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfbdf266a33498229c6003bb4858c51ba9e28c0 --- /dev/null +++ b/examples/nx_denoise/step_1_prepare_data.py @@ -0,0 +1,201 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--max_count", default=10000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + if count >= args.max_count: + break + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/nx_denoise/step_2_train_model.py b/examples/nx_denoise/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..87fa224b7ff90fe17c6025ac1098526afb5d5ff8 --- /dev/null +++ b/examples/nx_denoise/step_2_train_model.py @@ -0,0 +1,440 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/yxlu-0102/MP-SENet/blob/main/train.py +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig +from toolbox.torchaudio.models.nx_denoise.discriminator import MetricDiscriminator, MetricDiscriminatorPretrainedModel +from toolbox.torchaudio.models.nx_denoise.modeling_nx_denoise import NXDenoise, NXDenoisePretrainedModel +from toolbox.torchaudio.models.nx_denoise.metrics import run_batch_pesq, run_pesq_score +from toolbox.torchaudio.models.nx_denoise.utils import mag_pha_stft, mag_pha_istft +from toolbox.torchaudio.models.nx_denoise.loss import phase_losses + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = NXDenoiseConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=16, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=16, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + generator = NXDenoisePretrainedModel(config).to(device) + discriminator = MetricDiscriminatorPretrainedModel(config).to(device) + + # optimizer + logger.info("prepare optimizer, lr_scheduler") + num_params = 0 + for p in generator.parameters(): + num_params += p.numel() + logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6)) + + optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + + # resume training + last_epoch = -1 + for epoch_i in serialization_dir.glob("epoch-*"): + epoch_i = Path(epoch_i) + epoch_idx = epoch_i.stem.split("-")[1] + epoch_idx = int(epoch_idx) + if epoch_idx > last_epoch: + last_epoch = epoch_idx + + if last_epoch != -1: + logger.info(f"resume from epoch-{last_epoch}.") + generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt" + discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt" + optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth" + optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth" + + logger.info(f"load state dict for generator.") + with open(generator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + generator.load_state_dict(state_dict, strict=True) + logger.info(f"load state dict for discriminator.") + with open(discriminator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + discriminator.load_state_dict(state_dict, strict=True) + + logger.info(f"load state dict for optim_g.") + with open(optim_g_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_g.load_state_dict(state_dict) + logger.info(f"load state dict for optim_d.") + with open(optim_d_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_d.load_state_dict(state_dict) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch) + + # training loop + + # state + loss_d = 10000000000 + loss_g = 10000000000 + pesq_metric = 10000000000 + mag_err = 10000000000 + pha_err = 10000000000 + com_err = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + logger.info("training") + for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): + # train + generator.train() + discriminator.train() + + total_loss_d = 0. + total_loss_g = 0. + total_batches = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + for batch in train_data_loader: + clean_audios, noisy_audios = batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + one_labels = torch.ones(clean_audios.shape[0]).to(device) + + audio_g = generator.forward(noisy_audios) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + + clean_audio_list = torch.split(clean_audios, 1, dim=0) + enhanced_audio_list = torch.split(audio_g, 1, dim=0) + clean_audio_list = [t.squeeze().detach().cpu().numpy() for t in clean_audio_list] + enhanced_audio_list = [t.squeeze().detach().cpu().numpy() for t in enhanced_audio_list] + + pesq_score_list: List[float] = run_batch_pesq(clean_audio_list, enhanced_audio_list, sample_rate=config.sample_rate, mode="nb") + + # Discriminator + optim_d.zero_grad() + metric_r = discriminator.forward(clean_audios, clean_audios) + metric_g = discriminator.forward(clean_audios, audio_g.detach()) + loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) + + if -1 in pesq_score_list: + # print("-1 in batch_pesq_score!") + loss_disc_g = 0 + else: + pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32) + loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten()) + + loss_disc_all = loss_disc_r + loss_disc_g + loss_disc_all.backward() + optim_d.step() + + # Generator + optim_g.zero_grad() + # L2 Magnitude Loss + loss_mag = F.mse_loss(clean_mag, mag_g) + # Anti-wrapping Phase Loss + loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) + loss_pha = loss_ip + loss_gd + loss_iaf + # L2 Complex Loss + loss_com = F.mse_loss(clean_com, com_g) * 2 + # L2 Consistency Loss + # Time Loss + loss_time = F.l1_loss(clean_audios, audio_g) + # Metric Loss + metric_g = discriminator.forward(clean_audios, audio_g.detach()) + loss_metric = F.mse_loss(metric_g.flatten(), one_labels) + + loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2 + # loss_gen_all = loss_mag * 0.1 + loss_pha * 0.1 + loss_com * 0.1 + loss_metric * 0.9 + loss_time * 0.9 + # 2.02 + + loss_gen_all.backward() + optim_g.step() + + total_loss_d += loss_disc_all.item() + total_loss_g += loss_gen_all.item() + total_batches += 1 + + loss_d = round(total_loss_d / total_batches, 4) + loss_g = round(total_loss_g / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "loss_d": loss_d, + "loss_g": loss_g, + }) + + # evaluation + generator.eval() + discriminator.eval() + + torch.cuda.empty_cache() + total_pesq_score = 0. + total_mag_err = 0. + total_pha_err = 0. + total_com_err = 0. + total_batches = 0. + + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + with torch.no_grad(): + for batch in valid_data_loader: + clean_audios, noisy_audios = batch + clean_audios = clean_audios.to(device) + noisy_audios = noisy_audios.to(device) + + audio_g = generator.forward(noisy_audios) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_length, config.win_length, config.compress_factor) + + clean_audio_list = torch.split(clean_audios, 1, dim=0) + enhanced_audio_list = torch.split(audio_g, 1, dim=0) + clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list] + enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list] + pesq_score = run_pesq_score( + clean_audio_list, + enhanced_audio_list, + sample_rate = config.sample_rate, + mode = "nb", + ) + total_pesq_score += pesq_score + total_mag_err += F.mse_loss(clean_mag, mag_g).item() + val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) + total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() + total_com_err += F.mse_loss(clean_com, com_g).item() + + total_batches += 1 + + pesq_metric = round(total_pesq_score / total_batches, 4) + mag_err = round(total_mag_err / total_batches, 4) + pha_err = round(total_pha_err / total_batches, 4) + com_err = round(total_com_err / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + }) + + # scheduler + scheduler_g.step() + scheduler_d.step() + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + generator.save_pretrained(epoch_dir.as_posix()) + discriminator.save_pretrained(epoch_dir.as_posix()) + + # save optim + torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix()) + torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = pesq_metric + elif pesq_metric > best_metric: + # great is better. + best_idx_epoch = idx_epoch + best_metric = pesq_metric + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "loss_d": loss_d, + "loss_g": loss_g, + + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + + return + + +if __name__ == "__main__": + main() diff --git a/examples/nx_denoise/step_3_evaluation.py b/examples/nx_denoise/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..71d8e1569814cb5b357cf003aa929d228a785773 --- /dev/null +++ b/examples/nx_denoise/step_3_evaluation.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def main(): + return + + +if __name__ == '__main__': + main() diff --git a/examples/nx_denoise/yaml/config.yaml b/examples/nx_denoise/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dff945235785d0494709281c6aa9836f23cf0f2e --- /dev/null +++ b/examples/nx_denoise/yaml/config.yaml @@ -0,0 +1,42 @@ +model_name: "nx_denoise" + +sample_rate: 8000 +segment_size: 16000 +n_fft: 512 +win_size: 200 +hop_size: 80 + +down_sampling_num_layers: 6 +down_sampling_in_channels: 1 +down_sampling_hidden_channels: 64 +down_sampling_kernel_size: 4 +down_sampling_stride: 2 + +causal_in_channels: 1 +causal_out_channels: 64 +causal_kernel_size: 3 +causal_bias: false +causal_separable: true +causal_f_stride: 1 +causal_num_layers: 5 + +tsfm_hidden_size: 256 +tsfm_attention_heads: 8 +tsfm_num_blocks: 6 +tsfm_dropout_rate: 0.1 +tsfm_max_length: 512 +tsfm_chunk_size: 1 +tsfm_num_left_chunks: 128 +tsfm_num_right_chunks: 4 + +discriminator_dim: 32 +discriminator_in_channel: 2 + +compress_factor: 0.3 + +batch_size: 4 +learning_rate: 0.0005 +adam_b1: 0.8 +adam_b2: 0.99 +lr_decay: 0.99 +seed: 1234 diff --git a/examples/nx_mpnet/run.sh b/examples/nx_mpnet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..08fba1aaa0de09b4ead7b1e8d37e74ddead389d4 --- /dev/null +++ b/examples/nx_mpnet/run.sh @@ -0,0 +1,155 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-mpnet-aishell \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \ +--max_epochs 100 \ +--duration 2 \ + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train +duration=2 + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --duration "${duration}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/evaluation_audio" "${final_model_dir}" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/nx_mpnet/step_1_prepare_data.py b/examples/nx_mpnet/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f7aaf5f627b867b8710c47c3d86013955b26cf36 --- /dev/null +++ b/examples/nx_mpnet/step_1_prepare_data.py @@ -0,0 +1,202 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + parser.add_argument("--scale", default=1, type=float) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + flag = random.random() + if flag > args.scale: + continue + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/nx_mpnet/step_2_train_model.py b/examples/nx_mpnet/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e416384f271fc3371a666590b2d9a1c897fcd0 --- /dev/null +++ b/examples/nx_mpnet/step_2_train_model.py @@ -0,0 +1,447 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/yxlu-0102/MP-SENet/blob/main/train.py +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig +from toolbox.torchaudio.models.nx_mpnet.discriminator import MetricDiscriminatorPretrainedModel +from toolbox.torchaudio.models.nx_mpnet.modeling_nx_mpnet import NXMPNet, NXMPNetPretrainedModel +from toolbox.torchaudio.models.nx_mpnet.utils import mag_pha_stft, mag_pha_istft +from toolbox.torchaudio.models.nx_mpnet.metrics import run_batch_pesq, run_pesq_score +from toolbox.torchaudio.models.nx_mpnet.loss import phase_losses + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = NXMPNetConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(config.seed) + np.random.seed(config.seed) + torch.manual_seed(config.seed) + logger.info(f"set seed: {config.seed}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info(f"GPU available count: {n_gpu}; device: {device}") + + # datasets + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + generator = NXMPNetPretrainedModel(config).to(device) + discriminator = MetricDiscriminatorPretrainedModel(config).to(device) + + # optimizer + logger.info("prepare optimizer, lr_scheduler") + num_params = 0 + for p in generator.parameters(): + num_params += p.numel() + logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6)) + + optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2]) + + # resume training + last_epoch = -1 + for epoch_i in serialization_dir.glob("epoch-*"): + epoch_i = Path(epoch_i) + epoch_idx = epoch_i.stem.split("-")[1] + epoch_idx = int(epoch_idx) + if epoch_idx > last_epoch: + last_epoch = epoch_idx + + if last_epoch != -1: + logger.info(f"resume from epoch-{last_epoch}.") + generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt" + discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt" + optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth" + optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth" + + logger.info(f"load state dict for generator.") + with open(generator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + generator.load_state_dict(state_dict, strict=True) + logger.info(f"load state dict for discriminator.") + with open(discriminator_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + discriminator.load_state_dict(state_dict, strict=True) + + logger.info(f"load state dict for optim_g.") + with open(optim_g_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_g.load_state_dict(state_dict) + logger.info(f"load state dict for optim_d.") + with open(optim_d_pth.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + optim_d.load_state_dict(state_dict) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch) + + # training loop + + # state + loss_d = 10000000000 + loss_g = 10000000000 + pesq_metric = 10000000000 + mag_err = 10000000000 + pha_err = 10000000000 + com_err = 10000000000 + stft_err = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + logger.info("training") + for idx_epoch in range(max(0, last_epoch+1), args.max_epochs): + # train + generator.train() + discriminator.train() + + total_loss_d = 0. + total_loss_g = 0. + total_batches = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + for batch in train_data_loader: + clean_audio, noisy_audio = batch + clean_audio = clean_audio.to(device) + noisy_audio = noisy_audio.to(device) + one_labels = torch.ones(clean_audio.shape[0]).to(device) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) + + audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy()) + pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb") + + # Discriminator + optim_d.zero_grad() + metric_r = discriminator.forward(clean_mag, clean_mag) + metric_g = discriminator.forward(clean_mag, mag_g_hat.detach()) + loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) + + if -1 in pesq_score_list: + # print("-1 in batch_pesq_score!") + loss_disc_g = 0 + else: + pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32) + loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten()) + + loss_disc_all = loss_disc_r + loss_disc_g + loss_disc_all.backward() + optim_d.step() + + # Generator + optim_g.zero_grad() + # L2 Magnitude Loss + loss_mag = F.mse_loss(clean_mag, mag_g) + # Anti-wrapping Phase Loss + loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g) + loss_pha = loss_ip + loss_gd + loss_iaf + # L2 Complex Loss + loss_com = F.mse_loss(clean_com, com_g) * 2 + # L2 Consistency Loss + loss_stft = F.mse_loss(com_g, com_g_hat) * 2 + # Time Loss + loss_time = F.l1_loss(clean_audio, audio_g) + # Metric Loss + metric_g = discriminator.forward(clean_mag, mag_g_hat) + loss_metric = F.mse_loss(metric_g.flatten(), one_labels) + + loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2 + + loss_gen_all.backward() + optim_g.step() + + total_loss_d += loss_disc_all.item() + total_loss_g += loss_gen_all.item() + total_batches += 1 + + loss_d = round(total_loss_d / total_batches, 4) + loss_g = round(total_loss_g / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "loss_d": loss_d, + "loss_g": loss_g, + }) + + # evaluation + generator.eval() + discriminator.eval() + + torch.cuda.empty_cache() + total_pesq_score = 0. + total_mag_err = 0. + total_pha_err = 0. + total_com_err = 0. + total_stft_err = 0. + total_batches = 0. + + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + with torch.no_grad(): + for batch in valid_data_loader: + clean_audio, noisy_audio = batch + clean_audio = clean_audio.to(device) + noisy_audio = noisy_audio.to(device) + + clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) + + audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor) + + clean_audio_list = torch.split(clean_audio, 1, dim=0) + enhanced_audio_list = torch.split(audio_g, 1, dim=0) + clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list] + enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list] + pesq_score = run_pesq_score( + clean_audio_list, + enhanced_audio_list, + sample_rate = config.sample_rate, + mode = "nb", + ) + total_pesq_score += pesq_score + total_mag_err += F.mse_loss(clean_mag, mag_g).item() + val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g) + total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item() + total_com_err += F.mse_loss(clean_com, com_g).item() + total_stft_err += F.mse_loss(com_g, com_g_hat).item() + + total_batches += 1 + + pesq_metric = round(total_pesq_score / total_batches, 4) + mag_err = round(total_mag_err / total_batches, 4) + pha_err = round(total_pha_err / total_batches, 4) + com_err = round(total_com_err / total_batches, 4) + stft_err = round(total_stft_err / total_batches, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + "stft_err": stft_err, + }) + + # scheduler + scheduler_g.step() + scheduler_d.step() + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + generator.save_pretrained(epoch_dir.as_posix()) + discriminator.save_pretrained(epoch_dir.as_posix()) + + # save optim + torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix()) + torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = pesq_metric + elif pesq_metric > best_metric: + # great is better. + best_idx_epoch = idx_epoch + best_metric = pesq_metric + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "loss_d": loss_d, + "loss_g": loss_g, + + "pesq_metric": pesq_metric, + "mag_err": mag_err, + "pha_err": pha_err, + "com_err": com_err, + "stft_err": stft_err, + + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + + return + + +if __name__ == "__main__": + main() diff --git a/examples/nx_mpnet/step_3_evaluation.py b/examples/nx_mpnet/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4c1ba29b83196dca3f3c52e0521383b4aa6b21 --- /dev/null +++ b/examples/nx_mpnet/step_3_evaluation.py @@ -0,0 +1,187 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/yxlu-0102/MP-SENet/blob/main/inference.py +""" +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig +from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel +from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +def save_audios(noise_audio: torch.Tensor, + clean_audio: torch.Tensor, + noisy_audio: torch.Tensor, + enhanced_audio: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_audio.wav" + torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate, bits_per_sample=16) + filename = output_dir / "clean_audio.wav" + torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate, bits_per_sample=16) + filename = output_dir / "noisy_audio.wav" + torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate, bits_per_sample=16) + + filename = output_dir / "enhanced_audio.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate, bits_per_sample=16) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + config = MPNetConfig.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + generator = MPNetPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + generator.to(device) + generator.eval() + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_audio, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + clean_audio, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + noisy_audio: np.ndarray = mix_speech_and_noise( + speech=clean_audio, + noise=noise_audio, + snr_db=snr_db, + ) + noise_audio = torch.tensor(noise_audio, dtype=torch.float32) + clean_audio = torch.tensor(clean_audio, dtype=torch.float32) + noisy_audio: torch.Tensor = torch.tensor(noisy_audio, dtype=torch.float32) + + noise_audio = noise_audio.unsqueeze(dim=0) + clean_audio = clean_audio.unsqueeze(dim=0) + noisy_audio: torch.Tensor = noisy_audio.unsqueeze(dim=0) + + # inference + clean_audio = clean_audio.to(device) + noisy_audio = noisy_audio.to(device) + with torch.no_grad(): + noisy_mag, noisy_pha, noisy_com = mag_pha_stft( + noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor + ) + mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) + audio_g = mag_pha_istft( + mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor + ) + enhanced_audio = audio_g.detach() + + save_audios( + noise_audio, clean_audio, noisy_audio, + enhanced_audio, + args.evaluation_audio_dir + ) + + progress_bar.update(1) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/nx_mpnet/yaml/config.yaml b/examples/nx_mpnet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6297143f1994c14f24ad9e640c0347dbef02684 --- /dev/null +++ b/examples/nx_mpnet/yaml/config.yaml @@ -0,0 +1,38 @@ +model_name: "nx_denoise" + +sample_rate: 8000 +segment_size: 16000 +n_fft: 512 +win_size: 200 +hop_size: 80 + +dense_num_blocks: 4 +dense_hidden_size: 64 + +mask_num_blocks: 4 +mask_hidden_size: 64 + +phase_num_blocks: 4 +phase_hidden_size: 64 + +tsfm_hidden_size: 64 +tsfm_attention_heads: 4 +tsfm_num_blocks: 4 +tsfm_dropout_rate: 0.0 +tsfm_max_time_relative_position: 2048 +tsfm_max_freq_relative_position: 256 +tsfm_chunk_size: 1 +tsfm_num_left_chunks: 128 +tsfm_num_right_chunks: 64 + +discriminator_dim: 32 +discriminator_in_channel: 2 + +compress_factor: 0.3 + +batch_size: 4 +learning_rate: 0.0005 +adam_b1: 0.8 +adam_b2: 0.99 +lr_decay: 0.99 +seed: 1234 diff --git a/examples/rnnoise/run.sh b/examples/rnnoise/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..bc16482132e3c9b4828709f9fbb30c09a69b71cd --- /dev/null +++ b/examples/rnnoise/run.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name rnnoise-nx-dns3 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \ + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +train_dataset="${file_dir}/train.jsonl" +valid_dataset="${file_dir}/valid.jsonl" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + --sparse + +fi diff --git a/examples/rnnoise/step_1_prepare_data.py b/examples/rnnoise/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..16398c263d39e9accfff6ffb4a64650037eba385 --- /dev/null +++ b/examples/rnnoise/step_1_prepare_data.py @@ -0,0 +1,197 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_nsr_db", default=-20, type=float) + parser.add_argument("--max_nsr_db", default=5, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/rnnoise/step_2_train_model.py b/examples/rnnoise/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3f86b25983549c0bd01a686fa48b31ad206c6f84 --- /dev/null +++ b/examples/rnnoise/step_2_train_model.py @@ -0,0 +1,444 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset +from toolbox.torchaudio.losses.snr import NegativeSISNRLoss +from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss +from toolbox.torchaudio.metrics.pesq import run_pesq_score +from toolbox.torchaudio.models.rnnoise.configuration_rnnoise import RNNoiseConfig +from toolbox.torchaudio.models.rnnoise.modeling_rnnoise import RNNoisePretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.jsonl", type=str) + parser.add_argument("--valid_dataset", default="valid.jsonl", type=str) + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-3, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=15, type=int) + parser.add_argument("--patience", default=10, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + parser.add_argument("--sparse", action="store_true") + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self): + pass + + def __call__(self, batch: List[dict]): + clean_audios = list() + noisy_audios = list() + snr_db_list = list() + + for sample in batch: + # noise_wave: torch.Tensor = sample["noise_wave"] + clean_audio: torch.Tensor = sample["speech_wave"] + noisy_audio: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + clean_audios.append(clean_audio) + noisy_audios.append(noisy_audio) + + clean_audios = torch.stack(clean_audios) + noisy_audios = torch.stack(noisy_audios) + + # assert + if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)): + raise AssertionError("nan or inf in clean_audios") + if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)): + raise AssertionError("nan or inf in noisy_audios") + return clean_audios, noisy_audios + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + config = RNNoiseConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + ) + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseJsonlDataset( + jsonl_file=args.train_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + # skip=225000, + ) + valid_dataset = DenoiseJsonlDataset( + jsonl_file=args.valid_dataset, + expected_sample_rate=config.sample_rate, + max_wave_value=32768.0, + min_snr_db=config.min_snr_db, + max_snr_db=config.max_snr_db, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=config.batch_size, + # shuffle=True, + sampler=None, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + prefetch_factor=None if platform.system() == "Windows" else 2, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + model = RNNoisePretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric") + optimizer = torch.optim.AdamW(model.parameters(), config.lr) + + # resume training + last_step_idx = -1 + last_epoch = -1 + for step_idx_str in serialization_dir.glob("steps-*"): + step_idx_str = Path(step_idx_str) + step_idx = step_idx_str.stem.split("-")[1] + step_idx = int(step_idx) + if step_idx > last_step_idx: + last_step_idx = step_idx + # last_epoch = 1 + + if last_step_idx != -1: + logger.info(f"resume from steps-{last_step_idx}.") + model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt" + + logger.info(f"load state dict for model.") + with open(model_pt.as_posix(), "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + + if config.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + last_epoch=last_epoch, + # T_max=10 * config.eval_steps, + # eta_min=0.01 * config.lr, + **config.lr_scheduler_kwargs, + ) + elif config.lr_scheduler == "MultiStepLR": + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + last_epoch=last_epoch, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + else: + raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}") + + neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device) + mr_stft_loss_fn = MultiResolutionSTFTLoss( + fft_size_list=[256, 512, 1024], + win_size_list=[256, 512, 1024], + hop_size_list=[128, 256, 512], + factor_sc=1.5, + factor_mag=1.0, + reduction="mean" + ).to(device) + + # training loop + logger.info("training") + + average_pesq_score = 1000000000 + average_loss = 1000000000 + average_mr_stft_loss = 1000000000 + average_neg_si_snr_loss = 1000000000 + + model_list = list() + best_epoch_idx = None + best_step_idx = None + best_metric = None + patience_count = 0 + + step_idx = 0 if last_step_idx == -1 else last_step_idx + + early_stop_flag = False + for epoch_idx in range(max(0, last_epoch+1), config.max_epochs): + if early_stop_flag: + break + + # train + model.train() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train = tqdm( + initial=step_idx, + desc="Training; epoch-{}".format(epoch_idx), + ) + for train_batch in train_data_loader: + clean_audios, noisy_audios = train_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + denoise_audios, _, _ = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm) + optimizer.step() + if args.sparse: + model.sparsify() + lr_scheduler.step() + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_train.update(1) + progress_bar_train.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + }) + + # evaluation + step_idx += 1 + if step_idx % config.eval_steps == 0: + model.eval() + with torch.no_grad(): + torch.cuda.empty_cache() + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_train.close() + progress_bar_eval = tqdm( + desc="Evaluation; steps-{}k".format(int(step_idx / 1000)), + ) + + for eval_batch in valid_data_loader: + clean_audios, noisy_audios = eval_batch + clean_audios: torch.Tensor = clean_audios.to(device) + noisy_audios: torch.Tensor = noisy_audios.to(device) + + denoise_audios, _, _ = model.forward(noisy_audios) + denoise_audios = torch.squeeze(denoise_audios, dim=1) + + mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios) + neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios) + + loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.info(f"find nan or inf in loss.") + continue + + denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy()) + clean_audios_list_r = list(clean_audios.detach().cpu().numpy()) + pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb") + + total_pesq_score += pesq_score + total_loss += loss.item() + total_mr_stft_loss += mr_stft_loss.item() + total_neg_si_snr_loss += neg_si_snr_loss.item() + total_batches += 1 + + average_pesq_score = round(total_pesq_score / total_batches, 4) + average_loss = round(total_loss / total_batches, 4) + average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4) + average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4) + + progress_bar_eval.update(1) + progress_bar_eval.set_postfix({ + "lr": lr_scheduler.get_last_lr()[0], + "pesq_score": average_pesq_score, + "loss": average_loss, + "mr_stft_loss": average_mr_stft_loss, + "neg_si_snr_loss": average_neg_si_snr_loss, + }) + + total_pesq_score = 0. + total_loss = 0. + total_mr_stft_loss = 0. + total_neg_si_snr_loss = 0. + total_batches = 0. + + progress_bar_eval.close() + progress_bar_train = tqdm( + initial=progress_bar_train.n, + postfix=progress_bar_train.postfix, + desc=progress_bar_train.desc, + ) + + # save path + save_dir = serialization_dir / "steps-{}".format(step_idx) + save_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(save_dir.as_posix()) + + model_list.append(save_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + elif average_pesq_score >= best_metric: + # great is better. + best_epoch_idx = epoch_idx + best_step_idx = step_idx + best_metric = average_pesq_score + else: + pass + + metrics = { + "epoch_idx": epoch_idx, + "best_epoch_idx": best_epoch_idx, + "best_step_idx": best_step_idx, + "pesq_score": average_pesq_score, + "loss": average_loss, + } + metrics_filename = save_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_epoch_idx == epoch_idx: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(save_dir, best_dir) + + # early stop + early_stop_flag = False + if best_epoch_idx == epoch_idx and best_step_idx == step_idx: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + model.train() + + return + + +if __name__ == "__main__": + main() diff --git a/examples/rnnoise/yaml/config.yaml b/examples/rnnoise/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b1e5fd7b66220585736aa8e20bb5fc3d482b7e87 --- /dev/null +++ b/examples/rnnoise/yaml/config.yaml @@ -0,0 +1,35 @@ +model_name: "rnnoise" + +# spec +sample_rate: 8000 +segment_size: 32000 +nfft: 160 +win_size: 160 +hop_size: 80 +win_type: hann + +erb_bins: 32 +min_freq_bins_for_erb: 2 + +# model +conv_size: 256 +gru_size: 256 + +# data +max_snr_db: 20 +min_snr_db: -10 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +batch_size: 64 +num_workers: 4 +eval_steps: 15000 diff --git a/examples/simple_linear_irm_aishell/run.sh b/examples/simple_linear_irm_aishell/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..52d9691c1c3ced86192dfae6da6bed1106306390 --- /dev/null +++ b/examples/simple_linear_irm_aishell/run.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/simple_linear_irm_aishell/step_1_prepare_data.py b/examples/simple_linear_irm_aishell/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f1ff4f3b60ae19116d3716f2c234733f053e95 --- /dev/null +++ b/examples/simple_linear_irm_aishell/step_1_prepare_data.py @@ -0,0 +1,196 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_nsr_db", default=-20, type=float) + parser.add_argument("--max_nsr_db", default=5, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/simple_linear_irm_aishell/step_2_train_model.py b/examples/simple_linear_irm_aishell/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..27742b223497c49b314d41d9e8ff058e7e2f2e64 --- /dev/null +++ b/examples/simple_linear_irm_aishell/step_2_train_model.py @@ -0,0 +1,348 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +from torch import dtype + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.simple_linear_irm.configuration_simple_linear_irm import SimpleLinearIRMConfig +from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-3, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + window_fn: str = "hamming", + irm_beta: float = 1.0, + epsilon: float = 1e-8, + ): + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.window_fn = window_fn + self.irm_beta = irm_beta + self.epsilon = epsilon + + self.transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=2.0, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + def __call__(self, batch: List[dict]): + mix_spec_list = list() + speech_irm_list = list() + snr_db_list = list() + for sample in batch: + noise_wave: torch.Tensor = sample["noise_wave"] + speech_wave: torch.Tensor = sample["speech_wave"] + mix_wave: torch.Tensor = sample["mix_wave"] + snr_db: float = sample["snr_db"] + + noise_spec = self.transform.forward(noise_wave) + speech_spec = self.transform.forward(speech_wave) + mix_spec = self.transform.forward(mix_wave) + + # noise_irm = noise_spec / (noise_spec + speech_spec) + speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) + speech_irm = torch.pow(speech_irm, self.irm_beta) + + mix_spec_list.append(mix_spec) + speech_irm_list.append(speech_irm) + snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32)) + + mix_spec_list = torch.stack(mix_spec_list) + speech_irm_list = torch.stack(speech_irm_list) + snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,) + + # assert + if torch.any(torch.isnan(mix_spec_list)): + raise AssertionError("nan in mix_spec Tensor") + if torch.any(torch.isnan(speech_irm_list)): + raise AssertionError("nan in speech_irm Tensor") + if torch.any(torch.isnan(snr_db_list)): + raise AssertionError("nan in snr_db Tensor") + + return mix_spec_list, speech_irm_list, snr_db_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + config = SimpleLinearIRMConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + # num_labels=vocabulary.get_vocab_size(namespace="labels") + ) + model = SimpleLinearIRMPretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + param_optimizer = model.parameters() + optimizer = torch.optim.Adam( + param_optimizer, + lr=args.learning_rate, + ) + # lr_scheduler = torch.optim.lr_scheduler.StepLR( + # optimizer, + # step_size=2000 + # ) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + mse_loss = nn.MSELoss( + reduction="mean", + ) + + # training loop + logger.info("training") + + training_loss = 10000000000 + evaluation_loss = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + for idx_epoch in range(args.max_epochs): + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + + for batch in train_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + training_loss = total_loss / total_examples + training_loss = round(training_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "training_loss": training_loss, + }) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + for batch in valid_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + elif evaluation_loss < best_metric: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "training_loss": training_loss, + "evaluation_loss": evaluation_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + return + + +if __name__ == '__main__': + main() diff --git a/examples/simple_linear_irm_aishell/step_3_evaluation.py b/examples/simple_linear_irm_aishell/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..fe60fead3d247509e3916f435dc722f9cedad8f9 --- /dev/null +++ b/examples/simple_linear_irm_aishell/step_3_evaluation.py @@ -0,0 +1,239 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +stft_power = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=2.0, + window_fn=torch.hamming_window, +) + + +stft_complex = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=None, + window_fn=torch.hamming_window, +) + + +istft = torchaudio.transforms.InverseSpectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, +) + + +def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): + mix_spec_complex = mix_spec_complex.detach().cpu() + speech_irm_prediction = speech_irm_prediction.detach().cpu() + + mask_speech = speech_irm_prediction + mask_noise = 1.0 - speech_irm_prediction + + speech_spec = mix_spec_complex * mask_speech + noise_spec = mix_spec_complex * mask_noise + + speech_wave = istft.forward(speech_spec) + noise_wave = istft.forward(noise_spec) + + return speech_wave, noise_wave + + +def save_audios(noise_wave: torch.Tensor, + speech_wave: torch.Tensor, + mix_wave: torch.Tensor, + speech_wave_enhanced: torch.Tensor, + noise_wave_enhanced: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_wave.wav" + torchaudio.save(filename, noise_wave, sample_rate) + filename = output_dir / "speech_wave.wav" + torchaudio.save(filename, speech_wave, sample_rate) + filename = output_dir / "mix_wave.wav" + torchaudio.save(filename, mix_wave, sample_rate) + + filename = output_dir / "speech_wave_enhanced.wav" + torchaudio.save(filename, speech_wave_enhanced, sample_rate) + filename = output_dir / "noise_wave_enhanced.wav" + torchaudio.save(filename, noise_wave_enhanced, sample_rate) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + model = SimpleLinearIRMPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + model.to(device) + model.eval() + + # optimizer + logger.info("prepare loss_fn") + mse_loss = nn.MSELoss( + reduction="mean", + ) + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_wave, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + speech_wave, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + mix_wave: np.ndarray = mix_speech_and_noise( + speech=speech_wave, + noise=noise_wave, + snr_db=snr_db, + ) + noise_wave = torch.tensor(noise_wave, dtype=torch.float32) + speech_wave = torch.tensor(speech_wave, dtype=torch.float32) + mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) + + noise_wave = noise_wave.unsqueeze(dim=0) + speech_wave = speech_wave.unsqueeze(dim=0) + mix_wave = mix_wave.unsqueeze(dim=0) + + noise_spec: torch.Tensor = stft_power.forward(noise_wave) + speech_spec: torch.Tensor = stft_power.forward(speech_wave) + mix_spec: torch.Tensor = stft_power.forward(mix_wave) + mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) + + speech_irm = speech_spec / (noise_spec + speech_spec) + speech_irm = torch.pow(speech_irm, 1.0) + + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + with torch.no_grad(): + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) + save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/simple_linear_irm_aishell/yaml/config.yaml b/examples/simple_linear_irm_aishell/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f269122b6f589c012376cda288e9a256f1453de2 --- /dev/null +++ b/examples/simple_linear_irm_aishell/yaml/config.yaml @@ -0,0 +1,13 @@ +model_name: "simple_linear_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +# model +num_bins: 257 +hidden_size: 2048 +lookback: 3 +lookahead: 3 diff --git a/examples/spectrum_dfnet_aishell/run.sh b/examples/spectrum_dfnet_aishell/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..e1f45cd98ea43badeb9986bb37d374fa5bfd8bf6 --- /dev/null +++ b/examples/spectrum_dfnet_aishell/run.sh @@ -0,0 +1,178 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + +sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/spectrum_dfnet_aishell/step_1_prepare_data.py b/examples/spectrum_dfnet_aishell/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..cf114513869c474b2588d6d81067ec801673b5e7 --- /dev/null +++ b/examples/spectrum_dfnet_aishell/step_1_prepare_data.py @@ -0,0 +1,197 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/spectrum_dfnet_aishell/step_2_train_model.py b/examples/spectrum_dfnet_aishell/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8f443d6a67231c3fb20043cd3cb4d9fec431cd --- /dev/null +++ b/examples/spectrum_dfnet_aishell/step_2_train_model.py @@ -0,0 +1,440 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.spectrum_dfnet.configuration_spectrum_dfnet import SpectrumDfNetConfig +from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=16, type=int) + parser.add_argument("--learning_rate", default=1e-4, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + window_fn: str = "hamming", + irm_beta: float = 1.0, + epsilon: float = 1e-8, + ): + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.window_fn = window_fn + self.irm_beta = irm_beta + self.epsilon = epsilon + + self.complex_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=None, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + self.transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=2.0, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + @staticmethod + def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3): + batch_size, channels, freq_dim, time_steps = x.shape + + # kernel: [freq_dim, n_time_step] + kernel_size = (freq_dim, n_time_steps) + + # pad + pad = n_time_steps // 2 + x = torch.concat(tensors=[ + x[:, :, :, :pad], + x, + x[:, :, :, -pad:], + ], dim=-1) + + x = F.unfold( + input=x, + kernel_size=kernel_size, + ) + # x shape: [batch_size, fold, time_steps] + return x + + def __call__(self, batch: List[dict]): + speech_complex_spec_list = list() + mix_complex_spec_list = list() + speech_irm_list = list() + snr_db_list = list() + for sample in batch: + noise_wave: torch.Tensor = sample["noise_wave"] + speech_wave: torch.Tensor = sample["speech_wave"] + mix_wave: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + noise_spec = self.transform.forward(noise_wave) + speech_spec = self.transform.forward(speech_wave) + + speech_complex_spec = self.complex_transform.forward(speech_wave) + mix_complex_spec = self.complex_transform.forward(mix_wave) + + # noise_irm = noise_spec / (noise_spec + speech_spec) + speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) + speech_irm = torch.pow(speech_irm, self.irm_beta) + + # noise_spec, speech_spec, mix_spec, speech_irm + # shape: [freq_dim, time_steps] + + snr_db: torch.Tensor = 10 * torch.log10( + speech_spec / (noise_spec + self.epsilon) + ) + snr_db = torch.clamp(snr_db, min=self.epsilon) + + snr_db_ = torch.unsqueeze(snr_db, dim=0) + snr_db_ = torch.unsqueeze(snr_db_, dim=0) + snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3) + snr_db_ = torch.squeeze(snr_db_, dim=0) + # snr_db_ shape: [fold, time_steps] + + snr_db = torch.mean(snr_db_, dim=0, keepdim=True) + # snr_db shape: [1, time_steps] + + speech_complex_spec_list.append(speech_complex_spec) + mix_complex_spec_list.append(mix_complex_spec) + speech_irm_list.append(speech_irm) + snr_db_list.append(snr_db) + + speech_complex_spec_list = torch.stack(speech_complex_spec_list) + mix_complex_spec_list = torch.stack(mix_complex_spec_list) + speech_irm_list = torch.stack(speech_irm_list) + snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1) + + speech_complex_spec_list = speech_complex_spec_list[:, :-1, :] + mix_complex_spec_list = mix_complex_spec_list[:, :-1, :] + speech_irm_list = speech_irm_list[:, :-1, :] + + # speech_complex_spec_list shape: [batch_size, freq_dim, time_steps] + # mix_complex_spec_list shape: [batch_size, freq_dim, time_steps] + # speech_irm_list shape: [batch_size, freq_dim, time_steps] + # snr_db shape: [batch_size, 1, time_steps] + + # assert + if torch.any(torch.isnan(speech_complex_spec_list)) or torch.any(torch.isinf(speech_complex_spec_list)): + raise AssertionError("nan or inf in speech_complex_spec_list") + if torch.any(torch.isnan(mix_complex_spec_list)) or torch.any(torch.isinf(mix_complex_spec_list)): + raise AssertionError("nan or inf in mix_complex_spec_list") + if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)): + raise AssertionError("nan or inf in speech_irm_list") + if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)): + raise AssertionError("nan or inf in snr_db_list") + + return speech_complex_spec_list, mix_complex_spec_list, speech_irm_list, snr_db_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + config = SpectrumDfNetConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + # num_labels=vocabulary.get_vocab_size(namespace="labels") + ) + model = SpectrumDfNetPretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + param_optimizer = model.parameters() + optimizer = torch.optim.Adam( + param_optimizer, + lr=args.learning_rate, + ) + # lr_scheduler = torch.optim.lr_scheduler.StepLR( + # optimizer, + # step_size=2000 + # ) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + + speech_mse_loss = nn.MSELoss( + reduction="mean", + ) + irm_mse_loss = nn.MSELoss( + reduction="mean", + ) + snr_mse_loss = nn.MSELoss( + reduction="mean", + ) + + # training loop + logger.info("training") + + training_loss = 10000000000 + evaluation_loss = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + for idx_epoch in range(args.max_epochs): + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + + for batch in train_data_loader: + speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch + speech_complex_spec = speech_complex_spec.to(device) + mix_complex_spec = mix_complex_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec) + if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)): + raise AssertionError("nan or inf in speech_spec_prediction") + if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): + raise AssertionError("nan or inf in speech_irm_prediction") + if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): + raise AssertionError("nan or inf in lsnr_prediction") + + speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec)) + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + + loss = speech_loss + irm_loss + snr_loss + + total_loss += loss.item() + total_examples += mix_complex_spec.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + training_loss = total_loss / total_examples + training_loss = round(training_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "training_loss": training_loss, + }) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + for batch in valid_data_loader: + speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch + speech_complex_spec = speech_complex_spec.to(device) + mix_complex_spec = mix_complex_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec) + if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)): + raise AssertionError("nan or inf in speech_spec_prediction") + if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): + raise AssertionError("nan or inf in speech_irm_prediction") + if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): + raise AssertionError("nan or inf in lsnr_prediction") + + speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec)) + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + + loss = speech_loss + irm_loss + snr_loss + + total_loss += loss.item() + total_examples += mix_complex_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + elif evaluation_loss < best_metric: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "training_loss": training_loss, + "evaluation_loss": evaluation_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + return + + +if __name__ == '__main__': + main() diff --git a/examples/spectrum_dfnet_aishell/step_3_evaluation.py b/examples/spectrum_dfnet_aishell/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..53482b2e6d8432f4c0c0c88bdadeaef828ce5168 --- /dev/null +++ b/examples/spectrum_dfnet_aishell/step_3_evaluation.py @@ -0,0 +1,302 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +stft_power = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=2.0, + window_fn=torch.hamming_window, +) + + +stft_complex = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=None, + window_fn=torch.hamming_window, +) + + +istft = torchaudio.transforms.InverseSpectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, +) + + +def enhance(mix_spec_complex: torch.Tensor, + speech_spec_prediction: torch.Tensor, + speech_irm_prediction: torch.Tensor, + ): + mix_spec_complex = mix_spec_complex.detach().cpu() + speech_spec_prediction = speech_spec_prediction.detach().cpu() + speech_irm_prediction = speech_irm_prediction.detach().cpu() + + mask_speech = speech_irm_prediction + mask_noise = 1.0 - speech_irm_prediction + + speech_spec = mix_spec_complex * mask_speech + noise_spec = mix_spec_complex * mask_noise + + # print(f"speech_spec_prediction: {speech_spec_prediction.shape}") + # print(f"noise_spec: {noise_spec.shape}") + + speech_wave = istft.forward(speech_spec_prediction) + # speech_wave = istft.forward(speech_spec) + noise_wave = istft.forward(noise_spec) + + return speech_wave, noise_wave + + +def save_audios(noise_wave: torch.Tensor, + speech_wave: torch.Tensor, + mix_wave: torch.Tensor, + speech_wave_enhanced: torch.Tensor, + noise_wave_enhanced: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_wave.wav" + torchaudio.save(filename, noise_wave, sample_rate) + filename = output_dir / "speech_wave.wav" + torchaudio.save(filename, speech_wave, sample_rate) + filename = output_dir / "mix_wave.wav" + torchaudio.save(filename, mix_wave, sample_rate) + + filename = output_dir / "speech_wave_enhanced.wav" + torchaudio.save(filename, speech_wave_enhanced, sample_rate) + filename = output_dir / "noise_wave_enhanced.wav" + torchaudio.save(filename, noise_wave_enhanced, sample_rate) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + model = SpectrumDfNetPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + model.to(device) + model.eval() + + # optimizer + logger.info("prepare loss_fn") + irm_mse_loss = nn.MSELoss( + reduction="mean", + ) + snr_mse_loss = nn.MSELoss( + reduction="mean", + ) + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_wave, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + speech_wave, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + mix_wave: np.ndarray = mix_speech_and_noise( + speech=speech_wave, + noise=noise_wave, + snr_db=snr_db, + ) + noise_wave = torch.tensor(noise_wave, dtype=torch.float32) + speech_wave = torch.tensor(speech_wave, dtype=torch.float32) + mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) + + noise_wave = noise_wave.unsqueeze(dim=0) + speech_wave = speech_wave.unsqueeze(dim=0) + mix_wave = mix_wave.unsqueeze(dim=0) + + noise_spec: torch.Tensor = stft_power.forward(noise_wave) + speech_spec: torch.Tensor = stft_power.forward(speech_wave) + mix_spec: torch.Tensor = stft_power.forward(mix_wave) + + speech_spec_complex: torch.Tensor = stft_complex.forward(speech_wave) + mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) + # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2] + + noise_spec = noise_spec[:, :-1, :] + speech_spec = speech_spec[:, :-1, :] + mix_spec = mix_spec[:, :-1, :] + speech_spec_complex = speech_spec_complex[:, :-1, :] + mix_spec_complex = mix_spec_complex[:, :-1, :] + + speech_irm = speech_spec / (noise_spec + speech_spec) + speech_irm = torch.pow(speech_irm, 1.0) + + snr_db: torch.Tensor = 10 * torch.log10( + speech_spec / (noise_spec + 1e-8) + ) + snr_db = torch.clamp(snr_db, min=1e-8) + snr_db = torch.mean(snr_db, dim=1, keepdim=True) + # snr_db shape: [batch_size, 1, time_steps] + + speech_spec_complex = speech_spec_complex.to(device) + mix_spec_complex = mix_spec_complex.to(device) + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex) + speech_spec_prediction = torch.view_as_complex(speech_spec_prediction) + + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + # loss = irm_loss + 0.1 * snr_loss + loss = irm_loss + + # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2] + # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps] + batch_size, _, time_steps = speech_irm_prediction.shape + + mix_spec_complex = torch.concat( + [ + mix_spec_complex, + torch.zeros(size=(batch_size, 1, time_steps), dtype=mix_spec_complex.dtype).to(device) + ], + dim=1, + ) + speech_spec_prediction = torch.concat( + [ + speech_spec_prediction, + torch.zeros(size=(batch_size, 1, time_steps), dtype=speech_spec_prediction.dtype).to(device) + ], + dim=1, + ) + speech_irm_prediction = torch.concat( + [ + speech_irm_prediction, + 0.5 * torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device) + ], + dim=1, + ) + + # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps] + speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction) + save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/spectrum_dfnet_aishell/yaml/config.yaml b/examples/spectrum_dfnet_aishell/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f77bd9156f416d885a86d1f484a5767dd0e146c3 --- /dev/null +++ b/examples/spectrum_dfnet_aishell/yaml/config.yaml @@ -0,0 +1,53 @@ +model_name: "spectrum_unet_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +spec_bins: 256 + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +conv_lookahead: 0 + +convt_kernel_size_inner: + - 1 + - 3 + +embedding_hidden_size: 256 +encoder_combine_op: "concat" + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +encoder_linear_groups: 32 + +lsnr_max: 30 +lsnr_min: -15 +norm_tau: 1. + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +df_decoder_hidden_size: 256 +df_num_layers: 2 +df_order: 5 +df_bins: 96 +df_gru_skip: "grouped_linear" +df_decoder_linear_groups: 16 +df_pathway_kernel_size_t: 5 +df_lookahead: 2 + +# runtime +use_post_filter: true diff --git a/examples/spectrum_unet_irm_aishell/run.sh b/examples/spectrum_unet_irm_aishell/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..c24e65d89cd449f7d8ea883e8d2a4b4e37db2309 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/run.sh @@ -0,0 +1,178 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train" + + +sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + +sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/spectrum_unet_irm_aishell/step_1_prepare_data.py b/examples/spectrum_unet_irm_aishell/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..cf114513869c474b2588d6d81067ec801673b5e7 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/step_1_prepare_data.py @@ -0,0 +1,197 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_snr_db", default=-10, type=float) + parser.add_argument("--max_snr_db", default=20, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_snr_db, args.max_snr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/spectrum_unet_irm_aishell/step_2_train_model.py b/examples/spectrum_unet_irm_aishell/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..95f6dc31a94d99c31b4e43472bda078639e8a02a --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/step_2_train_model.py @@ -0,0 +1,420 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.spectrum_unet_irm.configuration_specturm_unet_irm import SpectrumUnetIRMConfig +from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-4, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + window_fn: str = "hamming", + irm_beta: float = 1.0, + epsilon: float = 1e-8, + ): + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.window_fn = window_fn + self.irm_beta = irm_beta + self.epsilon = epsilon + + self.transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=2.0, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + @staticmethod + def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3): + batch_size, channels, freq_dim, time_steps = x.shape + + # kernel: [freq_dim, n_time_step] + kernel_size = (freq_dim, n_time_steps) + + # pad + pad = n_time_steps // 2 + x = torch.concat(tensors=[ + x[:, :, :, :pad], + x, + x[:, :, :, -pad:], + ], dim=-1) + + x = F.unfold( + input=x, + kernel_size=kernel_size, + ) + # x shape: [batch_size, fold, time_steps] + return x + + def __call__(self, batch: List[dict]): + mix_spec_list = list() + speech_irm_list = list() + snr_db_list = list() + for sample in batch: + noise_wave: torch.Tensor = sample["noise_wave"] + speech_wave: torch.Tensor = sample["speech_wave"] + mix_wave: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + noise_spec = self.transform.forward(noise_wave) + speech_spec = self.transform.forward(speech_wave) + mix_spec = self.transform.forward(mix_wave) + + # noise_irm = noise_spec / (noise_spec + speech_spec) + speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) + speech_irm = torch.pow(speech_irm, self.irm_beta) + + # noise_spec, speech_spec, mix_spec, speech_irm + # shape: [freq_dim, time_steps] + + snr_db: torch.Tensor = 10 * torch.log10( + speech_spec / (noise_spec + self.epsilon) + ) + snr_db = torch.clamp(snr_db, min=self.epsilon) + + snr_db_ = torch.unsqueeze(snr_db, dim=0) + snr_db_ = torch.unsqueeze(snr_db_, dim=0) + snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3) + snr_db_ = torch.squeeze(snr_db_, dim=0) + # snr_db_ shape: [fold, time_steps] + + snr_db = torch.mean(snr_db_, dim=0, keepdim=True) + # snr_db shape: [1, time_steps] + + mix_spec_list.append(mix_spec) + speech_irm_list.append(speech_irm) + snr_db_list.append(snr_db) + + mix_spec_list = torch.stack(mix_spec_list) + speech_irm_list = torch.stack(speech_irm_list) + snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1) + + mix_spec_list = mix_spec_list[:, :-1, :] + speech_irm_list = speech_irm_list[:, :-1, :] + + # mix_spec_list shape: [batch_size, freq_dim, time_steps] + # speech_irm_list shape: [batch_size, freq_dim, time_steps] + # snr_db shape: [batch_size, 1, time_steps] + + # assert + if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)): + raise AssertionError("nan or inf in mix_spec_list") + if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)): + raise AssertionError("nan or inf in speech_irm_list") + if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)): + raise AssertionError("nan or inf in snr_db_list") + + return mix_spec_list, speech_irm_list, snr_db_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + config = SpectrumUnetIRMConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + # num_labels=vocabulary.get_vocab_size(namespace="labels") + ) + model = SpectrumUnetIRMPretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + param_optimizer = model.parameters() + optimizer = torch.optim.Adam( + param_optimizer, + lr=args.learning_rate, + ) + # lr_scheduler = torch.optim.lr_scheduler.StepLR( + # optimizer, + # step_size=2000 + # ) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + irm_mse_loss = nn.MSELoss( + reduction="mean", + ) + snr_mse_loss = nn.MSELoss( + reduction="mean", + ) + + # training loop + logger.info("training") + + training_loss = 10000000000 + evaluation_loss = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + for idx_epoch in range(args.max_epochs): + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + + for batch in train_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + speech_irm_prediction, lsnr_prediction = model.forward(mix_spec) + if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): + raise AssertionError("nan or inf in speech_irm_prediction") + if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): + raise AssertionError("nan or inf in lsnr_prediction") + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min) + if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0: + raise AssertionError(f"expected lsnr_prediction between 0 and 1.") + snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)): + raise AssertionError("nan or inf in snr_loss") + # loss = irm_loss + 0.1 * snr_loss + loss = 10.0 * irm_loss + 0.05 * snr_loss + # loss = irm_loss + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + training_loss = total_loss / total_examples + training_loss = round(training_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "training_loss": training_loss, + }) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + for batch in valid_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_irm_prediction, lsnr_prediction = model.forward(mix_spec) + if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)): + raise AssertionError("nan or inf in speech_irm_prediction") + if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)): + raise AssertionError("nan or inf in lsnr_prediction") + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min) + if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0: + raise AssertionError(f"expected lsnr_prediction between 0 and 1.") + snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + # loss = irm_loss + 0.1 * snr_loss + loss = 10.0 * irm_loss + 0.05 * snr_loss + # loss = irm_loss + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + elif evaluation_loss < best_metric: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "training_loss": training_loss, + "evaluation_loss": evaluation_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + return + + +if __name__ == '__main__': + main() diff --git a/examples/spectrum_unet_irm_aishell/step_3_evaluation.py b/examples/spectrum_unet_irm_aishell/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..38ef3054e4688d1b94f8d14cb3f3f3ca6b8b8296 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/step_3_evaluation.py @@ -0,0 +1,270 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +stft_power = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=2.0, + window_fn=torch.hamming_window, +) + + +stft_complex = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=None, + window_fn=torch.hamming_window, +) + + +istft = torchaudio.transforms.InverseSpectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, +) + + +def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): + mix_spec_complex = mix_spec_complex.detach().cpu() + speech_irm_prediction = speech_irm_prediction.detach().cpu() + + mask_speech = speech_irm_prediction + mask_noise = 1.0 - speech_irm_prediction + + speech_spec = mix_spec_complex * mask_speech + noise_spec = mix_spec_complex * mask_noise + + speech_wave = istft.forward(speech_spec) + noise_wave = istft.forward(noise_spec) + + return speech_wave, noise_wave + + +def save_audios(noise_wave: torch.Tensor, + speech_wave: torch.Tensor, + mix_wave: torch.Tensor, + speech_wave_enhanced: torch.Tensor, + noise_wave_enhanced: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_wave.wav" + torchaudio.save(filename, noise_wave, sample_rate) + filename = output_dir / "speech_wave.wav" + torchaudio.save(filename, speech_wave, sample_rate) + filename = output_dir / "mix_wave.wav" + torchaudio.save(filename, mix_wave, sample_rate) + + filename = output_dir / "speech_wave_enhanced.wav" + torchaudio.save(filename, speech_wave_enhanced, sample_rate) + filename = output_dir / "noise_wave_enhanced.wav" + torchaudio.save(filename, noise_wave_enhanced, sample_rate) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + model = SpectrumUnetIRMPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + model.to(device) + model.eval() + + # optimizer + logger.info("prepare loss_fn") + irm_mse_loss = nn.MSELoss( + reduction="mean", + ) + snr_mse_loss = nn.MSELoss( + reduction="mean", + ) + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_wave, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + speech_wave, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + mix_wave: np.ndarray = mix_speech_and_noise( + speech=speech_wave, + noise=noise_wave, + snr_db=snr_db, + ) + noise_wave = torch.tensor(noise_wave, dtype=torch.float32) + speech_wave = torch.tensor(speech_wave, dtype=torch.float32) + mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) + + noise_wave = noise_wave.unsqueeze(dim=0) + speech_wave = speech_wave.unsqueeze(dim=0) + mix_wave = mix_wave.unsqueeze(dim=0) + + noise_spec: torch.Tensor = stft_power.forward(noise_wave) + speech_spec: torch.Tensor = stft_power.forward(speech_wave) + mix_spec: torch.Tensor = stft_power.forward(mix_wave) + + noise_spec = noise_spec[:, :-1, :] + speech_spec = speech_spec[:, :-1, :] + mix_spec = mix_spec[:, :-1, :] + + mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) + # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2] + + speech_irm = speech_spec / (noise_spec + speech_spec) + speech_irm = torch.pow(speech_irm, 1.0) + + snr_db: torch.Tensor = 10 * torch.log10( + speech_spec / (noise_spec + 1e-8) + ) + snr_db = torch.mean(snr_db, dim=1, keepdim=True) + # snr_db shape: [batch_size, 1, time_steps] + + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_irm_prediction, lsnr_prediction = model.forward(mix_spec) + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + # loss = irm_loss + 0.1 * snr_loss + loss = irm_loss + + # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2] + # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps] + batch_size, _, time_steps = speech_irm_prediction.shape + speech_irm_prediction = torch.concat( + [ + speech_irm_prediction, + 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device) + ], + dim=1, + ) + # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps] + speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) + save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/spectrum_unet_irm_aishell/yaml/config.yaml b/examples/spectrum_unet_irm_aishell/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..756e93747642a4c1bdfebc82c739c341c85a5855 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/yaml/config.yaml @@ -0,0 +1,38 @@ +model_name: "spectrum_unet_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +spec_bins: 256 + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +conv_lookahead: 0 + +convt_kernel_size_inner: + - 1 + - 3 + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +lsnr_max: 30 +lsnr_min: -15 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +# runtime +use_post_filter: true diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..826e360a65034f9b489f2ff0e9feb1a286fede9b --- /dev/null +++ b/install.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash + +# bash install.sh --stage 2 --stop_stage 2 --system_version centos + + +python_version=3.8.10 +system_version="centos"; + +verbose=true; +stage=-1 +stop_stage=0 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +work_dir="$(pwd)" + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: install python" + cd "${work_dir}" || exit 1; + + sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}" +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: create virtualenv" + + # /usr/local/python-3.9.9/bin/virtualenv nx_denoise + # source /data/local/bin/nx_denoise/bin/activate + /usr/local/python-${python_version}/bin/pip3 install virtualenv + mkdir -p /data/local/bin + cd /data/local/bin || exit 1; + /usr/local/python-${python_version}/bin/virtualenv nx_denoise + +fi diff --git a/log.py b/log.py new file mode 100644 index 0000000000000000000000000000000000000000..068100f8d298a3ca8f4814cd212a91f61efa3066 --- /dev/null +++ b/log.py @@ -0,0 +1,229 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler +import os + + +def setup_size_rotating(log_directory: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + # main + main_logger = logging.getLogger("main") + main_logger.addHandler(stream_handler) + main_info_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "main.log"), + maxBytes=100*1024*1024, # 100MB + encoding="utf-8", + backupCount=2, + ) + main_info_file_handler.setLevel(logging.INFO) + main_info_file_handler.setFormatter(logging.Formatter(fmt)) + main_logger.addHandler(main_info_file_handler) + + # http + http_logger = logging.getLogger("http") + http_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "http.log"), + maxBytes=100*1024*1024, # 100MB + encoding="utf-8", + backupCount=2, + ) + http_file_handler.setLevel(logging.DEBUG) + http_file_handler.setFormatter(logging.Formatter(fmt)) + http_logger.addHandler(http_file_handler) + + # api + api_logger = logging.getLogger("api") + api_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "api.log"), + maxBytes=10*1024*1024, # 10MB + encoding="utf-8", + backupCount=2, + ) + api_file_handler.setLevel(logging.DEBUG) + api_file_handler.setFormatter(logging.Formatter(fmt)) + api_logger.addHandler(api_file_handler) + + # toolbox + toolbox_logger = logging.getLogger("toolbox") + toolbox_logger.addHandler(stream_handler) + toolbox_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "toolbox.log"), + maxBytes=10*1024*1024, # 10MB + encoding="utf-8", + backupCount=2, + ) + toolbox_file_handler.setLevel(logging.DEBUG) + toolbox_file_handler.setFormatter(logging.Formatter(fmt)) + toolbox_logger.addHandler(toolbox_file_handler) + + # alarm + alarm_logger = logging.getLogger("alarm") + alarm_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "alarm.log"), + maxBytes=1*1024*1024, # 1MB + encoding="utf-8", + backupCount=2, + ) + alarm_file_handler.setLevel(logging.DEBUG) + alarm_file_handler.setFormatter(logging.Formatter(fmt)) + alarm_logger.addHandler(alarm_file_handler) + + debug_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "debug.log"), + maxBytes=1*1024*1024, # 1MB + encoding="utf-8", + backupCount=2, + ) + debug_file_handler.setLevel(logging.DEBUG) + debug_file_handler.setFormatter(logging.Formatter(fmt)) + + info_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "info.log"), + maxBytes=1*1024*1024, # 1MB + encoding="utf-8", + backupCount=2, + ) + info_file_handler.setLevel(logging.INFO) + info_file_handler.setFormatter(logging.Formatter(fmt)) + + error_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "error.log"), + maxBytes=1*1024*1024, # 1MB + encoding="utf-8", + backupCount=2, + ) + error_file_handler.setLevel(logging.ERROR) + error_file_handler.setFormatter(logging.Formatter(fmt)) + + logging.basicConfig( + level=logging.DEBUG, + datefmt="%a, %d %b %Y %H:%M:%S", + handlers=[ + debug_file_handler, + info_file_handler, + error_file_handler, + ] + ) + + +def setup_time_rotating(log_directory: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + # main + main_logger = logging.getLogger("main") + main_logger.addHandler(stream_handler) + main_info_file_handler = TimedRotatingFileHandler( + filename=os.path.join(log_directory, "main.log"), + encoding="utf-8", + when="midnight", + interval=1, + backupCount=7 + ) + main_info_file_handler.setLevel(logging.INFO) + main_info_file_handler.setFormatter(logging.Formatter(fmt)) + main_logger.addHandler(main_info_file_handler) + + # http + http_logger = logging.getLogger("http") + http_file_handler = TimedRotatingFileHandler( + filename=os.path.join(log_directory, "http.log"), + encoding='utf-8', + when="midnight", + interval=1, + backupCount=7 + ) + http_file_handler.setLevel(logging.DEBUG) + http_file_handler.setFormatter(logging.Formatter(fmt)) + http_logger.addHandler(http_file_handler) + + # api + api_logger = logging.getLogger("api") + api_file_handler = TimedRotatingFileHandler( + filename=os.path.join(log_directory, "api.log"), + encoding='utf-8', + when="midnight", + interval=1, + backupCount=7 + ) + api_file_handler.setLevel(logging.DEBUG) + api_file_handler.setFormatter(logging.Formatter(fmt)) + api_logger.addHandler(api_file_handler) + + # toolbox + toolbox_logger = logging.getLogger("toolbox") + toolbox_file_handler = RotatingFileHandler( + filename=os.path.join(log_directory, "toolbox.log"), + maxBytes=10*1024*1024, # 10MB + encoding="utf-8", + backupCount=2, + ) + toolbox_file_handler.setLevel(logging.DEBUG) + toolbox_file_handler.setFormatter(logging.Formatter(fmt)) + toolbox_logger.addHandler(toolbox_file_handler) + + # alarm + alarm_logger = logging.getLogger("alarm") + alarm_file_handler = TimedRotatingFileHandler( + filename=os.path.join(log_directory, "alarm.log"), + encoding="utf-8", + when="midnight", + interval=1, + backupCount=7 + ) + alarm_file_handler.setLevel(logging.DEBUG) + alarm_file_handler.setFormatter(logging.Formatter(fmt)) + alarm_logger.addHandler(alarm_file_handler) + + debug_file_handler = TimedRotatingFileHandler( + filename=os.path.join(log_directory, "debug.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + debug_file_handler.setLevel(logging.DEBUG) + debug_file_handler.setFormatter(logging.Formatter(fmt)) + + info_file_handler = TimedRotatingFileHandler( + filename=os.path.join(log_directory, "info.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + info_file_handler.setLevel(logging.INFO) + info_file_handler.setFormatter(logging.Formatter(fmt)) + + error_file_handler = TimedRotatingFileHandler( + filename=os.path.join(log_directory, "error.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + error_file_handler.setLevel(logging.ERROR) + error_file_handler.setFormatter(logging.Formatter(fmt)) + + logging.basicConfig( + level=logging.DEBUG, + datefmt="%a, %d %b %Y %H:%M:%S", + handlers=[ + debug_file_handler, + info_file_handler, + error_file_handler, + ] + ) + + +if __name__ == "__main__": + pass diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..fd7445287b90cfe59b17938b81a42dcfd3d3d1bd --- /dev/null +++ b/main.py @@ -0,0 +1,292 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +docker build -t denoise:v20250609_1919 . +docker stop denoise_7865 && docker rm denoise_7865 +docker run -itd \ +--name denoise_7865 \ +--restart=always \ +--network host \ +-e server_port=7865 \ +-e hf_token=hf_coRVvzwAzCwGHKRK***********EX \ +denoise:v20250609_1919 /bin/bash + +""" +import argparse +import json +from functools import lru_cache +import logging +from pathlib import Path +import platform +import shutil +import tempfile +import time +from typing import Dict, Tuple +import zipfile + +import gradio as gr +from huggingface_hub import snapshot_download +import librosa +import librosa.display +import matplotlib.pyplot as plt +import numpy as np + +import log +from project_settings import environment, project_path, log_directory +from toolbox.os.command import Command +from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet +from toolbox.torchaudio.models.dfnet2.inference_dfnet2 import InferenceDfNet2 +from toolbox.torchaudio.models.dtln.inference_dtln import InferenceDTLN +from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN +from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet + + +log.setup_size_rotating(log_directory=log_directory) + +logger = logging.getLogger("main") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--examples_dir", + # default=(project_path / "data").as_posix(), + default=(project_path / "data/examples").as_posix(), + type=str + ) + parser.add_argument( + "--models_repo_id", + default="qgyd2021/nx_denoise", + type=str + ) + parser.add_argument( + "--trained_model_dir", + default=(project_path / "trained_models").as_posix(), + type=str + ) + parser.add_argument( + "--hf_token", + default=environment.get("hf_token"), + type=str, + ) + parser.add_argument( + "--server_port", + default=environment.get("server_port", 7860), + type=int + ) + + args = parser.parse_args() + return args + + +def shell(cmd: str): + return Command.popen(cmd) + + +def get_infer_cls_by_model_name(model_name: str): + if model_name.__contains__("dtln"): + infer_cls = InferenceDTLN + elif model_name.__contains__("dfnet2"): + infer_cls = InferenceDfNet2 + elif model_name.__contains__("frcrn"): + infer_cls = InferenceFRCRN + elif model_name.__contains__("mpnet"): + infer_cls = InferenceMPNet + else: + raise AssertionError + return infer_cls + + +denoise_engines: Dict[str, dict] = None + + +@lru_cache(maxsize=1) +def load_denoise_model(infer_cls, **kwargs): + infer_engine = infer_cls(**kwargs) + + return infer_engine + + +def generate_spectrogram(signal: np.ndarray, sample_rate: int = 8000, title: str = "Spectrogram"): + mag = np.abs(librosa.stft(signal)) + # mag_db = librosa.amplitude_to_db(mag, ref=np.max) + mag_db = librosa.amplitude_to_db(mag, ref=20) + + plt.figure(figsize=(10, 4)) + librosa.display.specshow(mag_db, sr=sample_rate) + plt.title(title) + + temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + plt.savefig(temp_file.name, bbox_inches="tight") + plt.close() + return temp_file.name + + +def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_t = None, engine: str = None): + if noisy_audio_file_t is None and noisy_audio_microphone_t is None: + raise gr.Error(f"audio file and microphone is null.") + if noisy_audio_file_t is not None and noisy_audio_microphone_t is not None: + gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.") + + noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t + + sample_rate, signal = noisy_audio_t + audio_duration = signal.shape[-1] // 8000 + + # Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。 + logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}") + + noisy_audio = np.array(signal / (1 << 15), dtype=np.float32) + + infer_engine_param = denoise_engines.get(engine) + if infer_engine_param is None: + raise gr.Error(f"invalid denoise engine: {engine}.") + + try: + infer_cls = infer_engine_param["infer_cls"] + kwargs = infer_engine_param["kwargs"] + infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs) + + begin = time.time() + denoise_audio = infer_engine.enhancement_by_ndarray(noisy_audio) + time_cost = time.time() - begin + + fpr = time_cost / audio_duration + info = { + "time_cost": round(time_cost, 4), + "audio_duration": round(audio_duration, 4), + "fpr": round(fpr, 4) + } + message = json.dumps(info, ensure_ascii=False, indent=4) + + noise_audio = noisy_audio - denoise_audio + + noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy") + denoise_mag_db = generate_spectrogram(denoise_audio, title="denoise") + noise_mag_db = generate_spectrogram(noise_audio, title="noise") + + denoise_audio = np.array(denoise_audio * (1 << 15), dtype=np.int16) + noise_audio = np.array(noise_audio * (1 << 15), dtype=np.int16) + + except Exception as e: + raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.") + + denoise_audio_t = (sample_rate, denoise_audio) + noise_audio_t = (sample_rate, noise_audio) + return denoise_audio_t, noise_audio_t, message, noisy_mag_db, denoise_mag_db, noise_mag_db + + +def main(): + args = get_args() + + examples_dir = Path(args.examples_dir) + trained_model_dir = Path(args.trained_model_dir) + + # download models + if not trained_model_dir.exists(): + trained_model_dir.mkdir(parents=True, exist_ok=True) + _ = snapshot_download( + repo_id=args.models_repo_id, + local_dir=trained_model_dir.as_posix(), + token=args.hf_token, + ) + + # engines + global denoise_engines + denoise_engines = { + filename.stem: { + "infer_cls": get_infer_cls_by_model_name(filename.stem), + "kwargs": { + "pretrained_model_path_or_zip_file": filename.as_posix() + } + } + for filename in (project_path / "trained_models").glob("*.zip") + if filename.name != "examples.zip" + } + + # choices + denoise_engine_choices = list(denoise_engines.keys()) + + # examples + if not examples_dir.exists(): + example_zip_file = trained_model_dir / "examples.zip" + with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip: + out_root = examples_dir + if out_root.exists(): + shutil.rmtree(out_root.as_posix()) + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + + # examples + examples = list() + for filename in examples_dir.glob("**/*.wav"): + examples.append([ + filename.as_posix(), + None, + denoise_engine_choices[0], + ]) + + # ui + with gr.Blocks() as blocks: + gr.Markdown(value="denoise.") + with gr.Tabs(): + with gr.TabItem("denoise"): + with gr.Row(): + with gr.Column(variant="panel", scale=5): + with gr.Tabs(): + with gr.TabItem("file"): + dn_noisy_audio_file = gr.Audio(label="noisy_audio") + with gr.TabItem("microphone"): + dn_noisy_audio_microphone = gr.Audio(sources="microphone", label="noisy_audio") + + dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine") + dn_button = gr.Button(variant="primary") + with gr.Column(variant="panel", scale=5): + with gr.Tabs(): + with gr.TabItem("audio"): + dn_denoise_audio = gr.Audio(label="denoise_audio") + dn_noise_audio = gr.Audio(label="noise_audio") + dn_message = gr.Textbox(lines=1, max_lines=20, label="message") + with gr.TabItem("mag_db"): + dn_noisy_mag_db = gr.Image(label="noisy_mag_db") + dn_denoise_mag_db = gr.Image(label="denoise_mag_db") + dn_noise_mag_db = gr.Image(label="noise_mag_db") + + dn_button.click( + when_click_denoise_button, + inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], + outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db] + ) + gr.Examples( + examples=examples, + inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], + outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db], + fn=when_click_denoise_button, + # cache_examples=True, + # cache_mode="lazy", + ) + + with gr.TabItem("shell"): + shell_text = gr.Textbox(label="cmd") + shell_button = gr.Button("run") + shell_output = gr.Textbox(label="output") + + shell_button.click( + shell, + inputs=[shell_text,], + outputs=[shell_output], + ) + + # http://127.0.0.1:7865/ + # http://10.75.27.247:7865/ + blocks.queue().launch( + # share=True, + share=False if platform.system() == "Windows" else False, + server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", + server_port=args.server_port + ) + return + + +if __name__ == "__main__": + main() diff --git a/project_settings.py b/project_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..ee94aa843ecc32ed9c6a0e17bf31750e0ef8c7ee --- /dev/null +++ b/project_settings.py @@ -0,0 +1,25 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os +from pathlib import Path + +from toolbox.os.environment import EnvironmentManager + + +project_path = os.path.abspath(os.path.dirname(__file__)) +project_path = Path(project_path) + +log_directory = project_path / "logs" +log_directory.mkdir(parents=True, exist_ok=True) + +# temp_directory = project_path / "temp" +# temp_directory.mkdir(parents=True, exist_ok=True) + +environment = EnvironmentManager( + path=os.path.join(project_path, "dotenv"), + env=os.environ.get("environment", "dev"), +) + + +if __name__ == '__main__': + pass diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..455d67c396b1ade6c10f9f8153820fa947c9fb8e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,16 @@ +gradio==5.33.0 +gradio_client==1.10.2 +datasets==3.2.0 +python-dotenv==1.0.1 +scipy==1.15.1 +librosa==0.10.2.post1 +pandas==2.2.3 +openpyxl==3.1.5 +torch==2.5.1 +torchaudio==2.5.1 +overrides==7.7.0 +torch-pesq==0.1.2 +torchmetrics==1.6.1 +torchmetrics[audio]==1.6.1 +einops==0.8.1 +torch-stoi==0.2.3 diff --git a/toolbox/__init__.py b/toolbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/json/__init__.py b/toolbox/json/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/json/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/json/misc.py b/toolbox/json/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..56022e111a555aa370ff833f4ee68f880d183ed9 --- /dev/null +++ b/toolbox/json/misc.py @@ -0,0 +1,63 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Callable + + +def traverse(js, callback: Callable, *args, **kwargs): + if isinstance(js, list): + result = list() + for l in js: + l = traverse(l, callback, *args, **kwargs) + result.append(l) + return result + elif isinstance(js, tuple): + result = list() + for l in js: + l = traverse(l, callback, *args, **kwargs) + result.append(l) + return tuple(result) + elif isinstance(js, dict): + result = dict() + for k, v in js.items(): + k = traverse(k, callback, *args, **kwargs) + v = traverse(v, callback, *args, **kwargs) + result[k] = v + return result + elif isinstance(js, int): + return callback(js, *args, **kwargs) + elif isinstance(js, str): + return callback(js, *args, **kwargs) + else: + return js + + +def demo1(): + d = { + "env": "ppe", + "mysql_connect": { + "host": "$mysql_connect_host", + "port": 3306, + "user": "callbot", + "password": "NxcloudAI2021!", + "database": "callbot_ppe", + "charset": "utf8" + }, + "es_connect": { + "hosts": ["10.20.251.8"], + "http_auth": ["elastic", "ElasticAI2021!"], + "port": 9200 + } + } + + def callback(s): + if isinstance(s, str) and s.startswith('$'): + return s[1:] + return s + + result = traverse(d, callback=callback) + print(result) + return + + +if __name__ == '__main__': + demo1() diff --git a/toolbox/os/__init__.py b/toolbox/os/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/os/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/os/command.py b/toolbox/os/command.py new file mode 100644 index 0000000000000000000000000000000000000000..40a1880fb08d21be79f31a13e51e389adb18c283 --- /dev/null +++ b/toolbox/os/command.py @@ -0,0 +1,59 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os + + +class Command(object): + custom_command = [ + "cd" + ] + + @staticmethod + def _get_cmd(command): + command = str(command).strip() + if command == "": + return None + cmd_and_args = command.split(sep=" ") + cmd = cmd_and_args[0] + args = " ".join(cmd_and_args[1:]) + return cmd, args + + @classmethod + def popen(cls, command): + cmd, args = cls._get_cmd(command) + if cmd in cls.custom_command: + method = getattr(cls, cmd) + return method(args) + else: + resp = os.popen(command) + result = resp.read() + resp.close() + return result + + @classmethod + def cd(cls, args): + if args.startswith("/"): + os.chdir(args) + else: + pwd = os.getcwd() + path = os.path.join(pwd, args) + os.chdir(path) + + @classmethod + def system(cls, command): + return os.system(command) + + def __init__(self): + pass + + +def ps_ef_grep(keyword: str): + cmd = "ps -ef | grep {}".format(keyword) + rows = Command.popen(cmd) + rows = str(rows).split("\n") + rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")] + return rows + + +if __name__ == "__main__": + pass diff --git a/toolbox/os/environment.py b/toolbox/os/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9bc004fc8777ae906c1caccd47bb257b126d55 --- /dev/null +++ b/toolbox/os/environment.py @@ -0,0 +1,114 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import json +import os + +from dotenv import load_dotenv +from dotenv.main import DotEnv + +from toolbox.json.misc import traverse + + +class EnvironmentManager(object): + def __init__(self, path, env, override=False): + filename = os.path.join(path, '{}.env'.format(env)) + self.filename = filename + + load_dotenv( + dotenv_path=filename, + override=override + ) + + self._environ = dict() + + def open_dotenv(self, filename: str = None): + filename = filename or self.filename + dotenv = DotEnv( + dotenv_path=filename, + stream=None, + verbose=False, + interpolate=False, + override=False, + encoding="utf-8", + ) + result = dotenv.dict() + return result + + def get(self, key, default=None, dtype=str): + result = os.environ.get(key) + if result is None: + if default is None: + result = None + else: + result = default + else: + result = dtype(result) + self._environ[key] = result + return result + + +_DEFAULT_DTYPE_MAP = { + 'int': int, + 'float': float, + 'str': str, + 'json.loads': json.loads +} + + +class JsonConfig(object): + """ + 将 json 中, 形如 `$float:threshold` 的值, 处理为: + 从环境变量中查到 threshold, 再将其转换为 float 类型. + """ + def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None): + self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP + self.environment = environment or os.environ + + def sanitize_by_filename(self, filename: str): + with open(filename, 'r', encoding='utf-8') as f: + js = json.load(f) + + return self.sanitize_by_json(js) + + def sanitize_by_json(self, js): + js = traverse( + js, + callback=self.sanitize, + environment=self.environment + ) + return js + + def sanitize(self, string, environment): + """支持 $ 符开始的, 环境变量配置""" + if isinstance(string, str) and string.startswith('$'): + dtype, key = string[1:].split(':') + dtype = self.dtype_map[dtype] + + value = environment.get(key) + if value is None: + raise AssertionError('environment not exist. key: {}'.format(key)) + + value = dtype(value) + result = value + else: + result = string + return result + + +def demo1(): + import json + + from project_settings import project_path + + environment = EnvironmentManager( + path=os.path.join(project_path, 'server/callbot_server/dotenv'), + env='dev', + ) + init_scenes = environment.get(key='init_scenes', dtype=json.loads) + print(init_scenes) + print(environment._environ) + return + + +if __name__ == '__main__': + demo1() diff --git a/toolbox/os/other.py b/toolbox/os/other.py new file mode 100644 index 0000000000000000000000000000000000000000..f215505eedfd714442d2fab241c8b1aff871d18a --- /dev/null +++ b/toolbox/os/other.py @@ -0,0 +1,9 @@ +import os +import inspect + + +def pwd(): + """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标""" + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + return os.path.dirname(os.path.abspath(module.__file__)) diff --git a/toolbox/torch/__init__.py b/toolbox/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/sparsification/__init__.py b/toolbox/torch/sparsification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torch/sparsification/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torch/sparsification/common.py b/toolbox/torch/sparsification/common.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6e66e3937c2da44adb509fea4b338a56d006b2 --- /dev/null +++ b/toolbox/torch/sparsification/common.py @@ -0,0 +1,131 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" +import torch + + +""" +https://github.com/xiph/rnnoise/blob/main/torch/sparsification/common.py +""" + +def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False): + """ sparsifies matrix with specified block size + + Parameters: + ----------- + matrix : torch.tensor + matrix to sparsify + density : int + target density + block_size : [int, int] + block size dimensions + keep_diagonal : bool + If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False + """ + + m, n = matrix.shape + m1, n1 = block_size + + if m % m1 or n % n1: + raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}") + + # extract diagonal if keep_diagonal = True + if keep_diagonal: + if m != n: + raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True") + + to_spare = torch.diag(torch.diag(matrix)) + matrix = matrix - to_spare + else: + to_spare = torch.zeros_like(matrix) + + # calculate energy in sub-blocks + x = torch.reshape(matrix, (m // m1, m1, n // n1, n1)) + x = x ** 2 + block_energies = torch.sum(torch.sum(x, dim=3), dim=1) + + number_of_blocks = (m * n) // (m1 * n1) + number_of_survivors = round(number_of_blocks * density) + + # masking threshold + if number_of_survivors == 0: + threshold = 0 + else: + threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors] + + # create mask + mask = torch.ones_like(block_energies) + mask[block_energies < threshold] = 0 + mask = torch.repeat_interleave(mask, m1, dim=0) + mask = torch.repeat_interleave(mask, n1, dim=1) + + # perform masking + masked_matrix = mask * matrix + to_spare + + if return_mask: + return masked_matrix, mask + else: + return masked_matrix + +def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False): + input_size = gru.input_size + hidden_size = gru.hidden_size + flops = 0 + + input_density = ( + sparsification_dict.get('W_ir', [1])[0] + + sparsification_dict.get('W_in', [1])[0] + + sparsification_dict.get('W_iz', [1])[0] + ) / 3 + + recurrent_density = ( + sparsification_dict.get('W_hr', [1])[0] + + sparsification_dict.get('W_hn', [1])[0] + + sparsification_dict.get('W_hz', [1])[0] + ) / 3 + + # input matrix vector multiplications + if not drop_input: + flops += 2 * 3 * input_size * hidden_size * input_density + + # recurrent matrix vector multiplications + flops += 2 * 3 * hidden_size * hidden_size * recurrent_density + + # biases + flops += 6 * hidden_size + + # activations estimated by 10 flops per activation + flops += 30 * hidden_size + + return flops + + +if __name__ == "__main__": + pass diff --git a/toolbox/torch/sparsification/gru_sparsifier.py b/toolbox/torch/sparsification/gru_sparsifier.py new file mode 100644 index 0000000000000000000000000000000000000000..91ddbbaa60ec1e011a0bd9042b4c8bdfc1e6d671 --- /dev/null +++ b/toolbox/torch/sparsification/gru_sparsifier.py @@ -0,0 +1,190 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +/* Copyright (c) 2023 Amazon + Written by Jan Buethe */ +/* + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER + OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +""" +import torch + +from toolbox.torch.sparsification.common import sparsify_matrix + + +""" +https://github.com/xiph/rnnoise/blob/main/torch/sparsification/gru_sparsifier.py +""" + +class GRUSparsifier: + def __init__(self, task_list, start, stop, interval, exponent=3): + """ Sparsifier for torch.nn.GRUs + + Parameters: + ----------- + task_list : list + task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance + of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in', + 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset, + update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal), + where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which + sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal + should be kept. + + start : int + training step after which sparsification will be started. + + stop : int + training step after which sparsification will be completed. + + interval : int + sparsification interval for steps between start and stop. After stop sparsification will be + carried out after every call to GRUSparsifier.step() + + exponent : float + Interpolation exponent for sparsification interval. In step i sparsification will be carried out + with density (alpha + target_density * (1 * alpha)), where + alpha = ((stop - i) / (start - stop)) ** exponent + + Example: + -------- + >>> import torch + >>> gru = torch.nn.GRU(10, 20) + >>> sparsify_dict = { + ... 'W_ir' : (0.5, [2, 2], False), + ... 'W_iz' : (0.6, [2, 2], False), + ... 'W_in' : (0.7, [2, 2], False), + ... 'W_hr' : (0.1, [4, 4], True), + ... 'W_hz' : (0.2, [4, 4], True), + ... 'W_hn' : (0.3, [4, 4], True), + ... } + >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50) + >>> for i in range(100): + ... sparsifier.step() + """ + # just copying parameters... + self.start = start + self.stop = stop + self.interval = interval + self.exponent = exponent + self.task_list = task_list + + # ... and setting counter to 0 + self.step_counter = 0 + + self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']} + + def step(self, verbose=False): + """ carries out sparsification step + + Call this function after optimizer.step in your + training loop. + + Parameters: + ---------- + verbose : bool + if true, densities are printed out + + Returns: + -------- + None + + """ + # compute current interpolation factor + self.step_counter += 1 + + if self.step_counter < self.start: + return + elif self.step_counter < self.stop: + # update only every self.interval-th interval + if self.step_counter % self.interval: + return + + alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent + else: + alpha = 0 + + with torch.no_grad(): + for gru, params in self.task_list: + hidden_size = gru.hidden_size + + # input weights + for i, key in enumerate(['W_ir', 'W_iz', 'W_in']): + if key in params: + density = alpha + (1 - alpha) * params[key][0] + if verbose: + print(f"[{self.step_counter}]: {key} density: {density}") + + gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ], + density, # density + params[key][1], # block_size + params[key][2], # keep_diagonal (might want to set this to False) + return_mask=True + ) + + if type(self.last_masks[key]) != type(None): + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: + print(f"sparsification mask {key} changed for gru {gru}") + + self.last_masks[key] = new_mask + + # recurrent weights + for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']): + if key in params: + density = alpha + (1 - alpha) * params[key][0] + if verbose: + print(f"[{self.step_counter}]: {key} density: {density}") + gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( + gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ], + density, + params[key][1], # block_size + params[key][2], # keep_diagonal (might want to set this to False) + return_mask=True + ) + + if type(self.last_masks[key]) != type(None): + if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: + print(f"sparsification mask {key} changed for gru {gru}") + + self.last_masks[key] = new_mask + + +if __name__ == "__main__": + print("Testing sparsifier") + + gru = torch.nn.GRU(10, 20) + sparsify_dict = { + 'W_ir' : (0.5, [2, 2], False), + 'W_iz' : (0.6, [2, 2], False), + 'W_in' : (0.7, [2, 2], False), + 'W_hr' : (0.1, [4, 4], True), + 'W_hz' : (0.2, [4, 4], True), + 'W_hn' : (0.3, [4, 4], True), + } + + sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10) + + for i in range(100): + sparsifier.step(verbose=True) diff --git a/toolbox/torch/training/__init__.py b/toolbox/torch/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/training/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/training/metrics/__init__.py b/toolbox/torch/training/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/training/metrics/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/training/metrics/pesq.py b/toolbox/torch/training/metrics/pesq.py new file mode 100644 index 0000000000000000000000000000000000000000..67f503f5c4242464f77f87cdd8d164d98a108c69 --- /dev/null +++ b/toolbox/torch/training/metrics/pesq.py @@ -0,0 +1,108 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Optional + +import torch +from torch_pesq import PesqLoss + + +class Pesq(object): + def __init__(self): + pass + + +class CategoricalAccuracy(object): + def __init__(self, top_k: int = 1, tie_break: bool = False) -> None: + if top_k > 1 and tie_break: + raise AssertionError("Tie break in Categorical Accuracy " + "can be done only for maximum (top_k = 1)") + if top_k <= 0: + raise AssertionError("top_k passed to Categorical Accuracy must be > 0") + self._top_k = top_k + self._tie_break = tie_break + self.correct_count = 0. + self.total_count = 0. + + def __call__(self, + predictions: torch.Tensor, + gold_labels: torch.Tensor, + mask: Optional[torch.Tensor] = None): + + # predictions, gold_labels, mask = self.unwrap_to_tensors(predictions, gold_labels, mask) + + # Some sanity checks. + num_classes = predictions.size(-1) + if gold_labels.dim() != predictions.dim() - 1: + raise AssertionError("gold_labels must have dimension == predictions.size() - 1 but " + "found tensor of shape: {}".format(predictions.size())) + if (gold_labels >= num_classes).any(): + raise AssertionError("A gold label passed to Categorical Accuracy contains an id >= {}, " + "the number of classes.".format(num_classes)) + + predictions = predictions.view((-1, num_classes)) + gold_labels = gold_labels.view(-1).long() + if not self._tie_break: + # Top K indexes of the predictions (or fewer, if there aren't K of them). + # Special case topk == 1, because it's common and .max() is much faster than .topk(). + if self._top_k == 1: + top_k = predictions.max(-1)[1].unsqueeze(-1) + else: + top_k = predictions.topk(min(self._top_k, predictions.shape[-1]), -1)[1] + + # This is of shape (batch_size, ..., top_k). + correct = top_k.eq(gold_labels.unsqueeze(-1)).float() + else: + # prediction is correct if gold label falls on any of the max scores. distribute score by tie_counts + max_predictions = predictions.max(-1)[0] + max_predictions_mask = predictions.eq(max_predictions.unsqueeze(-1)) + # max_predictions_mask is (rows X num_classes) and gold_labels is (batch_size) + # ith entry in gold_labels points to index (0-num_classes) for ith row in max_predictions + # For each row check if index pointed by gold_label is was 1 or not (among max scored classes) + correct = max_predictions_mask[torch.arange(gold_labels.numel()).long(), gold_labels].float() + tie_counts = max_predictions_mask.sum(-1) + correct /= tie_counts.float() + correct.unsqueeze_(-1) + + if mask is not None: + correct *= mask.view(-1, 1).float() + self.total_count += mask.sum() + else: + self.total_count += gold_labels.numel() + self.correct_count += correct.sum() + + def get_metric(self, reset: bool = False): + """ + Returns + ------- + The accumulated accuracy. + """ + if self.total_count > 1e-12: + accuracy = float(self.correct_count) / float(self.total_count) + else: + accuracy = 0.0 + if reset: + self.reset() + return {'accuracy': accuracy} + + def reset(self): + self.correct_count = 0.0 + self.total_count = 0.0 + + +def main(): + pesq = PesqLoss(0.5, + sample_rate=8000, + ) + + reference = torch.randn(1, 44100) + degraded = torch.randn(1, 44100) + + mos = pesq.mos(reference, degraded) + loss = pesq(reference, degraded) + + print(mos, loss) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torch/training/metrics/stoi.py b/toolbox/torch/training/metrics/stoi.py new file mode 100644 index 0000000000000000000000000000000000000000..24e5d335267ffd1e92db5de6c8c4dfb1097e3d31 --- /dev/null +++ b/toolbox/torch/training/metrics/stoi.py @@ -0,0 +1,19 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import torch +from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility + + +# 假设 reference 和 degraded 是两个音频信号的张量 +reference = torch.randn(1, 16000) # 参考信号 +degraded = torch.randn(1, 16000) # 降质信号 + + +# 计算 STOI 分数 +stoi_score = short_time_objective_intelligibility(reference, degraded, fs=16000) + +print(f"STOI 分数: {stoi_score}") + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/__init__.py b/toolbox/torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/utils/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/data/__init__.py b/toolbox/torch/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/utils/data/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/data/dataset/__init__.py b/toolbox/torch/utils/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/utils/data/dataset/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/data/dataset/denoise_excel_dataset.py b/toolbox/torch/utils/data/dataset/denoise_excel_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a62b918b64a0f5eb79be9d6eb8e36270a8ef3156 --- /dev/null +++ b/toolbox/torch/utils/data/dataset/denoise_excel_dataset.py @@ -0,0 +1,133 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torchaudio +from torch.utils.data import Dataset +from tqdm import tqdm + + +class DenoiseExcelDataset(Dataset): + def __init__(self, + excel_file: str, + expected_sample_rate: int, + resample: bool = False, + max_wave_value: float = 1.0, + eps: float = 1e-8, + ): + self.excel_file = excel_file + self.expected_sample_rate = expected_sample_rate + self.resample = resample + self.max_wave_value = max_wave_value + self.eps = eps + + self.samples = self.load_samples(excel_file) + + @staticmethod + def load_samples(filename: str): + df = pd.read_excel(filename) + samples = list() + for i, row in tqdm(df.iterrows(), total=len(df)): + noise_filename = row["noise_filename"] + noise_raw_duration = row["noise_raw_duration"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_raw_duration = row["speech_raw_duration"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": snr_db, + } + samples.append(row) + return samples + + def __getitem__(self, index): + sample = self.samples[index] + noise_filename = sample["noise_filename"] + noise_offset = sample["noise_offset"] + noise_duration = sample["noise_duration"] + + speech_filename = sample["speech_filename"] + speech_offset = sample["speech_offset"] + speech_duration = sample["speech_duration"] + + snr_db = sample["snr_db"] + + noise_wave = self.filename_to_waveform(noise_filename, noise_offset, noise_duration) + speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration) + + mix_wave, noise_wave_adjusted = self.mix_speech_and_noise( + speech=speech_wave.numpy(), + noise=noise_wave.numpy(), + snr_db=snr_db, eps=self.eps, + ) + mix_wave = torch.tensor(mix_wave, dtype=torch.float32) + noise_wave_adjusted = torch.tensor(noise_wave_adjusted, dtype=torch.float32) + + result = { + "noise_wave": noise_wave_adjusted, + "speech_wave": speech_wave, + "mix_wave": mix_wave, + "snr_db": snr_db, + } + return result + + def __len__(self): + return len(self.samples) + + def filename_to_waveform(self, filename: str, offset: float, duration: float): + try: + waveform, sample_rate = librosa.load( + filename, + sr=self.expected_sample_rate, + offset=offset, + duration=duration, + ) + except ValueError as e: + print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}") + raise e + waveform = torch.tensor(waveform, dtype=torch.float32) + return waveform + + @staticmethod + def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float, eps: float = 1e-8): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / (np.sqrt(np.mean(noise ** 2)) + eps) + + noisy_signal = speech + noise_adjusted + + return noisy_signal, noise_adjusted + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py b/toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ea30406b6c998f77cb492a3a5c0eb01daad77be3 --- /dev/null +++ b/toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py @@ -0,0 +1,176 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import json +import random +from typing import List + +import librosa +import numpy as np +import torch +from torch.utils.data import Dataset, IterableDataset + + +class DenoiseJsonlDataset(IterableDataset): + def __init__(self, + jsonl_file: str, + expected_sample_rate: int, + resample: bool = False, + max_wave_value: float = 1.0, + buffer_size: int = 1000, + min_snr_db: float = None, + max_snr_db: float = None, + eps: float = 1e-8, + skip: int = 0, + ): + self.jsonl_file = jsonl_file + self.expected_sample_rate = expected_sample_rate + self.resample = resample + self.max_wave_value = max_wave_value + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + self.eps = eps + self.skip = skip + + self.buffer_size = buffer_size + self.buffer_samples: List[dict] = list() + + def __iter__(self): + self.buffer_samples = list() + + iterable_source = self.iterable_source() + + try: + for _ in range(self.skip): + next(iterable_source) + except StopIteration: + pass + + # 初始填充缓冲区 + try: + for _ in range(self.buffer_size): + self.buffer_samples.append(next(iterable_source)) + except StopIteration: + pass + + # 动态替换逻辑 + while True: + try: + item = next(iterable_source) + # 随机替换缓冲区元素 + replace_idx = random.randint(0, len(self.buffer_samples) - 1) + sample = self.buffer_samples[replace_idx] + self.buffer_samples[replace_idx] = item + yield self.convert_sample(sample) + except StopIteration: + break + + # 清空剩余元素 + random.shuffle(self.buffer_samples) + for sample in self.buffer_samples: + yield self.convert_sample(sample) + + def iterable_source(self): + last_sample = None + with open(self.jsonl_file, "r", encoding="utf-8") as f: + for row in f: + row = json.loads(row) + noise_filename = row["noise_filename"] + noise_raw_duration = row["noise_raw_duration"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_raw_duration = row["speech_raw_duration"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + if self.min_snr_db is None or self.max_snr_db is None: + snr_db = row["snr_db"] + else: + snr_db = random.uniform(self.min_snr_db, self.max_snr_db) + + sample = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": snr_db, + } + if last_sample is None: + last_sample = sample + continue + yield sample + yield last_sample + + def convert_sample(self, sample: dict): + noise_filename = sample["noise_filename"] + noise_offset = sample["noise_offset"] + noise_duration = sample["noise_duration"] + + speech_filename = sample["speech_filename"] + speech_offset = sample["speech_offset"] + speech_duration = sample["speech_duration"] + + snr_db = sample["snr_db"] + + noise_wave = self.filename_to_waveform(noise_filename, noise_offset, noise_duration) + speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration) + + mix_wave, noise_wave_adjusted = self.mix_speech_and_noise( + speech=speech_wave.numpy(), + noise=noise_wave.numpy(), + snr_db=snr_db, eps=self.eps, + ) + mix_wave = torch.tensor(mix_wave, dtype=torch.float32) + noise_wave_adjusted = torch.tensor(noise_wave_adjusted, dtype=torch.float32) + + result = { + "noise_wave": noise_wave_adjusted, + "speech_wave": speech_wave, + "mix_wave": mix_wave, + "snr_db": snr_db, + } + return result + + def filename_to_waveform(self, filename: str, offset: float, duration: float): + try: + waveform, sample_rate = librosa.load( + filename, + sr=self.expected_sample_rate, + offset=offset, + duration=duration, + ) + except ValueError as e: + print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}") + raise e + waveform = torch.tensor(waveform, dtype=torch.float32) + return waveform + + @staticmethod + def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float, eps: float = 1e-8): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / (np.sqrt(np.mean(noise ** 2)) + eps) + + noisy_signal = speech + noise_adjusted + + return noisy_signal, noise_adjusted + + +if __name__ == "__main__": + pass diff --git a/toolbox/torch/utils/data/dataset/mp3_to_wav_jsonl_dataset.py b/toolbox/torch/utils/data/dataset/mp3_to_wav_jsonl_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f8880d06f209e5a09d2d12504e718938130b6db3 --- /dev/null +++ b/toolbox/torch/utils/data/dataset/mp3_to_wav_jsonl_dataset.py @@ -0,0 +1,197 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import json +import os +import random +from typing import List +from pathlib import Path +import tempfile +import uuid + +from pydub import AudioSegment +from scipy.io import wavfile +import librosa +import numpy as np +import torch +from torch.utils.data import Dataset, IterableDataset + + +class Mp3ToWavJsonlDataset(IterableDataset): + def __init__(self, + jsonl_file: str, + expected_sample_rate: int, + resample: bool = False, + max_wave_value: float = 1.0, + buffer_size: int = 1000, + eps: float = 1e-8, + skip: int = 0, + ): + self.jsonl_file = jsonl_file + self.expected_sample_rate = expected_sample_rate + self.resample = resample + self.max_wave_value = max_wave_value + self.eps = eps + self.skip = skip + + self.buffer_size = buffer_size + self.buffer_samples: List[dict] = list() + + def __iter__(self): + self.buffer_samples = list() + + iterable_source = self.iterable_source() + + try: + for _ in range(self.skip): + next(iterable_source) + except StopIteration: + pass + + # 初始填充缓冲区 + try: + for _ in range(self.buffer_size): + self.buffer_samples.append(next(iterable_source)) + except StopIteration: + pass + + # 动态替换逻辑 + while True: + try: + item = next(iterable_source) + # 随机替换缓冲区元素 + replace_idx = random.randint(0, len(self.buffer_samples) - 1) + sample = self.buffer_samples[replace_idx] + self.buffer_samples[replace_idx] = item + yield self.convert_sample(sample) + except StopIteration: + break + + # 清空剩余元素 + random.shuffle(self.buffer_samples) + for sample in self.buffer_samples: + yield self.convert_sample(sample) + + def iterable_source(self): + last_sample = None + with open(self.jsonl_file, "r", encoding="utf-8") as f: + for row in f: + row = json.loads(row) + filename = row["filename"] + raw_duration = row["raw_duration"] + offset = row["offset"] + duration = row["duration"] + + sample = { + "filename": filename, + "raw_duration": raw_duration, + "offset": offset, + "duration": duration, + } + if last_sample is None: + last_sample = sample + continue + yield sample + yield last_sample + + def convert_sample(self, sample: dict): + filename = sample["filename"] + offset = sample["offset"] + duration = sample["duration"] + + wav_waveform = self.filename_to_waveform(filename, offset, duration) + mp3_waveform = self.filename_to_mp3_waveform(filename, offset, duration) + + if wav_waveform.shape != mp3_waveform.shape: + raise AssertionError(f"wav_waveform: {wav_waveform.shape}, mp3_waveform: {mp3_waveform.shape}") + + result = { + "mp3_waveform": mp3_waveform, + "wav_waveform": wav_waveform, + } + return result + + @staticmethod + def filename_to_waveform(filename: str, offset: float, duration: float, expected_sample_rate: int = 8000): + try: + waveform, sample_rate = librosa.load( + filename, + sr=expected_sample_rate, + offset=offset, + duration=duration, + ) + except ValueError as e: + print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}") + raise e + waveform = torch.tensor(waveform, dtype=torch.float32) + return waveform + + @staticmethod + def get_temporary_file(suffix: str = ".wav"): + temp_audio_dir = Path(tempfile.gettempdir()) / "mp3_to_wav_jsonl_dataset" + temp_audio_dir.mkdir(parents=True, exist_ok=True) + filename = temp_audio_dir / f"{uuid.uuid4()}{suffix}" + filename = filename.as_posix() + return filename + + @staticmethod + def filename_to_mp3_waveform(filename: str, offset: float, duration: float, expected_sample_rate: int = 8000): + try: + waveform, sample_rate = librosa.load( + filename, + sr=expected_sample_rate, + offset=offset, + duration=duration, + ) + waveform = np.array(waveform * (1 << 15), dtype=np.int16) + except ValueError as e: + print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}") + raise e + + wav_temporary_file = Mp3ToWavJsonlDataset.get_temporary_file(suffix=".wav") + wavfile.write( + wav_temporary_file, + rate=sample_rate, + data=waveform, + ) + + mp3_temporary_file = Mp3ToWavJsonlDataset.get_temporary_file(suffix=".mp3") + + audio = AudioSegment.from_wav(wav_temporary_file) + audio.export(mp3_temporary_file, + format="mp3", + bitrate="64k", # 8kHz建议使用64kbps + # parameters=["-ar", "8000"] + parameters=["-ar", f"{expected_sample_rate}"] + ) + + try: + waveform, sample_rate = librosa.load(mp3_temporary_file, sr=expected_sample_rate) + except ValueError as e: + print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}") + raise e + + os.remove(wav_temporary_file) + os.remove(mp3_temporary_file) + + waveform = torch.tensor(waveform, dtype=torch.float32) + return waveform + + +def main(): + filename = r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-PH\2025-06-13\active_media_r_2e6e6303-4a2e-4bc9-b814-98ceddc59e9d_23.wav" + + waveform = Mp3ToWavJsonlDataset.filename_to_mp3_waveform(filename, offset=0, duration=15) + print(waveform.shape) + + signal = np.array(waveform.numpy() * (1 << 15), dtype=np.int16) + + wavfile.write( + "temp.wav", + 8000, + signal, + ) + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/__init__.py b/toolbox/torchaudio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4aad738e112896111c38ae6624c8632aee62a234 --- /dev/null +++ b/toolbox/torchaudio/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/configuration_utils.py b/toolbox/torchaudio/configuration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33f85f3217aa622bdf6447d3491dc43207a66631 --- /dev/null +++ b/toolbox/torchaudio/configuration_utils.py @@ -0,0 +1,64 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import copy +import os +from typing import Any, Dict, Union + +import yaml + + +CONFIG_FILE = "config.yaml" +DISCRIMINATOR_CONFIG_FILE = "discriminator_config.yaml" + + +class PretrainedConfig(object): + def __init__(self, **kwargs): + pass + + @classmethod + def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]): + with open(yaml_file, encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + return config_dict + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike] + ) -> Dict[str, Any]: + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE) + else: + config_file = pretrained_model_name_or_path + config_dict = cls._dict_from_yaml_file(config_file) + return config_dict + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs): + for k, v in kwargs.items(): + if k in config_dict.keys(): + config_dict[k] = v + config = cls(**config_dict) + return config + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ): + config_dict = cls.get_config_dict(pretrained_model_name_or_path) + return cls.from_dict(config_dict, **kwargs) + + def to_dict(self): + output = copy.deepcopy(self.__dict__) + return output + + def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]): + config_dict = self.to_dict() + + with open(yaml_file_path, "w", encoding="utf-8") as writer: + yaml.safe_dump(config_dict, writer) + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/losses/__init__.py b/toolbox/torchaudio/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/losses/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/losses/irm.py b/toolbox/torchaudio/losses/irm.py new file mode 100644 index 0000000000000000000000000000000000000000..d14683dcdcc77ae360a3af982a15c8ca6d25fd89 --- /dev/null +++ b/toolbox/torchaudio/losses/irm.py @@ -0,0 +1,174 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import List + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class CIRMLoss(nn.Module): + def __init__(self, + n_fft: int = 512, + win_size: int = 512, + hop_size: int = 256, + center: bool = True, + eps: float = 1e-8, + reduction: str = "mean", + ): + super(CIRMLoss, self).__init__() + self.n_fft = n_fft + self.win_size = win_size + self.hop_size = hop_size + self.center = center + self.eps = eps + self.reduction = reduction + + self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) + + if reduction not in ("sum", "mean"): + raise AssertionError(f"param reduction must be sum or mean.") + + def forward(self, clean: torch.Tensor, noisy: torch.Tensor, mask_real: torch.Tensor, mask_imag: torch.Tensor): + """ + :param clean: waveform + :param noisy: waveform + :param mask_real: shape: [b, f, t] + :param mask_imag: shape: [b, f, t] + :return: + """ + if noisy.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + + # clean_stft, noisy_stft shape: [b, f, t] + clean_stft = torch.stft( + clean, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + noisy_stft = torch.stft( + noisy, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + + # [b, f, t] + clean_stft_spec_real = torch.real(clean_stft) + clean_stft_spec_imag = torch.imag(clean_stft) + noisy_stft_spec_real = torch.real(noisy_stft) + noisy_stft_spec_imag = torch.imag(noisy_stft) + noisy_power = noisy_stft_spec_real ** 2 + noisy_stft_spec_imag ** 2 + + sr = clean_stft_spec_real + yr = noisy_stft_spec_real + si = clean_stft_spec_imag + yi = noisy_stft_spec_imag + y_pow = noisy_power + # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8) + gth_mask_real = (sr * yr + si * yi) / (y_pow + self.eps) + # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8) + gth_mask_imag = (sr * yr - si * yi) / (y_pow + self.eps) + + gth_mask_real[gth_mask_real > 2] = 1 + gth_mask_real[gth_mask_real < -2] = -1 + gth_mask_imag[gth_mask_imag > 2] = 1 + gth_mask_imag[gth_mask_imag < -2] = -1 + + amp_loss = F.mse_loss(gth_mask_real, mask_real) + phase_loss = F.mse_loss(gth_mask_imag, mask_imag) + + loss = amp_loss + phase_loss + return loss + + +class IRMLoss(nn.Module): + """ + https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L25 + """ + def __init__(self, + n_fft: int = 512, + win_size: int = 512, + hop_size: int = 256, + center: bool = True, + eps: float = 1e-8, + reduction: str = "mean", + ): + super(IRMLoss, self).__init__() + self.n_fft = n_fft + self.win_size = win_size + self.hop_size = hop_size + self.center = center + self.eps = eps + self.reduction = reduction + + self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) + + if reduction not in ("sum", "mean"): + raise AssertionError(f"param reduction must be sum or mean.") + + def forward(self, mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): + if noisy.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + noise = noisy - clean + + # clean_stft, noisy_stft shape: [b, f, t] + stft_clean = torch.stft( + clean, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + stft_noise = torch.stft( + noise, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + + mag_clean = torch.abs(stft_clean) + mag_noise = torch.abs(stft_noise) + + gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1) + + loss = F.l1_loss(gth_irm_mask, mask, reduction=self.reduction) + return loss + + +def main(): + batch_size = 2 + signal_length = 16000 + estimated_signal = torch.randn(batch_size, signal_length) + target_signal = torch.randn(batch_size, signal_length) + + loss_fn = CIRMLoss() + + loss = loss_fn.forward(estimated_signal, target_signal) + print(f"loss: {loss.item()}") + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/losses/perceptual.py b/toolbox/torchaudio/losses/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..b4368d7a605dd97bc665ce13093497e86ed7bea8 --- /dev/null +++ b/toolbox/torchaudio/losses/perceptual.py @@ -0,0 +1,122 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://zhuanlan.zhihu.com/p/627039860 +""" +import torch +import torch.nn as nn +from torch_stoi import NegSTOILoss as TorchNegSTOILoss +from torch_pesq import PesqLoss as TorchPesqLoss + + +class PMSQELoss(object): + """ + A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality + https://sigmat.ugr.es/PMSQE/ + + On Loss Functions for Supervised Monaural Time-Domain Speech Enhancement + https://arxiv.org/abs/1909.01019 + + https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/pmsqe.py + """ + + +class NegSTOILoss(nn.Module): + """ + STOI短时客观可懂度(Short-Time Objective Intelligibility), + 通过计算语音信号的时域和频域特征之间的相关性来预测语音的可理解度, + 范围从0到1,分数越高可懂度越高。 + 它适用于评估噪声环境下的语音可懂度改善效果。 + + https://github.com/mpariente/pytorch_stoi + https://github.com/mpariente/pystoi + https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/loss/stoi_loss.py + """ + def __init__(self, + sample_rate: int, + reduction: str = "mean", + ): + super(NegSTOILoss, self).__init__() + self.loss_fn = TorchNegSTOILoss(sample_rate=sample_rate) + self.reduction = reduction + + if reduction not in ("sum", "mean"): + raise AssertionError(f"param reduction must be sum or mean.") + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + + batch_loss = self.loss_fn.forward(denoise, clean) + + if self.reduction == "mean": + loss = torch.mean(batch_loss) + elif self.reduction == "sum": + loss = torch.sum(batch_loss) + else: + raise AssertionError + return loss + + +class PesqLoss(nn.Module): + def __init__(self, + factor: float, + sample_rate: int = 48000, + nbarks: int = 49, + win_length: int = 512, + n_fft: int = 512, + hop_length: int = 256, + reduction: str = "mean", + ): + super(PesqLoss, self).__init__() + self.factor = factor + self.sample_rate = sample_rate + self.nbarks = nbarks + self.win_length = win_length + self.n_fft = n_fft + self.hop_length = hop_length + self.reduction = reduction + + self.loss_fn = TorchPesqLoss( + factor=factor, + sample_rate=sample_rate, + nbarks=nbarks, + win_length=win_length, + n_fft=n_fft, + hop_length=hop_length, + ) + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + + batch_loss = self.loss_fn.forward(clean, denoise) + + # mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss)) + # batch_loss = batch_loss[mask] + # if len(batch_loss) == 0: + # raise AssertionError + + if self.reduction == "mean": + loss = torch.mean(batch_loss) + elif self.reduction == "sum": + loss = torch.sum(batch_loss) + else: + raise AssertionError + return loss + + +def main(): + sample_rate = 16000 + + loss_func = NegSTOILoss( + sample_rate=sample_rate, + reduction="mean", + ) + + denoise = torch.randn(2, sample_rate) + clean = torch.randn(2, sample_rate) + + loss_batch = loss_func.forward(denoise, clean) + print(loss_batch) + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/losses/snr.py b/toolbox/torchaudio/losses/snr.py new file mode 100644 index 0000000000000000000000000000000000000000..5a483ca0ab47237150cb591bd1e45a732852a6fe --- /dev/null +++ b/toolbox/torchaudio/losses/snr.py @@ -0,0 +1,185 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://zhuanlan.zhihu.com/p/627039860 +""" +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget + + +class NegativeSNRLoss(nn.Module): + """ + Signal-to-Noise Ratio + """ + def __init__(self, eps: float = 1e-8): + super(NegativeSNRLoss, self).__init__() + self.eps = eps + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + """ + Compute the SI-SNR loss between the estimated signal and the target signal. + + :param denoise: The estimated signal (batch_size, signal_length) + :param clean: The target signal (batch_size, signal_length) + :return: The SI-SNR loss (batch_size,) + """ + if denoise.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + + denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) + clean = clean - torch.mean(clean, dim=-1, keepdim=True) + + noise = denoise - clean + + clean_power = torch.norm(clean, p=2, dim=-1) ** 2 + noise_power = torch.norm(noise, p=2, dim=-1) ** 2 + + snr = 10 * torch.log10((clean_power + self.eps) / (noise_power + self.eps)) + + return -snr.mean() + + +class NegativeSISNRLoss(nn.Module): + """ + Scale-Invariant Source-to-Noise Ratio + + https://arxiv.org/abs/2206.07293 + """ + def __init__(self, + reduction: str = "mean", + eps: float = 1e-8, + ): + super(NegativeSISNRLoss, self).__init__() + self.reduction = reduction + self.eps = eps + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + """ + Compute the SI-SNR loss between the estimated signal and the target signal. + + :param denoise: The estimated signal (batch_size, signal_length) + :param clean: The target signal (batch_size, signal_length) + :return: The SI-SNR loss (batch_size,) + """ + if denoise.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + + denoise = denoise - torch.mean(denoise, dim=-1, keepdim=True) + clean = clean - torch.mean(clean, dim=-1, keepdim=True) + + s_target = torch.sum(denoise * clean, dim=-1, keepdim=True) * clean / (torch.norm(clean, p=2, dim=-1, keepdim=True) ** 2 + self.eps) + + e_noise = denoise - s_target + + batch_si_snr = 10 * torch.log10(torch.norm(s_target, p=2, dim=-1) ** 2 / (torch.norm(e_noise, p=2, dim=-1) ** 2 + self.eps) + self.eps) + # si_snr shape: [batch_size,] + + if self.reduction == "mean": + loss = torch.mean(batch_si_snr) + elif self.reduction == "sum": + loss = torch.sum(batch_si_snr) + else: + raise AssertionError + return -loss + + +class LocalSNRLoss(nn.Module): + """ + https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816 + + """ + def __init__(self, + sample_rate: int = 8000, + nfft: int = 512, + win_size: int = 512, + hop_size: int = 256, + n_frame: int = 3, + min_local_snr: int = -15, + max_local_snr: int = 30, + db: bool = True, + factor: float = 1, + reduction: str = "mean", + eps: float = 1e-8, + ): + super(LocalSNRLoss, self).__init__() + self.sample_rate = sample_rate + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + + self.factor = factor + self.reduction = reduction + self.eps = eps + + self.lsnr_fn = LocalSnrTarget( + sample_rate=sample_rate, + nfft=nfft, + win_size=win_size, + hop_size=hop_size, + n_frame=n_frame, + min_local_snr=min_local_snr, + max_local_snr=max_local_snr, + db=db, + ) + + self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) + + def forward(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): + if clean.shape != noisy.shape: + raise AssertionError("Input signals must have the same shape") + noise = noisy - clean + + stft_clean = torch.stft( + clean, + n_fft=self.nfft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + stft_noise = torch.stft( + noise, + n_fft=self.nfft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + + # lsnr shape: [b, 1, t] + lsnr = lsnr.squeeze(1) + # lsnr shape: [b, t] + + lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) + # lsnr_gth shape: [b, t] + + loss = F.mse_loss(lsnr, lsnr_gth) * self.factor + return loss + + +def main(): + batch_size = 2 + signal_length = 16000 + estimated_signal = torch.randn(batch_size, signal_length) + # target_signal = torch.randn(batch_size, signal_length) + target_signal = torch.zeros(batch_size, signal_length) + + si_snr_loss = NegativeSISNRLoss() + + loss = si_snr_loss.forward(estimated_signal, target_signal) + print(f"loss: {loss.item()}") + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/losses/spectral.py b/toolbox/torchaudio/losses/spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..c374cbed97f40793430b12cbd411417a4acb47d4 --- /dev/null +++ b/toolbox/torchaudio/losses/spectral.py @@ -0,0 +1,437 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://zhuanlan.zhihu.com/p/627039860 + +https://github.com/facebookresearch/denoiser/blob/main/denoiser/stft_loss.py +""" +from typing import List + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class LSDLoss(nn.Module): + """ + Log Spectral Distance + + Mean square error of power spectrum + """ + def __init__(self, + n_fft: int = 512, + win_size: int = 512, + hop_size: int = 256, + center: bool = True, + eps: float = 1e-8, + reduction: str = "mean", + ): + super(LSDLoss, self).__init__() + self.n_fft = n_fft + self.win_size = win_size + self.hop_size = hop_size + self.center = center + self.eps = eps + self.reduction = reduction + + if reduction not in ("sum", "mean"): + raise AssertionError(f"param reduction must be sum or mean.") + + def forward(self, denoise_power: torch.Tensor, clean_power: torch.Tensor): + """ + :param denoise_power: power spectrum of the estimated signal power spectrum (batch_size, ...) + :param clean_power: power spectrum of the target signal (batch_size, ...) + :return: + """ + denoise_power = denoise_power + self.eps + clean_power = clean_power + self.eps + + log_denoise_power = torch.log10(denoise_power) + log_clean_power = torch.log10(clean_power) + + # mean_square_error shape: [b, f] + mean_square_error = torch.mean(torch.square(log_denoise_power - log_clean_power), dim=-1) + + if self.reduction == "mean": + lsd_loss = torch.mean(mean_square_error) + elif self.reduction == "sum": + lsd_loss = torch.sum(mean_square_error) + else: + raise AssertionError + return lsd_loss + + +class ComplexSpectralLoss(nn.Module): + def __init__(self, + n_fft: int = 512, + win_size: int = 512, + hop_size: int = 256, + center: bool = True, + eps: float = 1e-8, + reduction: str = "mean", + factor_mag: float = 0.5, + factor_pha: float = 0.3, + factor_gra: float = 0.2, + ): + super().__init__() + self.n_fft = n_fft + self.win_size = win_size + self.hop_size = hop_size + self.center = center + self.eps = eps + self.reduction = reduction + + self.factor_mag = factor_mag + self.factor_pha = factor_pha + self.factor_gra = factor_gra + + if reduction not in ("sum", "mean"): + raise AssertionError(f"param reduction must be sum or mean.") + + self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + """ + :param denoise: The estimated signal (batch_size, signal_length) + :param clean: The target signal (batch_size, signal_length) + :return: + """ + if denoise.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + + # denoise_stft, clean_stft shape: [b, f, t] + denoise_stft = torch.stft( + denoise, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + clean_stft = torch.stft( + clean, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + + # complex_diff shape: [b, f, t], dtype: torch.complex64 + complex_diff = denoise_stft - clean_stft + + # magnitude_diff, phase_diff shape: [b, f, t], dtype: torch.float32 + magnitude_diff = torch.abs(complex_diff) + phase_diff = torch.angle(complex_diff) + + # magnitude_loss, phase_loss shape: [b,] + magnitude_loss = torch.norm(magnitude_diff, p=2, dim=(-1, -2)) + phase_loss = torch.norm(phase_diff, p=1, dim=(-1, -2)) + + # phase_grad shape: [b, f, t-1], dtype: torch.float32 + phase_grad = torch.diff(torch.angle(denoise_stft), dim=-1) + grad_loss = torch.mean(torch.abs(phase_grad), dim=(-1, -2)) + + # loss, grad_loss shape: [b,] + batch_loss = self.factor_mag * magnitude_loss + self.factor_pha * phase_loss + self.factor_gra * grad_loss + # print(f"magnitude_loss: {magnitude_loss}") + # print(f"phase_loss: {phase_loss}") + # print(f"grad_loss: {grad_loss}") + + if self.reduction == "mean": + loss = torch.mean(batch_loss) + elif self.reduction == "sum": + loss = torch.sum(batch_loss) + else: + raise AssertionError + return loss + + +class SpectralConvergenceLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self, + reduction: str = "mean", + eps: float = 1e-8, + ): + super(SpectralConvergenceLoss, self).__init__() + self.reduction = reduction + self.eps = eps + + if reduction not in ("sum", "mean"): + raise AssertionError(f"param reduction must be sum or mean.") + + def forward(self, + denoise_magnitude: torch.Tensor, + clean_magnitude: torch.Tensor, + ): + """ + :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] + :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] + :return: + """ + error_norm = torch.norm(denoise_magnitude - clean_magnitude, p="fro", dim=(-1, -2)) + truth_norm = torch.norm(clean_magnitude, p="fro", dim=(-1, -2)) + + batch_loss = error_norm / (truth_norm + self.eps) + + if self.reduction == "mean": + loss = torch.mean(batch_loss) + elif self.reduction == "sum": + loss = torch.sum(batch_loss) + else: + raise AssertionError + + return loss + + +class LogSTFTMagnitudeLoss(torch.nn.Module): + """Log STFT magnitude loss module.""" + + def __init__(self, + reduction: str = "mean", + eps: float = 1e-8, + ): + super(LogSTFTMagnitudeLoss, self).__init__() + self.reduction = reduction + self.eps = eps + + if reduction not in ("sum", "mean"): + raise AssertionError(f"param reduction must be sum or mean.") + + def forward(self, + denoise_magnitude: torch.Tensor, + clean_magnitude: torch.Tensor, + ): + """ + :param denoise_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] + :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins] + :return: + """ + + loss = F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps)) + + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + print("LogSTFTMagnitudeLoss, nan or inf in loss") + + return loss + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__(self, + n_fft: int = 1024, + win_size: int = 600, + hop_size: int = 120, + center: bool = True, + reduction: str = "mean", + ): + super(STFTLoss, self).__init__() + self.n_fft = n_fft + self.win_size = win_size + self.hop_size = hop_size + self.center = center + self.reduction = reduction + + self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) + + self.spectral_convergence_loss = SpectralConvergenceLoss(reduction=reduction) + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(reduction=reduction) + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + """ + :param denoise: + :param clean: + :return: + """ + if denoise.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + + # denoise_stft, clean_stft shape: [b, f, t] + denoise_stft = torch.stft( + denoise, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + clean_stft = torch.stft( + clean, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + + denoise_magnitude = torch.abs(denoise_stft) + clean_magnitude = torch.abs(clean_stft) + + sc_loss = self.spectral_convergence_loss.forward(denoise_magnitude, clean_magnitude) + mag_loss = self.log_stft_magnitude_loss.forward(denoise_magnitude, clean_magnitude) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__(self, + fft_size_list: List[int] = None, + win_size_list: List[int] = None, + hop_size_list: List[int] = None, + factor_sc=0.1, + factor_mag=0.1, + reduction: str = "mean", + ): + super(MultiResolutionSTFTLoss, self).__init__() + fft_size_list = fft_size_list or [512, 1024, 2048] + win_size_list = win_size_list or [240, 600, 1200] + hop_size_list = hop_size_list or [50, 120, 240] + + if not len(fft_size_list) == len(win_size_list) == len(hop_size_list): + raise AssertionError + + loss_fn_list = nn.ModuleList([]) + for n_fft, win_size, hop_size in zip(fft_size_list, win_size_list, hop_size_list): + loss_fn_list.append( + STFTLoss( + n_fft=n_fft, + win_size=win_size, + hop_size=hop_size, + reduction=reduction, + ) + ) + + self.loss_fn_list = loss_fn_list + self.factor_sc = factor_sc + self.factor_mag = factor_mag + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + """ + :param denoise: + :param clean: + :return: + """ + if denoise.shape != clean.shape: + raise AssertionError(f"Input signals must have the same shape. denoise_audios: {denoise.shape}, clean_audios: {clean.shape}") + + sc_loss = 0.0 + mag_loss = 0.0 + for loss_fn in self.loss_fn_list: + sc_l, mag_l = loss_fn.forward(denoise, clean) + sc_loss += sc_l + mag_loss += mag_l + sc_loss = sc_loss / len(self.loss_fn_list) + mag_loss = mag_loss / len(self.loss_fn_list) + + sc_loss = self.factor_sc * sc_loss + mag_loss = self.factor_mag * mag_loss + + loss = sc_loss + mag_loss + return loss + + +class WeightedMagnitudePhaseLoss(nn.Module): + def __init__(self, + n_fft: int = 1024, + win_size: int = 600, + hop_size: int = 120, + center: bool = True, + reduction: str = "mean", + mag_weight: float = 0.9, + pha_weight: float = 0.3, + ): + super(WeightedMagnitudePhaseLoss, self).__init__() + self.n_fft = n_fft + self.win_size = win_size + self.hop_size = hop_size + self.center = center + self.reduction = reduction + + self.mag_weight = mag_weight + self.pha_weight = pha_weight + + self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) + + def forward(self, denoise: torch.Tensor, clean: torch.Tensor): + """ + :param denoise: + :param clean: + :return: + """ + if denoise.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + + # denoise_stft, clean_stft shape: [b, f, t] + denoise_stft = torch.stft( + denoise, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + clean_stft = torch.stft( + clean, + n_fft=self.n_fft, + win_length=self.win_size, + hop_length=self.hop_size, + window=self.window, + center=self.center, + pad_mode="reflect", + normalized=False, + return_complex=True + ) + + denoise_stft_spec = torch.view_as_real(denoise_stft) + denoise_mag = torch.sqrt(denoise_stft_spec.pow(2).sum(-1) + 1e-9) + denoise_pha = torch.atan2(denoise_stft_spec[:, :, :, 1] + 1e-10, denoise_stft_spec[:, :, :, 0] + 1e-5) + + clean_stft_spec = torch.view_as_real(clean_stft) + clean_mag = torch.sqrt(clean_stft_spec.pow(2).sum(-1) + 1e-9) + clean_pha = torch.atan2(clean_stft_spec[:, :, :, 1] + 1e-10, clean_stft_spec[:, :, :, 0] + 1e-5) + + mag_loss = F.mse_loss(denoise_mag, clean_mag, reduction=self.reduction) + pha_loss = F.mse_loss(denoise_pha, clean_pha, reduction=self.reduction) + + loss = self.mag_weight * mag_loss + self.pha_weight * pha_loss + return loss + + +def main(): + batch_size = 2 + signal_length = 16000 + estimated_signal = torch.randn(batch_size, signal_length) + target_signal = torch.randn(batch_size, signal_length) + + # loss_fn = LSDLoss() + # loss_fn = ComplexSpectralLoss() + # loss_fn = MultiResolutionSTFTLoss() + loss_fn = WeightedMagnitudePhaseLoss() + + loss = loss_fn.forward(estimated_signal, target_signal) + print(f"loss: {loss.item()}") + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/metrics/__init__.py b/toolbox/torchaudio/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/metrics/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/metrics/pesq.py b/toolbox/torchaudio/metrics/pesq.py new file mode 100644 index 0000000000000000000000000000000000000000..a17bd069763de0843f8f173b85bc68da3355b323 --- /dev/null +++ b/toolbox/torchaudio/metrics/pesq.py @@ -0,0 +1,80 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from joblib import Parallel, delayed +import numpy as np +from pesq import pesq +from typing import List + +from pesq import cypesq + + +def run_pesq(clean_audio: np.ndarray, + noisy_audio: np.ndarray, + sample_rate: int = 16000, + mode: str = "wb", + ) -> float: + if sample_rate == 8000 and mode == "wb": + raise AssertionError(f"mode should be `nb` when sample_rate is 8000") + try: + pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) + except cypesq.NoUtterancesError as e: + pesq_score = -1 + except Exception as e: + print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") + pesq_score = -1 + return pesq_score + + +def run_batch_pesq(clean_audio_list: List[np.ndarray], + noisy_audio_list: List[np.ndarray], + sample_rate: int = 16000, + mode: str = "wb", + n_jobs: int = 4, + ) -> List[float]: + parallel = Parallel(n_jobs=n_jobs) + + parallel_tasks = list() + for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): + parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) + parallel_tasks.append(parallel_task) + + pesq_score_list = parallel.__call__(parallel_tasks) + return pesq_score_list + + +def run_pesq_score(clean_audio_list: List[np.ndarray], + noisy_audio_list: List[np.ndarray], + sample_rate: int = 16000, + mode: str = "wb", + n_jobs: int = 4, + ) -> float: + + pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, + noisy_audio_list=noisy_audio_list, + sample_rate=sample_rate, + mode=mode, + n_jobs=n_jobs, + ) + + pesq_score = np.mean(pesq_score_list) + return pesq_score + + +def main(): + clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) + noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) + + clean_audio_list = list(clean_audio) + noisy_audio_list = list(noisy_audio) + + pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) + print(pesq_score_list) + + pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) + print(pesq_score) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/__init__.py b/toolbox/torchaudio/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4aad738e112896111c38ae6624c8632aee62a234 --- /dev/null +++ b/toolbox/torchaudio/models/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/clean_unet/__init__.py b/toolbox/torchaudio/models/clean_unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/clean_unet/configuration_clean_unet.py b/toolbox/torchaudio/models/clean_unet/configuration_clean_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..9aaa380ee44246a0295f04889ee27958e59eee75 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/configuration_clean_unet.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class CleanUNetConfig(PretrainedConfig): + def __init__(self, + channels_input: int = 1, + channels_output: int = 1, + + channels_h: int = 64, + max_h: int = 768, + + encoder_n_layers: int = 8, + kernel_size: int = 4, + stride: int = 2, + tsfm_n_layers: int = 5, + tsfm_n_head: int = 8, + tsfm_d_model: int = 512, + tsfm_d_inner: int = 2048, + + **kwargs + ): + super(CleanUNetConfig, self).__init__(**kwargs) + self.channels_input = channels_input + self.channels_output = channels_output + + self.channels_h = channels_h + self.max_h = max_h + + self.encoder_n_layers = encoder_n_layers + self.kernel_size = kernel_size + self.stride = stride + self.tsfm_n_layers = tsfm_n_layers + self.tsfm_n_head = tsfm_n_head + self.tsfm_d_model = tsfm_d_model + self.tsfm_d_inner = tsfm_d_inner + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/clean_unet/inference_clean_unet.py b/toolbox/torchaudio/models/clean_unet/inference_clean_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..36c40705fbf32c23c8759e78beb0be92f25e01c9 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/inference_clean_unet.py @@ -0,0 +1,105 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from pathlib import Path +import shutil +import tempfile +import zipfile + +import librosa +import numpy as np +import torch +import torchaudio + +from project_settings import project_path +from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig +from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel, MODEL_FILE + +logger = logging.getLogger("toolbox") + + +class InferenceCleanUNet(object): + def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): + self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file + self.device = torch.device(device) + + logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") + config, model = self.load_models(self.pretrained_model_path_or_zip_file) + logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") + + self.config = config + self.model = model + self.model.to(device) + self.model.eval() + + def load_models(self, model_path: str): + model_path = Path(model_path) + if model_path.name.endswith(".zip"): + with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: + out_root = Path(tempfile.gettempdir()) / "nx_denoise" + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + model_path = out_root / model_path.stem + + config = CleanUNetConfig.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model = CleanUNetPretrainedModel.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model.to(self.device) + model.eval() + + shutil.rmtree(model_path) + return config, model + + def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # noisy_audio shape: [batch_size, n_samples] + enhanced_audio = self.enhancement_by_tensor(noisy_audio) + # noisy_audio shape: [channels, n_samples] + return enhanced_audio.cpu().numpy() + + def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + enhanced_audios = self.model.forward(noisy_audios) + # enhanced_audio shape: [batch_size, channels, num_samples] + # enhanced_audios = torch.squeeze(enhanced_audios, dim=1) + + enhanced_audio = enhanced_audios[0] + + # enhanced_audio shape: [channels, num_samples] + return enhanced_audio + + +def main(): + model_zip_file = project_path / "trained_models/clean-unet-aishell-18-epoch.zip" + infer_mpnet = InferenceCleanUNet(model_zip_file) + + sample_rate = 8000 + noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav" + noisy_audio, _ = librosa.load( + noisy_audio_file.as_posix(), + sr=sample_rate, + ) + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio) + + filename = "enhanced_audio.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/clean_unet/loss.py b/toolbox/torchaudio/models/clean_unet/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..073ccdf966a4200a281383c2c2f8b99104e26c55 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/loss.py @@ -0,0 +1,158 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import torch +import torch +import torch.nn.functional as F + + +def stft(x, fft_size, hop_size, win_length, window): + """ + Perform STFT and convert to magnitude spectrogram. + :param x: Tensor, Input signal tensor (B, T). + :param fft_size: int, FFT size. + :param hop_size: int, Hop size. + :param win_length: int, Window length. + :param window: str, Window function type. + :return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + + x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True) + + return x_stft.abs() + +class SpectralConvergenceLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self): + super(SpectralConvergenceLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """ + Calculate forward propagation. + :param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + :param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + :return: Tensor, Spectral convergence loss value. + """ + return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + + +class LogSTFTMagnitudeLoss(torch.nn.Module): + """Log STFT magnitude loss module.""" + + def __init__(self): + super(LogSTFTMagnitudeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """ + Calculate forward propagation. + :param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + :param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + :return: Tensor, Log STFT magnitude loss value. + """ + y_mag = torch.clamp(y_mag, min=1e-8) + x_mag = torch.clamp(x_mag, min=1e-8) + return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__( + self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", + band="full" + ): + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.band = band + + self.spectral_convergence_loss = SpectralConvergenceLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + # NOTE(kan-bayashi): Use register_buffer to fix #223 + self.register_buffer("window", getattr(torch, window)(win_length)) + + def forward(self, x, y): + """ + Calculate forward propagation. + :param x: Tensor, Predicted signal (B, T). + :param y: Tensor, Groundtruth signal (B, T). + :return: + Tensor, Spectral convergence loss value. + Tensor, Log STFT magnitude loss value. + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) + + if self.band == "high": + freq_mask_ind = x_mag.shape[1] // 2 # only select high frequency bands + sc_loss = self.spectral_convergence_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:]) + mag_loss = self.log_stft_magnitude_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:]) + elif self.band == "full": + sc_loss = self.spectral_convergence_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + else: + raise NotImplementedError + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__(self, + fft_sizes=None, hop_sizes=None, win_lengths=None, + window="hann_window", sc_lambda=0.1, mag_lambda=0.1, band="full", + ): + """ + Initialize Multi resolution STFT loss module. + :param fft_sizes: list, List of FFT sizes. + :param hop_sizes: list, List of hop sizes. + :param win_lengths: list, List of window lengths. + :param window: str, Window function type. + :param sc_lambda: float, a balancing factor across different losses. + :param mag_lambda: float, a balancing factor across different losses. + :param band: str, high-band or full-band loss + """ + super(MultiResolutionSTFTLoss, self).__init__() + fft_sizes = fft_sizes or [1024, 2048, 512] + hop_sizes = hop_sizes or [120, 240, 50] + win_lengths = win_lengths or [600, 1200, 240] + + self.sc_lambda = sc_lambda + self.mag_lambda = mag_lambda + + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window, band)] + + def forward(self, x, y): + """ + Calculate forward propagation. + :param x: Tensor, Predicted signal (B, T) or (B, #subband, T). + :param y: Tensor, Groundtruth signal (B, T) or (B, #subband, T). + :return: + Tensor, Multi resolution spectral convergence loss value. + Tensor, Multi resolution log STFT magnitude loss value. + """ + if len(x.shape) == 3: + x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T) + y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T) + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + + sc_loss *= self.sc_lambda + sc_loss /= len(self.stft_losses) + mag_loss *= self.mag_lambda + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/clean_unet/metrics.py b/toolbox/torchaudio/models/clean_unet/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..78468894a56d4488021e83ea47e07c785a385269 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/metrics.py @@ -0,0 +1,80 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from joblib import Parallel, delayed +import numpy as np +from pesq import pesq +from typing import List + +from pesq import cypesq + + +def run_pesq(clean_audio: np.ndarray, + noisy_audio: np.ndarray, + sample_rate: int = 16000, + mode: str = "wb", + ) -> float: + if sample_rate == 8000 and mode == "wb": + raise AssertionError(f"mode should be `nb` when sample_rate is 8000") + try: + pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) + except cypesq.NoUtterancesError as e: + pesq_score = -1 + except Exception as e: + print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") + pesq_score = -1 + return pesq_score + + +def run_batch_pesq(clean_audio_list: List[np.ndarray], + noisy_audio_list: List[np.ndarray], + sample_rate: int = 16000, + mode: str = "wb", + n_jobs: int = 4, + ) -> List[float]: + parallel = Parallel(n_jobs=n_jobs) + + parallel_tasks = list() + for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): + parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) + parallel_tasks.append(parallel_task) + + pesq_score_list = parallel.__call__(parallel_tasks) + return pesq_score_list + + +def run_pesq_score(clean_audio_list: List[np.ndarray], + noisy_audio_list: List[np.ndarray], + sample_rate: int = 16000, + mode: str = "wb", + n_jobs: int = 4, + ) -> List[float]: + + pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, + noisy_audio_list=noisy_audio_list, + sample_rate=sample_rate, + mode=mode, + n_jobs=n_jobs, + ) + + pesq_score = np.mean(pesq_score_list) + return pesq_score + + +def main(): + clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) + noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) + + clean_audio_list = list(clean_audio) + noisy_audio_list = list(noisy_audio) + + pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) + print(pesq_score_list) + + pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) + print(pesq_score) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py b/toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ab39dfe448d19640f7f3a2df9c0ef4a6a80d2d --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py @@ -0,0 +1,292 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2202.07790 + +https://github.com/nvidia/cleanunet + +https://huggingface.co/spaces/fsoft-ai-center/Speech-Enhancement/blob/main/src/model.py + +支持流式改造。 + +https://github.com/francislr/clean-unet-inference + +""" +import os +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.clean_unet.transformer import TransformerEncoder +from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig + + +def weight_scaling_init(layer): + """ + weight rescaling initialization from https://arxiv.org/abs/1911.13254 + """ + w = layer.weight.detach() + alpha = 10.0 * w.std() + layer.weight.data /= torch.sqrt(alpha) + layer.bias.data /= torch.sqrt(alpha) + + +def print_size(net, keyword=None): + """ + Print the number of parameters of a network + """ + + if net is not None and isinstance(net, torch.nn.Module): + module_parameters = filter(lambda p: p.requires_grad, net.parameters()) + params = sum([np.prod(p.size()) for p in module_parameters]) + + print("{} Parameters: {:.6f}M".format( + net.__class__.__name__, params / 1e6), flush=True, end="; ") + + if keyword is not None: + keyword_parameters = [p for name, p in net.named_parameters() if p.requires_grad and keyword in name] + params = sum([np.prod(p.size()) for p in keyword_parameters]) + print("{} Parameters: {:.6f}M".format( + keyword, params / 1e6), flush=True, end="; ") + + print(" ") + + +# CleanUNet architecture + +def padding(x, D, K, S): + """padding zeroes to x so that denoised audio has the same length""" + + L = x.shape[-1] + for _ in range(D): + if L < K: + L = 1 + else: + L = 1 + np.ceil((L - K) / S) + + for _ in range(D): + L = (L - 1) * S + K + + L = int(L) + x = F.pad(x, (0, L - x.shape[-1])) + return x + + +class CleanUNet(nn.Module): + """ + CleanUNet architecture. + """ + + def __init__(self, + channels_input=1, channels_output=1, + channels_h=64, max_h=768, + encoder_n_layers=8, kernel_size=4, stride=2, + tsfm_n_layers=3, + tsfm_n_head=8, + tsfm_d_model=512, + tsfm_d_inner=2048): + """ + Parameters: + channels_input (int): input channels + channels_output (int): output channels + channels_H (int): middle channels H that controls capacity + max_H (int): maximum H + encoder_n_layers (int): number of encoder/decoder layers D + kernel_size (int): kernel size K + stride (int): stride S + tsfm_n_layers (int): number of self attention blocks N + tsfm_n_head (int): number of heads in each self attention block + tsfm_d_model (int): d_model of self attention + tsfm_d_inner (int): d_inner of self attention + """ + + super(CleanUNet, self).__init__() + + self.channels_input = channels_input + self.channels_output = channels_output + self.channels_h = channels_h + self.max_h = max_h + self.encoder_n_layers = encoder_n_layers + self.kernel_size = kernel_size + self.stride = stride + + self.tsfm_n_layers = tsfm_n_layers + self.tsfm_n_head = tsfm_n_head + self.tsfm_d_model = tsfm_d_model + self.tsfm_d_inner = tsfm_d_inner + + # encoder and decoder + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(encoder_n_layers): + self.encoder.append(nn.Sequential( + nn.Conv1d(channels_input, channels_h, kernel_size, stride), + nn.ReLU(inplace=False), + nn.Conv1d(channels_h, channels_h * 2, 1), + nn.GLU(dim=1) + )) + channels_input = channels_h + + if i == 0: + # no relu at end + self.decoder.append(nn.Sequential( + nn.Conv1d(channels_h, channels_h * 2, 1), + nn.GLU(dim=1), + nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride) + )) + else: + self.decoder.insert(0, nn.Sequential( + nn.Conv1d(channels_h, channels_h * 2, 1), + nn.GLU(dim=1), + nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride), + nn.ReLU() + )) + channels_output = channels_h + + # double H but keep below max_H + channels_h *= 2 + channels_h = min(channels_h, max_h) + + # self attention block + self.tsfm_conv1 = nn.Conv1d(channels_output, tsfm_d_model, kernel_size=1) + self.tsfm_encoder = TransformerEncoder(d_word_vec=tsfm_d_model, + n_layers=tsfm_n_layers, + n_head=tsfm_n_head, + d_k=tsfm_d_model // tsfm_n_head, + d_v=tsfm_d_model // tsfm_n_head, + d_model=tsfm_d_model, + d_inner=tsfm_d_inner, + dropout=0.0, + n_position=0, + scale_emb=False) + self.tsfm_conv2 = nn.Conv1d(tsfm_d_model, channels_output, kernel_size=1) + + # weight scaling initialization + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + weight_scaling_init(layer) + + def forward(self, noisy_audio): + # (B, L) -> (B, C, L) + if len(noisy_audio.shape) == 2: + noisy_audio = noisy_audio.unsqueeze(1) + B, C, L = noisy_audio.shape + assert C == 1 + + # normalization and padding + std = noisy_audio.std(dim=2, keepdim=True) + 1e-3 + noisy_audio /= std + x = padding(noisy_audio, self.encoder_n_layers, self.kernel_size, self.stride) + + # encoder + skip_connections = [] + for downsampling_block in self.encoder: + x = downsampling_block(x) + skip_connections.append(x) + skip_connections = skip_connections[::-1] + + # attention mask for causal inference; for non-causal, set attn_mask to None + len_s = x.shape[-1] # length at bottleneck + attn_mask = (1 - torch.triu(torch.ones((1, len_s, len_s), device=x.device), diagonal=1)).bool() + + x = self.tsfm_conv1(x) # C 1024 -> 512 + x = x.permute(0, 2, 1) + x = self.tsfm_encoder.forward(x, src_mask=attn_mask) + x = x.permute(0, 2, 1) + x = self.tsfm_conv2(x) # C 512 -> 1024 + + # decoder + for i, upsampling_block in enumerate(self.decoder): + skip_i = skip_connections[i] + x = x + skip_i[:, :, :x.shape[-1]] + x = upsampling_block(x) + + x = x[:, :, :L] * std + return x + + +MODEL_FILE = "model.pt" + + +class CleanUNetPretrainedModel(CleanUNet): + def __init__(self, + config: CleanUNetConfig, + ): + super(CleanUNetPretrainedModel, self).__init__( + channels_input=config.channels_input, + channels_output=config.channels_output, + channels_h=config.channels_h, + max_h=config.max_h, + encoder_n_layers=config.encoder_n_layers, + kernel_size=config.kernel_size, + stride=config.stride, + tsfm_n_layers=config.tsfm_n_layers, + tsfm_n_head=config.tsfm_n_head, + tsfm_d_model=config.tsfm_d_model, + tsfm_d_inner=config.tsfm_d_inner, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = CleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + + config = CleanUNetConfig() + model = CleanUNetPretrainedModel(config) + + print_size(model, keyword="tsfm") + + input_data = torch.ones([4, 1, int(4.5 * 16000)]) + output = model.forward(input_data) + print(output.shape) + + # y = torch.rand([4, 1, int(4.5 * 16000)]) + # loss = torch.nn.MSELoss()(y, output) + # loss.backward() + # print(loss.item()) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/clean_unet/training.py b/toolbox/torchaudio/models/clean_unet/training.py new file mode 100644 index 0000000000000000000000000000000000000000..ed39c88b5e9b4f8e9bb877dce95646ba837478d7 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/training.py @@ -0,0 +1,85 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import math + + +def anneal_linear(start, end, proportion): + return start + proportion * (end - start) + + +def anneal_cosine(start, end, proportion): + cos_val = math.cos(math.pi * proportion) + 1 + return end + (start - end) / 2 * cos_val + + +class Phase: + def __init__(self, start, end, n_iter, cur_iter, anneal_fn): + self.start, self.end = start, end + self.n_iter = n_iter + self.anneal_fn = anneal_fn + self.n = cur_iter + + def step(self): + self.n += 1 + + return self.anneal_fn(self.start, self.end, self.n / self.n_iter) + + def reset(self): + self.n = 0 + + @property + def is_done(self): + return self.n >= self.n_iter + + +class LinearWarmupCosineDecay(object): + def __init__( + self, + optimizer, + lr_max, + n_iter, + iteration=0, + divider=25, + warmup_proportion=0.3, + phase=('linear', 'cosine'), + ): + self.optimizer = optimizer + + phase1 = int(n_iter * warmup_proportion) + phase2 = n_iter - phase1 + lr_min = lr_max / divider + + phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine} + + cur_iter_phase1 = iteration + cur_iter_phase2 = max(0, iteration - phase1) + self.lr_phase = [ + Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]), + Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]), + ] + + if iteration < phase1: + self.phase = 0 + else: + self.phase = 1 + + def step(self): + lr = self.lr_phase[self.phase].step() + + for group in self.optimizer.param_groups: + group['lr'] = lr + + if self.lr_phase[self.phase].is_done: + self.phase += 1 + + if self.phase >= len(self.lr_phase): + for phase in self.lr_phase: + phase.reset() + + self.phase = 0 + + return lr + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/clean_unet/transformer.py b/toolbox/torchaudio/models/clean_unet/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..467c3505f759f6789111161aaae5efd77c3d35c9 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/transformer.py @@ -0,0 +1,216 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Transformer (encoder) https://github.com/jadore801120/attention-is-all-you-need-pytorch +# Original Copyright 2017 Victor Huang +# MIT License (https://opensource.org/licenses/MIT) + +class ScaledDotProductAttention(nn.Module): + """ + Scaled Dot-Product Attention + """ + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward(self, q, k, v, mask=None): + + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, -1e9) + + attn = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + """ + Multi-Head Attention module + """ + + def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) + self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + + self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward(self, q, k, v, mask=None): + + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + residual = q + + # Pass through the pre-attention projection: b x lq x (n*dv) + # Separate different heads: b x lq x n x dv + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + # Transpose for attention dot product: b x n x lq x dv + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) # For head axis broadcasting. + + q, attn = self.attention(q, k, v, mask=mask) + + # Transpose to move the head dimension back: b x lq x n x dv + # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) + q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + q = self.dropout(self.fc(q)) + q += residual + + q = self.layer_norm(q) + + return q, attn + + +class PositionwiseFeedForward(nn.Module): + """ + A two-feed-forward-layer module + """ + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) # position-wise + self.w_2 = nn.Linear(d_hid, d_in) # position-wise + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + residual = x + + x = self.w_2(F.relu(self.w_1(x))) + x = self.dropout(x) + x += residual + + x = self.layer_norm(x) + + return x + + +def get_subsequent_mask(seq): + """ + For masking out the subsequent info. + """ + sz_b, len_s = seq.size() + subsequent_mask = (1 - torch.triu( + torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() + return subsequent_mask + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_hid, n_position=200): + super(PositionalEncoding, self).__init__() + + # Not a parameter + self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) + + def _get_sinusoid_encoding_table(self, n_position, d_hid): + """ + Sinusoid position encoding table + """ + # TODO: make it with torch instead of numpy + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.FloatTensor(sinusoid_table).unsqueeze(0) + + def forward(self, x): + return x + self.pos_table[:, :x.size(1)].clone().detach() + + +class EncoderLayer(nn.Module): + """ + Compose with two layers + """ + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0): + super(EncoderLayer, self).__init__() + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward(self, enc_input, slf_attn_mask=None): + enc_output, enc_slf_attn = self.slf_attn( + enc_input, enc_input, enc_input, mask=slf_attn_mask) + enc_output = self.pos_ffn(enc_output) + return enc_output, enc_slf_attn + + +class TransformerEncoder(nn.Module): + """ + A encoder model with self attention mechanism. + """ + + def __init__( + self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64, + d_model=512, d_inner=2048, dropout=0.1, n_position=624, scale_emb=False): + + super().__init__() + + # self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) + if n_position > 0: + self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) + else: + self.position_enc = lambda x: x + self.dropout = nn.Dropout(p=dropout) + self.layer_stack = nn.ModuleList([ + EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) + for _ in range(n_layers)]) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.scale_emb = scale_emb + self.d_model = d_model + + def forward(self, src_seq, src_mask, return_attns=False): + + enc_slf_attn_list = [] + + # -- Forward + # enc_output = self.src_word_emb(src_seq) + enc_output = src_seq + if self.scale_emb: + enc_output *= self.d_model ** 0.5 + enc_output = self.dropout(self.position_enc(enc_output)) + enc_output = self.layer_norm(enc_output) + + for enc_layer in self.layer_stack: + enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) + enc_slf_attn_list += [enc_slf_attn] if return_attns else [] + + if return_attns: + return enc_output, enc_slf_attn_list + return enc_output + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/clean_unet/yaml/config.yaml b/toolbox/torchaudio/models/clean_unet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e18d8ff87d142144b34d2d2627e7530982c54830 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/yaml/config.yaml @@ -0,0 +1,13 @@ +model_name: "clean_unet" + +channels_input: 1 +channels_output: 1 +channels_h: 64 +max_h: 768 +encoder_n_layers: 8 +kernel_size: 4 +stride: 2 +tsfm_n_layers: 5 +tsfm_n_head: 8 +tsfm_d_model: 512 +tsfm_d_inner: 2048 diff --git a/toolbox/torchaudio/models/conv_tasnet/__init__.py b/toolbox/torchaudio/models/conv_tasnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/conv_tasnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py b/toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..16e44db2aa49849a9b9868efe8af940d1eb58c48 --- /dev/null +++ b/toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py @@ -0,0 +1,76 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Tuple + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class ConvTasNetConfig(PretrainedConfig): + """ + https://github.com/kaituoxu/Conv-TasNet/blob/master/src/train.py + """ + def __init__(self, + sample_rate: int = 8000, + segment_size: int = 4, + + win_size: int = 20, + + freq_bins: int = 256, + bottleneck_channels: int = 256, + num_speakers: int = 2, + num_blocks: int = 4, + num_sub_blocks: int = 8, + sub_blocks_channels: int = 512, + sub_blocks_kernel_size: int = 3, + + norm_type: str = "gLN", + causal: bool = False, + mask_nonlinear: str = "relu", + + min_snr_db: float = -10, + max_snr_db: float = 20, + + lr: float = 1e-3, + adam_b1: float = 0.8, + adam_b2: float = 0.99, + + lr_scheduler: str = "CosineAnnealingLR", + lr_scheduler_kwargs: dict = None, + + eval_steps: int = 25000, + + **kwargs + ): + super(ConvTasNetConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.segment_size = segment_size + + self.win_size = win_size + + self.freq_bins = freq_bins + self.bottleneck_channels = bottleneck_channels + self.num_speakers = num_speakers + self.num_blocks = num_blocks + self.num_sub_blocks = num_sub_blocks + self.sub_blocks_channels = sub_blocks_channels + self.sub_blocks_kernel_size = sub_blocks_kernel_size + + self.norm_type = norm_type + self.causal = causal + self.mask_nonlinear = mask_nonlinear + + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + + self.lr = lr + self.adam_b1 = adam_b1 + self.adam_b2 = adam_b2 + + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() + + self.eval_steps = eval_steps + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/conv_tasnet/inference_conv_tasnet.py b/toolbox/torchaudio/models/conv_tasnet/inference_conv_tasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2adef21c6a06705f032170af4ad4426ea7852131 --- /dev/null +++ b/toolbox/torchaudio/models/conv_tasnet/inference_conv_tasnet.py @@ -0,0 +1,112 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from pathlib import Path +import shutil +import tempfile, time +import zipfile + +import librosa +import numpy as np +import torch +import torchaudio + +torch.set_num_threads(1) + +from project_settings import project_path +from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig +from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNetPretrainedModel, MODEL_FILE + +logger = logging.getLogger("toolbox") + + +class InferenceConvTasNet(object): + def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): + self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file + self.device = torch.device(device) + + logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") + config, model = self.load_models(self.pretrained_model_path_or_zip_file) + logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") + + self.config = config + self.model = model + self.model.to(device) + self.model.eval() + + def load_models(self, model_path: str): + model_path = Path(model_path) + if model_path.name.endswith(".zip"): + with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: + out_root = Path(tempfile.gettempdir()) / "nx_denoise" + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + model_path = out_root / model_path.stem + + config = ConvTasNetConfig.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model = ConvTasNetPretrainedModel.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model.to(self.device) + model.eval() + + shutil.rmtree(model_path) + return config, model + + def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # noisy_audio shape: [batch_size, n_samples] + enhanced_audio = self.enhancement_by_tensor(noisy_audio) + # noisy_audio shape: [n_samples,] + return enhanced_audio.cpu().numpy() + + def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + enhanced_audios = self.model.forward(noisy_audios) + # enhanced_audio shape: [batch_size, channels, num_samples] + # enhanced_audios = torch.squeeze(enhanced_audios, dim=1) + + enhanced_audio = enhanced_audios[0] + + # enhanced_audio shape: [channels, num_samples] + return enhanced_audio + + +def main(): + model_zip_file = project_path / "trained_models/conv-tasnet-dns3-1025k-steps.zip" + infer_conv_tasnet = InferenceConvTasNet(model_zip_file) + + sample_rate = 8000 + noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav" + noisy_audio, sample_rate = librosa.load( + noisy_audio_file.as_posix(), + sr=sample_rate, + ) + duration = librosa.get_duration(y=noisy_audio, sr=sample_rate) + # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + begin = time.time() + enhanced_audio = infer_conv_tasnet.enhancement_by_tensor(noisy_audio) + time_cost = time.time() - begin + print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + filename = "enhanced_audio.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py b/toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..407ef1e82b5061bb7ceb56bea7675ec0e528147c --- /dev/null +++ b/toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py @@ -0,0 +1,485 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py + +https://pytorch.org/audio/2.5.0/generated/torchaudio.models.ConvTasNet.html +""" +import os +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.conv_tasnet.utils import overlap_and_add +from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig + + +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN)""" + def __init__(self, + channels: int, + eps: float = 1e-8 + ): + super(ChannelwiseLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channels, 1)) + self.beta = nn.Parameter(torch.Tensor(1, channels,1 )) + self.reset_parameters() + + self.eps = eps + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + :param y: Tensor, shape: [batch_size, channels, time_steps] + :return: gln_y: Tensor, shape: [batch_size, channels, time_steps] + """ + # mean, var shape: [batch_size, 1, time_steps] + mean = torch.mean(y, dim=1, keepdim=True) + var = torch.var(y, dim=1, keepdim=True, unbiased=False) + + cln_y = self.gamma * (y - mean) / torch.pow(var + self.eps, 0.5) + self.beta + return cln_y + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN)""" + def __init__(self, + channels: int, + eps: float = 1e-8 + ): + super(GlobalLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channels, 1)) + self.beta = nn.Parameter(torch.Tensor(1, channels,1 )) + self.reset_parameters() + + self.eps = eps + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """ + :param y: Tensor, shape: [batch_size, channels, time_steps] + :return: gln_y: Tensor, shape: [batch_size, channels, time_steps] + """ + # mean, var shape: [batch_size, 1, 1] + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + + gln_y = self.gamma * (y - mean) / torch.pow(var + self.eps, 0.5) + self.beta + return gln_y + + +def choose_norm(norm_type: str, channels: int): + """ + The input of normalization will be (M, C, K), where M is batch size, + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channels) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channels) + else: # norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channels) + + +class Chomp1d(nn.Module): + """ + To ensure the output length is the same as the input. + """ + def __init__(self, chomp_size: int): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x: torch.Tensor): + """ + :param x: Tensor, shape: [batch_size, hidden_size, k_pad] + :return: Tensor, shape: [batch_size, hidden_size, k] + """ + return x[:, :, :-self.chomp_size].contiguous() + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + dilation: int, + norm_type="gLN", + causal=False + ): + super(DepthwiseSeparableConv, self).__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + self.depthwise_conv = nn.Conv1d( + in_channels=in_channels, out_channels=in_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + groups=in_channels, bias=False, + ) + + self.chomp = None + if causal: + self.chomp = Chomp1d(padding) + + self.prelu = nn.PReLU() + self.norm = choose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + self.pointwise_conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor): + """ + :param x: Tensor, shape: [batch_size, hidden_size, k] + :return: Tensor, shape: [batch_size, b, k] + """ + x = self.depthwise_conv.forward(x) + if self.chomp is not None: + x = self.chomp.forward(x) + x = self.prelu.forward(x) + x = self.norm.forward(x) + x = self.pointwise_conv.forward(x) + + return x + + +class Encoder(nn.Module): + def __init__(self, win_size: int, freq_bins: int): + super(Encoder, self).__init__() + self.win_size = win_size + self.freq_bins = freq_bins + + self.conv1d_U = nn.Conv1d( + in_channels=1, + out_channels=freq_bins, + kernel_size=win_size, + stride=win_size // 2, + bias=False + ) + + def forward(self, mixture): + """ + :param mixture: Tensor, shape: [batch_size, num_samples] + :return: mixture_w, Tensor, shape: [batch_size, freq_bins, time_steps], + where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1 + """ + mixture = torch.unsqueeze(mixture, 1) # [M, 1, T] + mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] + return mixture_w + + +class Decoder(nn.Module): + def __init__(self, win_size: int, freq_bins: int): + super(Decoder, self).__init__() + self.win_size = win_size + self.freq_bins = freq_bins + + self.basis_signals = nn.Linear( + in_features=freq_bins, + out_features=win_size, + bias=False + ) + + def forward(self, + mixture_w: torch.Tensor, + est_mask: torch.Tensor, + ): + """ + :param mixture_w: Tensor, shape: [batch_size, freq_bins, time_steps], + where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1 + :param est_mask: Tensor, shape: [batch_size, c, freq_bins, time_steps], + :return: Tensor, shape: [batch_size, c, num_samples], + """ + source_w = torch.unsqueeze(mixture_w, 1) * est_mask + source_w = torch.transpose(source_w, 2, 3) + est_source = self.basis_signals(source_w) + est_source = overlap_and_add(est_source, self.win_size//2) + return est_source + + +class TemporalBlock(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + dilation: int, + norm_type="gLN", + causal=False + ): + super(TemporalBlock, self).__init__() + self.conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + self.prelu = nn.PReLU() + self.norm = choose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + self.dsconv = DepthwiseSeparableConv( + out_channels, in_channels, + kernel_size, stride, + padding, dilation, + norm_type, causal, + ) + + def forward(self, x): + residual = x + + x = self.conv1x1.forward(x) + x = self.prelu.forward(x) + x = self.norm.forward(x) + x = self.dsconv.forward(x) + + out = x + residual + return out + + +class TemporalConvNet(nn.Module): + def __init__(self, + freq_bins: int = 256, + bottleneck_channels: int = 256, + num_speakers: int = 2, + num_blocks: int = 4, + num_sub_blocks: int = 8, + sub_blocks_channels: int = 512, + sub_blocks_kernel_size: int = 3, + norm_type: str = "gLN", + causal: bool = False, + mask_nonlinear: str = "relu", + + ): + super(TemporalConvNet, self).__init__() + self.freq_bins = freq_bins + self.bottleneck_channels = bottleneck_channels + self.num_speakers = num_speakers + + self.num_blocks = num_blocks + self.num_sub_blocks = num_sub_blocks + self.sub_blocks_channels = sub_blocks_channels + self.sub_blocks_kernel_size = sub_blocks_kernel_size + + self.mask_nonlinear = mask_nonlinear + + self.layer_norm = ChannelwiseLayerNorm(freq_bins) + self.bottleneck_conv1x1 = nn.Conv1d(freq_bins, bottleneck_channels, 1, bias=False) + + self.temporal_conv_list = nn.ModuleList([]) + for num_block_idx in range(num_blocks): + sub_blocks = list() + for num_sub_block_idx in range(num_sub_blocks): + dilation = 2 ** num_sub_block_idx + padding = (sub_blocks_kernel_size - 1) * dilation + if not causal: + padding = padding // 2 + temporal_block = TemporalBlock( + bottleneck_channels, sub_blocks_channels, + sub_blocks_kernel_size, stride=1, + padding=padding, dilation=dilation, + norm_type=norm_type, causal=causal, + ) + sub_blocks.append(temporal_block) + self.temporal_conv_list.extend(sub_blocks) + + self.mask_conv1x1 = nn.Conv1d( + in_channels=bottleneck_channels, + out_channels=num_speakers * freq_bins, + kernel_size=1, + bias=False, + ) + + def forward(self, mixture_w: torch.Tensor): + """ + :param mixture_w: Tensor, shape: [batch_size, freq_bins, time_steps] + :return: est_mask: Tensor, shape: [batch_size, freq_bins, time_steps] + """ + batch_size, freq_bins, time_steps = mixture_w.size() + + x = self.layer_norm.forward(mixture_w) + x = self.bottleneck_conv1x1.forward(x) + + for temporal_conv in self.temporal_conv_list: + x = temporal_conv.forward(x) + + score = self.mask_conv1x1.forward(x) + + # [M, C*N, K] -> [M, C, N, K] + score = score.view(batch_size, self.num_speakers, freq_bins, time_steps) + + if self.mask_nonlinear == "softmax": + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == "relu": + est_mask = F.relu(score) + else: + raise ValueError("Unsupported mask non-linear function") + + return est_mask + + +class ConvTasNet(nn.Module): + def __init__(self, + win_size: int = 20, + freq_bins: int = 256, + bottleneck_channels: int = 256, + num_speakers: int = 2, + num_blocks: int = 4, + num_sub_blocks: int = 8, + sub_blocks_channels: int = 512, + sub_blocks_kernel_size: int = 3, + norm_type: str = "gLN", + causal: bool = False, + mask_nonlinear: str = "relu", + + ): + super(ConvTasNet, self).__init__() + self.win_size = win_size + + self.freq_bins = freq_bins + self.bottleneck_channels = bottleneck_channels + self.num_speakers = num_speakers + + self.num_blocks = num_blocks + self.num_sub_blocks = num_sub_blocks + self.sub_blocks_channels = sub_blocks_channels + self.sub_blocks_kernel_size = sub_blocks_kernel_size + + self.norm_type = norm_type + self.causal = causal + self.mask_nonlinear = mask_nonlinear + + self.encoder = Encoder(win_size, freq_bins) + self.separator = TemporalConvNet( + freq_bins=freq_bins, + bottleneck_channels=bottleneck_channels, + sub_blocks_channels=sub_blocks_channels, + sub_blocks_kernel_size=sub_blocks_kernel_size, + num_sub_blocks=num_sub_blocks, + num_blocks=num_blocks, + num_speakers=num_speakers, + norm_type=norm_type, + causal=causal, + mask_nonlinear=mask_nonlinear, + ) + self.decoder = Decoder(win_size=win_size, freq_bins=freq_bins) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def forward(self, mixture: torch.Tensor): + """ + :param mixture: Tensor, shape: [batch_size, num_samples] + :return: est_source: Tensor, shape: [batch_size, c, num_samples] + """ + # mixture shape: [batch_size, num_samples] + mixture_w = self.encoder.forward(mixture) + # mixture_w shape: [batch_size, freq_bins, time_steps] + est_mask = self.separator.forward(mixture_w) + # est_mask shape: [batch_size, num_speakers, freq_bins, time_steps] + est_source = self.decoder.forward(mixture_w, est_mask) + # est_source shape: [batch_size, num_speakers, num_samples] + + num_samples1 = mixture.size(-1) + num_samples2 = est_source.size(-1) + est_source = F.pad(est_source, (0, num_samples1 - num_samples2)) + return est_source + + +MODEL_FILE = "model.pt" + + +class ConvTasNetPretrainedModel(ConvTasNet): + def __init__(self, + config: ConvTasNetConfig, + ): + super(ConvTasNetPretrainedModel, self).__init__( + win_size=config.win_size, + freq_bins=config.freq_bins, + bottleneck_channels=config.bottleneck_channels, + sub_blocks_channels=config.sub_blocks_channels, + sub_blocks_kernel_size=config.sub_blocks_kernel_size, + num_sub_blocks=config.num_sub_blocks, + num_blocks=config.num_blocks, + num_speakers=config.num_speakers, + norm_type=config.norm_type, + causal=config.causal, + mask_nonlinear=config.mask_nonlinear, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = ConvTasNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + config = ConvTasNetConfig() + tas_net = ConvTasNet( + win_size=config.win_size, + freq_bins=config.freq_bins, + bottleneck_channels=config.bottleneck_channels, + sub_blocks_channels=config.sub_blocks_channels, + sub_blocks_kernel_size=config.sub_blocks_kernel_size, + num_sub_blocks=config.num_sub_blocks, + num_blocks=config.num_blocks, + num_speakers=config.num_speakers, + norm_type=config.norm_type, + causal=config.causal, + mask_nonlinear=config.mask_nonlinear, + ) + + print(tas_net) + + mixture = torch.rand(size=(1, 8000*4), dtype=torch.float32) + + outputs = tas_net.forward(mixture) + print(outputs.shape) + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/conv_tasnet/utils.py b/toolbox/torchaudio/models/conv_tasnet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfe77e68ec9c6254d4334187a28a3e014a04c8e --- /dev/null +++ b/toolbox/torchaudio/models/conv_tasnet/utils.py @@ -0,0 +1,55 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py +""" +import math +import torch + + +def overlap_and_add(signal: torch.Tensor, frame_step: int): + """ + Reconstructs a signal from a framed representation. + + Adds potentially overlapping frames of a signal with shape + `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. + The resulting tensor has shape `[..., output_size]` where + + output_size = (frames - 1) * frame_step + frame_length + + Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py + + :param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2. + :param frame_step: int, overlap offsets. Must be less than or equal to frame_length. + :return: Tensor, shape: [..., output_size]. + containing the overlap-added frames of signal's inner-most two dimensions. + output_size = (frames - 1) * frame_step + frame_length + """ + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) + + frame = frame.clone().detach() + frame = frame.to(signal.device) + frame = frame.long() + + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + result.index_add_(-2, frame, subframe_signal) + result = result.view(*outer_dimensions, -1) + return result + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/conv_tasnet/yaml/config.yaml b/toolbox/torchaudio/models/conv_tasnet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fab1b0e8a7264e9f5d97a176fa545a757b4e8ad9 --- /dev/null +++ b/toolbox/torchaudio/models/conv_tasnet/yaml/config.yaml @@ -0,0 +1,17 @@ +model_name: "conv_tasnet" + +sample_rate: 8000 +segment_size: 4 + +win_size: 20 +freq_bins: 256 +bottleneck_channels: 256 +num_speakers: 2 +num_blocks: 4 +num_sub_blocks: 8 +sub_blocks_channels: 512 +sub_blocks_kernel_size: 3 + +norm_type: "gLN" +causal: false +mask_nonlinear: "relu" diff --git a/toolbox/torchaudio/models/dccrn/__init__.py b/toolbox/torchaudio/models/dccrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/models/dccrn/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dccrn/modeling_dccrn.py b/toolbox/torchaudio/models/dccrn/modeling_dccrn.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6306616753aca2fc7379a8f2406b819adb830a --- /dev/null +++ b/toolbox/torchaudio/models/dccrn/modeling_dccrn.py @@ -0,0 +1,12 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" + +https://arxiv.org/abs/2008.00264 + +https://github.com/huyanxin/DeepComplexCRN + +""" + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/demucs/__init__.py b/toolbox/torchaudio/models/demucs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/demucs/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/demucs/configuration_demucs.py b/toolbox/torchaudio/models/demucs/configuration_demucs.py new file mode 100644 index 0000000000000000000000000000000000000000..5c148f04a3fbe29e7b7b306e60228314e9c6b2e1 --- /dev/null +++ b/toolbox/torchaudio/models/demucs/configuration_demucs.py @@ -0,0 +1,51 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class DemucsConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + + in_channels: int = 1, + out_channels: int = 1, + hidden_channels: int = 48, + + depth: int = 5, + kernel_size: int = 8, + stride: int = 4, + + causal: bool = True, + resample: int = 4, + growth: int = 2, + + max_hidden: int = 10_000, + do_normalize: bool = True, + rescale: float = 0.1, + floor: float = 1e-3, + + **kwargs + ): + super(DemucsConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + + self.depth = depth + self.kernel_size = kernel_size + self.stride = stride + + self.causal = causal + self.resample = resample + self.growth = growth + + self.max_hidden = max_hidden + self.do_normalize = do_normalize + self.rescale = rescale + self.floor = floor + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/demucs/modeling_demucs.py b/toolbox/torchaudio/models/demucs/modeling_demucs.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d6b3f049bed708ccb640b83d9bb0489f261e1a --- /dev/null +++ b/toolbox/torchaudio/models/demucs/modeling_demucs.py @@ -0,0 +1,299 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2006.12847 + +https://github.com/facebookresearch/denoiser +""" +import math +import os +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.demucs.configuration_demucs import DemucsConfig +from toolbox.torchaudio.models.demucs.resample import upsample2, downsample2 + + +activation_layer_dict = { + "glu": nn.GLU, + "relu": nn.ReLU, + "identity": nn.Identity, + "sigmoid": nn.Sigmoid, +} + + +class BLSTM(nn.Module): + def __init__(self, + hidden_size: int, + num_layers: int = 2, + bidirectional: bool = True, + ): + super().__init__() + self.lstm = nn.LSTM(bidirectional=bidirectional, + num_layers=num_layers, + hidden_size=hidden_size, + input_size=hidden_size + ) + self.linear = None + if bidirectional: + self.linear = nn.Linear(2 * hidden_size, hidden_size) + + def forward(self, + x: torch.Tensor, + hx: torch.Tensor = None + ): + x, hx = self.lstm.forward(x, hx) + if self.linear: + x = self.linear(x) + return x, hx + + +def rescale_conv(conv, reference): + std = conv.weight.std().detach() + scale = (std / reference)**0.5 + conv.weight.data /= scale + if conv.bias is not None: + conv.bias.data /= scale + + +def rescale_module(module, reference): + for sub in module.modules(): + if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): + rescale_conv(sub, reference) + + +class DemucsModel(nn.Module): + def __init__(self, + in_channels: int = 1, + out_channels: int = 1, + hidden_channels: int = 48, + depth: int = 5, + kernel_size: int = 8, + stride: int = 4, + causal: bool = True, + resample: int = 4, + growth: int = 2, + max_hidden: int = 10_000, + do_normalize: bool = True, + rescale: float = 0.1, + floor: float = 1e-3, + ): + super(DemucsModel, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + + self.depth = depth + self.kernel_size = kernel_size + self.stride = stride + + self.causal = causal + + self.resample = resample + self.growth = growth + self.max_hidden = max_hidden + self.do_normalize = do_normalize + self.rescale = rescale + self.floor = floor + + if resample not in [1, 2, 4]: + raise ValueError("Resample should be 1, 2 or 4.") + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for index in range(depth): + encode = [] + encode += [ + nn.Conv1d(in_channels, hidden_channels, kernel_size, stride), + nn.ReLU(), + nn.Conv1d(hidden_channels, hidden_channels * 2, 1), + nn.GLU(1), + ] + self.encoder.append(nn.Sequential(*encode)) + + decode = [] + decode += [ + nn.Conv1d(hidden_channels, 2 * hidden_channels, 1), + nn.GLU(1), + nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride), + ] + if index > 0: + decode.append(nn.ReLU()) + self.decoder.insert(0, nn.Sequential(*decode)) + out_channels = hidden_channels + in_channels = hidden_channels + hidden_channels = min(int(growth * hidden_channels), max_hidden) + + self.lstm = BLSTM(in_channels, bidirectional=not causal) + + if rescale: + rescale_module(self, reference=rescale) + + @staticmethod + def valid_length(length: int, depth: int, kernel_size: int, stride: int, resample: int): + """ + Return the nearest valid length to use with the model so that + there is no time steps left over in a convolutions, e.g. for all + layers, size of the input - kernel_size % stride = 0. + + If the mixture has a valid length, the estimated sources + will have exactly the same length. + """ + length = math.ceil(length * resample) + for idx in range(depth): + length = math.ceil((length - kernel_size) / stride) + 1 + length = max(length, 1) + for idx in range(depth): + length = (length - 1) * stride + kernel_size + length = int(math.ceil(length / resample)) + return int(length) + + def forward(self, noisy: torch.Tensor): + """ + :param noisy: Tensor, shape: [batch_size, num_samples] or [batch_size, channels, num_samples] + :return: + """ + if noisy.dim() == 2: + noisy = noisy.unsqueeze(1) + # noisy shape: [batch_size, channels, num_samples] + + if self.do_normalize: + mono = noisy.mean(dim=1, keepdim=True) + std = mono.std(dim=-1, keepdim=True) + noisy = noisy / (self.floor + std) + else: + std = 1 + + _, _, length = noisy.shape + x = noisy + + length_ = self.valid_length(length, self.depth, self.kernel_size, self.stride, self.resample) + x = F.pad(x, (0, length_ - length)) + + if self.resample == 2: + x = upsample2(x) + elif self.resample == 4: + x = upsample2(x) + x = upsample2(x) + + skips = [] + for encode in self.encoder: + x = encode(x) + skips.append(x) + x = x.permute(2, 0, 1) + x, _ = self.lstm(x) + x = x.permute(1, 2, 0) + + for decode in self.decoder: + skip = skips.pop(-1) + x = x + skip[..., :x.shape[-1]] + x = decode(x) + + if self.resample == 2: + x = downsample2(x) + elif self.resample == 4: + x = downsample2(x) + x = downsample2(x) + + x = x[..., :length] + return std * x + + +MODEL_FILE = "model.pt" + + +class DemucsPretrainedModel(DemucsModel): + def __init__(self, + config: DemucsConfig, + ): + super(DemucsPretrainedModel, self).__init__( + # sample_rate=config.sample_rate, + in_channels=config.in_channels, + out_channels=config.out_channels, + hidden_channels=config.hidden_channels, + depth=config.depth, + kernel_size=config.kernel_size, + stride=config.stride, + causal=config.causal, + resample=config.resample, + growth=config.growth, + max_hidden=config.max_hidden, + do_normalize=config.do_normalize, + rescale=config.rescale, + floor=config.floor, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = DemucsConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + config = DemucsConfig() + model = DemucsModel( + in_channels=config.in_channels, + out_channels=config.out_channels, + hidden_channels=config.hidden_channels, + depth=config.depth, + kernel_size=config.kernel_size, + stride=config.stride, + causal=config.causal, + resample=config.resample, + growth=config.growth, + max_hidden=config.max_hidden, + do_normalize=config.do_normalize, + rescale=config.rescale, + floor=config.floor, + ) + + print(model) + + noisy = torch.rand(size=(1, 8000*4), dtype=torch.float32) + + denoise = model.forward(noisy) + print(denoise.shape) + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/demucs/resample.py b/toolbox/torchaudio/models/demucs/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf56d41827f843aa0a3b836a85402b8775d5781 --- /dev/null +++ b/toolbox/torchaudio/models/demucs/resample.py @@ -0,0 +1,81 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# author: adefossez + +import math + +import torch as th +from torch.nn import functional as F + + +def sinc(t): + """sinc. + + :param t: the input tensor + """ + return th.where(t == 0, th.tensor(1., device=t.device, dtype=t.dtype), th.sin(t) / t) + + +def kernel_upsample2(zeros=56): + """kernel_upsample2. + + """ + win = th.hann_window(4 * zeros + 1, periodic=False) + winodd = win[1::2] + t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t *= math.pi + kernel = (sinc(t) * winodd).view(1, 1, -1) + return kernel + + +def upsample2(x, zeros=56): + """ + Upsampling the input by 2 using sinc interpolation. + Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." + ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. + Vol. 9. IEEE, 1984. + """ + *other, time = x.shape + kernel = kernel_upsample2(zeros).to(x) + out = F.conv1d(x.view(-1, 1, time), kernel, padding=zeros)[..., 1:].view(*other, time) + y = th.stack([x, out], dim=-1) + return y.view(*other, -1) + + +def kernel_downsample2(zeros=56): + """kernel_downsample2. + + """ + win = th.hann_window(4 * zeros + 1, periodic=False) + winodd = win[1::2] + t = th.linspace(-zeros + 0.5, zeros - 0.5, 2 * zeros) + t.mul_(math.pi) + kernel = (sinc(t) * winodd).view(1, 1, -1) + return kernel + + +def downsample2(x, zeros=56): + """ + Downsampling the input by 2 using sinc interpolation. + Smith, Julius, and Phil Gossett. "A flexible sampling-rate conversion method." + ICASSP'84. IEEE International Conference on Acoustics, Speech, and Signal Processing. + Vol. 9. IEEE, 1984. + """ + if x.shape[-1] % 2 != 0: + x = F.pad(x, (0, 1)) + xeven = x[..., ::2] + xodd = x[..., 1::2] + *other, time = xodd.shape + kernel = kernel_downsample2(zeros).to(x) + out = xeven + F.conv1d(xodd.view(-1, 1, time), kernel, padding=zeros)[..., :-1].view( + *other, time) + return out.view(*other, -1).mul(0.5) + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dfnet/__init__.py b/toolbox/torchaudio/models/dfnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/dfnet/configuration_dfnet.py b/toolbox/torchaudio/models/dfnet/configuration_dfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9c469e472428b45682b4204281a82af128235221 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet/configuration_dfnet.py @@ -0,0 +1,149 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Tuple + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class DfNetConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + nfft: int = 512, + win_size: int = 200, + hop_size: int = 80, + win_type: str = "hann", + + spec_bins: int = 256, + erb_bins: int = 32, + min_freq_bins_for_erb: int = 2, + + conv_channels: int = 64, + conv_kernel_size_input: Tuple[int, int] = (3, 3), + conv_kernel_size_inner: Tuple[int, int] = (1, 3), + conv_lookahead: int = 0, + + convt_kernel_size_inner: Tuple[int, int] = (1, 3), + + embedding_hidden_size: int = 256, + encoder_combine_op: str = "concat", + + encoder_emb_skip_op: str = "none", + encoder_emb_linear_groups: int = 16, + encoder_emb_hidden_size: int = 256, + + encoder_linear_groups: int = 32, + + decoder_emb_num_layers: int = 3, + decoder_emb_skip_op: str = "none", + decoder_emb_linear_groups: int = 16, + decoder_emb_hidden_size: int = 256, + + df_decoder_hidden_size: int = 256, + df_num_layers: int = 2, + df_order: int = 5, + df_bins: int = 96, + df_gru_skip: str = "grouped_linear", + df_decoder_linear_groups: int = 16, + df_pathway_kernel_size_t: int = 5, + df_lookahead: int = 2, + + n_frame: int = 3, + max_local_snr: int = 30, + min_local_snr: int = -15, + norm_tau: float = 1., + + min_snr_db: float = -10, + max_snr_db: float = 20, + + lr: float = 0.001, + lr_scheduler: str = "CosineAnnealingLR", + lr_scheduler_kwargs: dict = None, + + max_epochs: int = 100, + clip_grad_norm: float = 10., + seed: int = 1234, + + num_workers: int = 4, + batch_size: int = 4, + eval_steps: int = 25000, + + use_post_filter: bool = False, + + **kwargs + ): + super(DfNetConfig, self).__init__(**kwargs) + # transform + self.sample_rate = sample_rate + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + # spectrum + self.spec_bins = spec_bins + self.erb_bins = erb_bins + self.min_freq_bins_for_erb = min_freq_bins_for_erb + + # conv + self.conv_channels = conv_channels + self.conv_kernel_size_input = conv_kernel_size_input + self.conv_kernel_size_inner = conv_kernel_size_inner + self.conv_lookahead = conv_lookahead + + self.convt_kernel_size_inner = convt_kernel_size_inner + + self.embedding_hidden_size = embedding_hidden_size + + # encoder + self.encoder_emb_skip_op = encoder_emb_skip_op + self.encoder_emb_linear_groups = encoder_emb_linear_groups + self.encoder_emb_hidden_size = encoder_emb_hidden_size + + self.encoder_linear_groups = encoder_linear_groups + self.encoder_combine_op = encoder_combine_op + + # decoder + self.decoder_emb_num_layers = decoder_emb_num_layers + self.decoder_emb_skip_op = decoder_emb_skip_op + self.decoder_emb_linear_groups = decoder_emb_linear_groups + self.decoder_emb_hidden_size = decoder_emb_hidden_size + + # df decoder + self.df_decoder_hidden_size = df_decoder_hidden_size + self.df_num_layers = df_num_layers + self.df_order = df_order + self.df_bins = df_bins + self.df_gru_skip = df_gru_skip + self.df_decoder_linear_groups = df_decoder_linear_groups + self.df_pathway_kernel_size_t = df_pathway_kernel_size_t + self.df_lookahead = df_lookahead + + # lsnr + self.n_frame = n_frame + self.max_local_snr = max_local_snr + self.min_local_snr = min_local_snr + self.norm_tau = norm_tau + + # data snr + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + + # train + self.lr = lr + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() + + self.max_epochs = max_epochs + self.clip_grad_norm = clip_grad_norm + self.seed = seed + + self.num_workers = num_workers + self.batch_size = batch_size + self.eval_steps = eval_steps + + # runtime + self.use_post_filter = use_post_filter + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dfnet/inference_dfnet.py b/toolbox/torchaudio/models/dfnet/inference_dfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9f06a941b14076813d2f03ff1ab354621834d50c --- /dev/null +++ b/toolbox/torchaudio/models/dfnet/inference_dfnet.py @@ -0,0 +1,112 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from pathlib import Path +import shutil +import tempfile, time +import zipfile + +import librosa +import numpy as np +import torch +import torchaudio + +torch.set_num_threads(1) + +from project_settings import project_path +from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig +from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNetPretrainedModel, MODEL_FILE + +logger = logging.getLogger("toolbox") + + +class InferenceDfNet(object): + def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): + self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file + self.device = torch.device(device) + + logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") + config, model = self.load_models(self.pretrained_model_path_or_zip_file) + logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") + + self.config = config + self.model = model + self.model.to(device) + self.model.eval() + + def load_models(self, model_path: str): + model_path = Path(model_path) + if model_path.name.endswith(".zip"): + with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: + out_root = Path(tempfile.gettempdir()) / "nx_denoise" + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + model_path = out_root / model_path.stem + + config = DfNetConfig.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model = DfNetPretrainedModel.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model.to(self.device) + model.eval() + + shutil.rmtree(model_path) + return config, model + + def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # noisy_audio shape: [batch_size, n_samples] + enhanced_audio = self.enhancement_by_tensor(noisy_audio) + # enhanced_audio shape: [channels, num_samples] + enhanced_audio = enhanced_audio[0] + # enhanced_audio shape: [num_samples] + return enhanced_audio.cpu().numpy() + + def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios) + + # shape: [batch_size, 1, num_samples] + denoise = est_wav[0] + # shape: [channels, num_samples] + return denoise + + +def main(): + model_zip_file = project_path / "trained_models/dfnet-nx-dns3.zip" + infer_model = InferenceDfNet(model_zip_file) + + sample_rate = 8000 + noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav" + noisy_audio, sample_rate = librosa.load( + noisy_audio_file.as_posix(), + sr=sample_rate, + ) + duration = librosa.get_duration(y=noisy_audio, sr=sample_rate) + # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + begin = time.time() + enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio) + time_cost = time.time() - begin + print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + filename = "enhanced_audio.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/dfnet/modeling_dfnet.py b/toolbox/torchaudio/models/dfnet/modeling_dfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..61057fa34b0cbf50207e61ab67667f9fa8737e55 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet/modeling_dfnet.py @@ -0,0 +1,1091 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +DeepFilterNet 的原生实现不直接支持流式推理 + +社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现 +https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF +""" +import os +import math +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig +from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT +from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget +from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands + + +MODEL_FILE = "model.pt" + + +norm_layer_dict = { + "batch_norm_2d": torch.nn.BatchNorm2d +} + + +activation_layer_dict = { + "relu": torch.nn.ReLU, + "identity": torch.nn.Identity, + "sigmoid": torch.nn.Sigmoid, +} + + +class CausalConv2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal Conv2d by delaying the signal for any lookahead. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + + :param in_channels: + :param out_channels: + :param kernel_size: + :param fstride: + :param dilation: + :param fpad: + """ + super(CausalConv2d, self).__init__() + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) + + if fpad: + fpad_ = kernel_size[1] // 2 + dilation - 1 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = list() + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + if max(kernel_size) == 1: + separable = False + + layers.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + def forward(self, inputs): + for module in self: + inputs = module(inputs) + return inputs + + +class CausalConvTranspose2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal ConvTranspose2d. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + """ + super(CausalConvTranspose2d, self).__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + if fpad: + fpad_ = kernel_size[1] // 2 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + + layers.append( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size[0] - 1, fpad_ + dilation - 1), + output_padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + +class GroupedLinear(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" + assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" + self.ws = input_size // groups + self.register_parameter( + "weight", + torch.nn.Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True + ), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., I] + b, t, f = x.shape + if f != self.input_size: + raise AssertionError + + # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] + new_shape = (b, t, self.groups, self.ws) + x = x.view(new_shape) + # The better way, but not supported by torchscript + # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] + x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] + x = x.flatten(2, 3) + # x: [b, t, h] + return x + + def __repr__(self): + cls = self.__class__.__name__ + return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" + + +class SqueezedGRU_S(nn.Module): + """ + SGE net: Video object detection with squeezed GRU and information entropy map + https://arxiv.org/abs/2106.07224 + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + skip_op: str = "none", + activation_layer: str = "identity", + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + self.linear_in = nn.Sequential( + GroupedLinear( + input_size=input_size, + hidden_size=hidden_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + + # gru skip operator + self.gru_skip_op = None + + if skip_op == "none": + self.gru_skip_op = None + elif skip_op == "identity": + if not input_size != output_size: + raise AssertionError("Dimensions do not match") + self.gru_skip_op = nn.Identity() + elif skip_op == "grouped_linear": + self.gru_skip_op = GroupedLinear( + input_size=hidden_size, + hidden_size=hidden_size, + groups=linear_groups, + ) + else: + raise NotImplementedError() + + self.gru = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + bidirectional=False, + ) + + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinear( + input_size=hidden_size, + hidden_size=output_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + else: + self.linear_out = nn.Identity() + + def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: + # inputs: shape: [b, t, h] + x = self.linear_in.forward(inputs) + + x, h = self.gru.forward(x, h) + + x = self.linear_out(x) + + if self.gru_skip_op is not None: + x = x + self.gru_skip_op(inputs) + + return x, h + + +class Add(nn.Module): + def forward(self, a, b): + return a + b + + +class Concat(nn.Module): + def forward(self, a, b): + return torch.cat((a, b), dim=-1) + + +class Encoder(nn.Module): + def __init__(self, config: DfNetConfig): + super(Encoder, self).__init__() + self.embedding_input_size = config.conv_channels * config.erb_bins // 4 + self.embedding_output_size = config.conv_channels * config.erb_bins // 4 + self.embedding_hidden_size = config.embedding_hidden_size + + self.spec_conv0 = CausalConv2d( + in_channels=1, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.spec_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.spec_conv2 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.spec_conv3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + self.df_conv0 = CausalConv2d( + in_channels=2, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + ) + self.df_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.df_fc_emb = nn.Sequential( + GroupedLinear( + config.conv_channels * config.df_bins // 2, + self.embedding_input_size, + groups=config.encoder_linear_groups + ), + nn.ReLU(inplace=True) + ) + + if config.encoder_combine_op == "concat": + self.embedding_input_size *= 2 + self.combine = Concat() + else: + self.combine = Add() + + # emb_gru + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_gru = SqueezedGRU_S( + self.embedding_input_size, + self.embedding_hidden_size, + output_size=self.embedding_output_size, + num_layers=1, + batch_first=True, + skip_op=config.encoder_emb_skip_op, + linear_groups=config.encoder_emb_linear_groups, + activation_layer="relu", + ) + + # lsnr + self.lsnr_fc = nn.Sequential( + nn.Linear(self.embedding_output_size, 1), + nn.Sigmoid() + ) + self.lsnr_scale = config.max_local_snr - config.min_local_snr + self.lsnr_offset = config.min_local_snr + + def forward(self, + feat_erb: torch.Tensor, + feat_spec: torch.Tensor, + hidden_state: torch.Tensor = None, + ): + # feat_erb shape: (b, 1, t, erb_bins) + e0 = self.spec_conv0.forward(feat_erb) + e1 = self.spec_conv1.forward(e0) + e2 = self.spec_conv2.forward(e1) + e3 = self.spec_conv3.forward(e2) + # e0 shape: [b, c, t, erb_bins] + # e1 shape: [b, c, t, erb_bins // 2] + # e2 shape: [b, c, t, erb_bins // 4] + # e3 shape: [b, c, t, erb_bins // 4] + # e3 shape: [b, 64, t, 32/4=8] + + # feat_spec, shape: (b, 2, t, df_bins) + c0 = self.df_conv0(feat_spec) + c1 = self.df_conv1(c0) + # c0 shape: [b, c, t, df_bins] + # c1 shape: [b, c, t, df_bins // 2] + # c1 shape: [b, 64, t, 96/2=48] + + cemb = c1.permute(0, 2, 3, 1) + # cemb shape: [b, t, df_bins // 2, c] + cemb = cemb.flatten(2) + # cemb shape: [b, t, df_bins // 2 * c] + # cemb shape: [b, t, 96/2*64=3072] + cemb = self.df_fc_emb.forward(cemb) + # cemb shape: [b, t, erb_bins // 4 * c] + # cemb shape: [b, t, 32/4*64=512] + + # e3 shape: [b, c, t, erb_bins // 4] + emb = e3.permute(0, 2, 3, 1) + # emb shape: [b, t, erb_bins // 4, c] + emb = emb.flatten(2) + # emb shape: [b, t, erb_bins // 4 * c] + # emb shape: [b, t, 32/4*64=512] + + emb = self.combine(emb, cemb) + # if concat; emb shape: [b, t, spec_bins // 4 * c * 2] + # if add; emb shape: [b, t, spec_bins // 4 * c] + + emb, h = self.emb_gru.forward(emb, hidden_state) + + # emb shape: [b, t, spec_dim // 4 * c] + # h shape: [b, 1, spec_dim] + + lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset + # lsnr shape: [b, t, 1] + + return e0, e1, e2, e3, emb, c0, lsnr, h + + +class Decoder(nn.Module): + """ErbDecoder""" + def __init__(self, config: DfNetConfig): + super(Decoder, self).__init__() + + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * config.erb_bins // 4 + self.emb_out_dim = config.conv_channels * config.erb_bins // 4 + self.emb_hidden_dim = config.decoder_emb_hidden_size + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=config.decoder_emb_num_layers - 1, + batch_first=True, + skip_op=config.decoder_emb_skip_op, + linear_groups=config.decoder_emb_linear_groups, + activation_layer="relu", + ) + self.conv3p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv2p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt2 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv1p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt1 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv0p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv0_out = CausalConv2d( + in_channels=config.conv_channels, + out_channels=1, + kernel_size=config.conv_kernel_size_inner, + activation_layer="sigmoid", + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: + # Estimates erb mask + b, _, t, f8 = e3.shape + + # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] + emb, _ = self.emb_gru.forward(emb) + # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) + e3 = self.convt3(self.conv3p(e3) + emb) + # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + e2 = self.convt2(self.conv2p(e2) + e3) + # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] + e1 = self.convt1(self.conv1p(e1) + e2) + # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] + mask = self.conv0_out(self.conv0p(e0) + e1) + # mask shape: [batch_size, 1, time_steps, freq_dim] + return mask + + +class DfDecoder(nn.Module): + def __init__(self, config: DfNetConfig): + super(DfDecoder, self).__init__() + + self.embedding_input_size = config.conv_channels * config.erb_bins // 4 + self.df_decoder_hidden_size = config.df_decoder_hidden_size + self.df_num_layers = config.df_num_layers + + self.df_order = config.df_order + + self.df_bins = config.df_bins + self.df_out_ch = config.df_order * 2 + + self.df_convp = CausalConv2d( + config.conv_channels, + self.df_out_ch, + fstride=1, + kernel_size=(config.df_pathway_kernel_size_t, 1), + separable=True, + bias=False, + ) + self.df_gru = SqueezedGRU_S( + self.embedding_input_size, + self.df_decoder_hidden_size, + num_layers=self.df_num_layers, + batch_first=True, + skip_op="none", + activation_layer="relu", + ) + + if config.df_gru_skip == "none": + self.df_skip = None + elif config.df_gru_skip == "identity": + if config.embedding_hidden_size != config.df_decoder_hidden_size: + raise AssertionError("Dimensions do not match") + self.df_skip = nn.Identity() + elif config.df_gru_skip == "grouped_linear": + self.df_skip = GroupedLinear( + self.embedding_input_size, + self.df_decoder_hidden_size, + groups=config.df_decoder_linear_groups + ) + else: + raise NotImplementedError() + + self.df_out: nn.Module + out_dim = self.df_bins * self.df_out_ch + + self.df_out = nn.Sequential( + GroupedLinear( + input_size=self.df_decoder_hidden_size, + hidden_size=out_dim, + groups=config.df_decoder_linear_groups, + # groups = self.df_bins // 5, + ), + nn.Tanh() + ) + self.df_fc_a = nn.Sequential( + nn.Linear(self.df_decoder_hidden_size, 1), + nn.Sigmoid() + ) + + def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: + # emb shape: [batch_size, time_steps, df_bins // 4 * channels] + b, t, _ = emb.shape + df_coefs, _ = self.df_gru(emb) + if self.df_skip is not None: + df_coefs = df_coefs + self.df_skip(emb) + # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] + + # c0 shape: [batch_size, channels, time_steps, df_bins] + c0 = self.df_convp(c0) + # c0 shape: [batch_size, df_order * 2, time_steps, df_bins] + c0 = c0.permute(0, 2, 3, 1) + # c0 shape: [batch_size, time_steps, df_bins, df_order * 2] + + df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order + # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] + df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) + # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] + df_coefs = df_coefs + c0 + # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] + return df_coefs + + +class DfOutputReshapeMF(nn.Module): + """Coefficients output reshape for multiframe/MultiFrameModule + + Requires input of shape B, C, T, F, 2. + """ + + def __init__(self, df_order: int, df_bins: int): + super().__init__() + self.df_order = df_order + self.df_bins = df_bins + + def forward(self, coefs: torch.Tensor) -> torch.Tensor: + # [B, T, F, O*2] -> [B, O, T, F, 2] + new_shape = list(coefs.shape) + new_shape[-1] = -1 + new_shape.append(2) + coefs = coefs.view(new_shape) + coefs = coefs.permute(0, 3, 1, 2, 4) + return coefs + + +class Mask(nn.Module): + def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): + super().__init__() + self.use_post_filter = use_post_filter + self.eps = eps + + def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: + """ + Post-Filter + + A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. + https://arxiv.org/abs/2008.04259 + + :param mask: Real valued mask, typically of shape [B, C, T, F]. + :param beta: Global gain factor. + :return: + """ + mask_sin = mask * torch.sin(np.pi * mask / 2) + mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) + return mask_pf + + def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # spec shape: [b, 1, t, spec_bins, 2] + + if not self.training and self.use_post_filter: + mask = self.post_filter(mask) + + # mask shape: [b, 1, t, spec_bins] + mask = mask.unsqueeze(4) + # mask shape: [b, 1, t, spec_bins, 1] + return spec * mask + + +class DeepFiltering(nn.Module): + def __init__(self, + df_bins: int, + df_order: int, + lookahead: int = 0, + ): + super(DeepFiltering, self).__init__() + self.df_bins = df_bins + self.df_order = df_order + self.need_unfold = df_order > 1 + self.lookahead = lookahead + + self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) + + def spec_unfold(self, spec: torch.Tensor): + """ + Pads and unfolds the spectrogram according to frame_size. + :param spec: complex Tensor, Spectrogram of shape [B, C, T, F]. + :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. + """ + if self.need_unfold: + # spec shape: [batch_size, spec_bins, time_steps] + spec_pad = self.pad(spec) + # spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins] + spec_unfold = spec_pad.unfold(2, self.df_order, 1) + # spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order] + return spec_unfold + else: + return spec.unsqueeze(-1) + + def forward(self, + spec: torch.Tensor, + coefs: torch.Tensor, + ): + # spec shape: [batch_size, 1, time_steps, spec_bins, 2] + spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous())) + # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order] + + # coefs shape: [batch_size, df_order, time_steps, df_bins, 2] + coefs = torch.view_as_complex(coefs.contiguous()) + # coefs shape: [batch_size, df_order, time_steps, df_bins] + spec_f = spec_u.narrow(-2, 0, self.df_bins) + # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order] + + coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:]) + # coefs shape: [batch_size, 1, df_order, time_steps, df_bins] + + spec_f = self.df(spec_f, coefs) + # spec_f shape: [batch_size, 1, time_steps, df_bins] + + if self.training: + spec = spec.clone() + spec[..., :self.df_bins, :] = torch.view_as_real(spec_f) + # spec shape: [batch_size, 1, time_steps, spec_bins, 2] + return spec + + @staticmethod + def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: + """ + Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. + :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. + :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. + :return: (complex Tensor). Spectrogram of shape [B, C, T, F]. + """ + return torch.einsum("...tfn,...ntf->...tf", spec, coefs) + + +class DfNet(nn.Module): + """ + 我感觉这个模型没办法实现完全一致的流式推理。 + """ + def __init__(self, config: DfNetConfig): + super(DfNet, self).__init__() + self.config = config + self.eps = 1e-12 + + self.freq_bins = self.config.nfft // 2 + 1 + + self.nfft = config.nfft + self.win_size = config.win_size + self.hop_size = config.hop_size + self.win_type = config.win_type + + self.erb_bands = ErbBands( + sample_rate=config.sample_rate, + nfft=config.nfft, + erb_bins=config.erb_bins, + min_freq_bins_for_erb=config.min_freq_bins_for_erb, + ) + + self.stft = ConvSTFT( + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + win_type=config.win_type, + power=None, + requires_grad=False + ) + self.istft = ConviSTFT( + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + win_type=config.win_type, + requires_grad=False + ) + + self.encoder = Encoder(config) + self.decoder = Decoder(config) + + self.df_decoder = DfDecoder(config) + self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) + self.df_op = DeepFiltering( + df_bins=config.df_bins, + df_order=config.df_order, + lookahead=config.df_lookahead, + ) + + self.mask = Mask(use_post_filter=config.use_post_filter) + + self.lsnr_fn = LocalSnrTarget( + sample_rate=config.sample_rate, + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + n_frame=config.n_frame, + min_local_snr=config.min_local_snr, + max_local_snr=config.max_local_snr, + db=True, + ) + + def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: + if signal.dim() == 2: + signal = torch.unsqueeze(signal, dim=1) + _, _, n_samples = signal.shape + remainder = (n_samples - self.win_size) % self.hop_size + if remainder > 0: + n_samples_pad = self.hop_size - remainder + signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) + return signal + + def feature_prepare(self, signal: torch.Tensor): + # noisy shape: [b, num_samples_pad] + spec_cmp = self.stft.forward(signal) + # spec_complex shape: [b, f, t], torch.complex64 + spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2) + # spec_complex shape: [b, t, f], torch.complex64 + spec_cmp_real = torch.view_as_real(spec_cmp) + # spec_cmp_real shape: [b, t, f, 2] + spec_mag = torch.abs(spec_cmp) + spec_pow = torch.square(spec_mag) + # shape: [b, t, f] + + spec = torch.unsqueeze(spec_cmp_real, dim=1) + # spec shape: [b, 1, t, f, 2] + + feat_erb = self.erb_bands.erb_scale(spec_pow, db=True) + # feat_erb shape: [b, t, erb_bins] + feat_erb = torch.unsqueeze(feat_erb, dim=1) + # feat_erb shape: [b, 1, t, erb_bins] + + feat_spec = spec_cmp_real.permute(0, 3, 1, 2) + # feat_spec shape: [b, 2, t, f] + feat_spec = feat_spec[..., :self.df_decoder.df_bins] + # feat_spec shape: [b, 2, t, df_bins] + + return spec, feat_erb, feat_spec + + def forward(self, + noisy: torch.Tensor, + ): + """ + :param noisy: + :return: + est_spec: shape: [b, 257*2, t] + est_wav: shape: [b, num_samples] + est_mask: shape: [b, 257, t] + lsnr: shape: [b, 1, t] + """ + n_samples = noisy.shape[-1] + noisy = self.signal_prepare(noisy) + + spec, feat_erb, feat_spec = self.feature_prepare(noisy) + + e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_erb, feat_spec) + + mask = self.decoder.forward(emb, e3, e2, e1, e0) + # mask shape: [b, 1, t, erb_bins] + mask = self.erb_bands.erb_scale_inv(mask) + # mask shape: [b, 1, t, f] + if torch.any(mask > 1) or torch.any(mask < 0): + raise AssertionError + + spec_m = self.mask.forward(spec, mask) + # spec_m shape: [b, 1, t, f, 2] + spec_m = spec_m[:, :, :, :self.config.spec_bins, :] + # spec_m shape: [b, 1, t, spec_bins, 2] + + # lsnr shape: [b, t, 1] + lsnr = torch.transpose(lsnr, dim0=2, dim1=1) + # lsnr shape: [b, 1, t] + + df_coefs = self.df_decoder.forward(emb, c0) + df_coefs = self.df_out_transform(df_coefs) + # df_coefs shape: [b, df_order, t, df_bins, 2] + + spec_ = spec[:, :, :, :self.config.spec_bins, :] + # spec shape: [b, 1, t, spec_bins, 2] + spec_e = self.df_op.forward(spec_, df_coefs) + # spec_e shape: [b, 1, t, spec_bins, 2] + + spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :] + + spec_e = torch.squeeze(spec_e, dim=1) + spec_e = spec_e.permute(0, 2, 1, 3) + # spec_e shape: [b, spec_bins, t, 2] + + # spec_e shape: [b, spec_bins, t, 2] + est_spec = torch.complex(real=spec_e[..., 0], imag=spec_e[..., 1]) + # est_spec shape: [b, spec_bins, t], torch.complex64 + est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) + # est_spec shape: [b, f, t], torch.complex64 + + est_wav = self.istft.forward(est_spec) + est_wav = est_wav[:, :, :n_samples] + # est_wav shape: [b, 1, n_samples] + + est_mask = torch.squeeze(mask, dim=1) + est_mask = est_mask.permute(0, 2, 1) + # est_mask shape: [b, f, t] + + return est_spec, est_wav, est_mask, lsnr + + def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): + """ + :param est_mask: torch.Tensor, shape: [b, 257, t] + :param clean: + :param noisy: + :return: + """ + if noisy.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + noise = noisy - clean + + clean = self.signal_prepare(clean) + noise = self.signal_prepare(noise) + + stft_clean = self.stft.forward(clean) + mag_clean = torch.abs(stft_clean) + + stft_noise = self.stft.forward(noise) + mag_noise = torch.abs(stft_noise) + + gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1) + + loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean") + + return loss + + def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): + if noisy.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + noise = noisy - clean + + clean = self.signal_prepare(clean) + noise = self.signal_prepare(noise) + + stft_clean = self.stft.forward(clean) + stft_noise = self.stft.forward(noise) + # shape: [b, f, t] + stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2) + stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2) + # shape: [b, t, f] + stft_clean = torch.unsqueeze(stft_clean, dim=1) + stft_noise = torch.unsqueeze(stft_noise, dim=1) + # shape: [b, 1, t, f] + + # lsnr shape: [b, 1, t] + lsnr = lsnr.squeeze(1) + # lsnr shape: [b, t] + + lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) + # lsnr_gth shape: [b, t] + + loss = F.mse_loss(lsnr, lsnr_gth) + return loss + + +class DfNetPretrainedModel(DfNet): + def __init__(self, + config: DfNetConfig, + ): + super(DfNetPretrainedModel, self).__init__( + config=config, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + + config = DfNetConfig() + model = DfNetPretrainedModel(config=config) + + noisy = torch.randn(size=(1, 16000), dtype=torch.float32) + + est_spec, est_wav, est_mask, lsnr = model.forward(noisy) + print(f"est_spec.shape: {est_spec.shape}") + print(f"est_wav.shape: {est_wav.shape}") + print(f"est_mask.shape: {est_mask.shape}") + print(f"lsnr.shape: {lsnr.shape}") + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/dfnet/yaml/config.yaml b/toolbox/torchaudio/models/dfnet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51369c4da056b8edde7fd84c4fde30b30b3122b9 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet/yaml/config.yaml @@ -0,0 +1,74 @@ +model_name: "dfnet" + +# spec +sample_rate: 8000 +nfft: 512 +win_size: 200 +hop_size: 80 + +spec_bins: 256 + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +conv_lookahead: 0 + +convt_kernel_size_inner: + - 1 + - 3 + +embedding_hidden_size: 256 +encoder_combine_op: "concat" + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +encoder_linear_groups: 32 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +df_decoder_hidden_size: 256 +df_num_layers: 2 +df_order: 5 +df_bins: 96 +df_gru_skip: "grouped_linear" +df_decoder_linear_groups: 16 +df_pathway_kernel_size_t: 5 +df_lookahead: 2 + +# lsnr +n_frame: 3 +lsnr_max: 30 +lsnr_min: -15 +norm_tau: 1. + +# data +min_snr_db: -10 +max_snr_db: 20 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 8 +batch_size: 64 +eval_steps: 10000 + +# runtime +use_post_filter: true diff --git a/toolbox/torchaudio/models/dfnet2/__init__.py b/toolbox/torchaudio/models/dfnet2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet2/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/dfnet2/configuration_dfnet2.py b/toolbox/torchaudio/models/dfnet2/configuration_dfnet2.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b336ef411748e8c74a8ea97e12ce652beb21e1 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet2/configuration_dfnet2.py @@ -0,0 +1,150 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Tuple + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class DfNet2Config(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + nfft: int = 512, + win_size: int = 200, + hop_size: int = 80, + win_type: str = "hann", + + spec_bins: int = 256, + erb_bins: int = 32, + min_freq_bins_for_erb: int = 2, + use_ema_norm: bool = True, + + conv_channels: int = 64, + conv_kernel_size_input: Tuple[int, int] = (3, 3), + conv_kernel_size_inner: Tuple[int, int] = (1, 3), + + convt_kernel_size_inner: Tuple[int, int] = (1, 3), + + embedding_hidden_size: int = 256, + encoder_combine_op: str = "concat", + + encoder_emb_skip_op: str = "none", + encoder_emb_linear_groups: int = 16, + encoder_emb_hidden_size: int = 256, + + encoder_linear_groups: int = 32, + + decoder_emb_num_layers: int = 3, + decoder_emb_skip_op: str = "none", + decoder_emb_linear_groups: int = 16, + decoder_emb_hidden_size: int = 256, + + df_decoder_hidden_size: int = 256, + df_num_layers: int = 2, + df_order: int = 5, + df_bins: int = 96, + df_gru_skip: str = "grouped_linear", + df_decoder_linear_groups: int = 16, + df_pathway_kernel_size_t: int = 5, + df_lookahead: int = 2, + + n_frame: int = 3, + max_local_snr: int = 30, + min_local_snr: int = -15, + norm_tau: float = 1., + + min_snr_db: float = -10, + max_snr_db: float = 20, + + lr: float = 0.001, + lr_scheduler: str = "CosineAnnealingLR", + lr_scheduler_kwargs: dict = None, + + max_epochs: int = 100, + clip_grad_norm: float = 10., + seed: int = 1234, + + num_workers: int = 4, + batch_size: int = 4, + eval_steps: int = 25000, + + use_post_filter: bool = False, + + **kwargs + ): + super(DfNet2Config, self).__init__(**kwargs) + # transform + self.sample_rate = sample_rate + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + # spectrum + self.spec_bins = spec_bins + self.erb_bins = erb_bins + self.min_freq_bins_for_erb = min_freq_bins_for_erb + + self.use_ema_norm = use_ema_norm + + # conv + self.conv_channels = conv_channels + self.conv_kernel_size_input = conv_kernel_size_input + self.conv_kernel_size_inner = conv_kernel_size_inner + + self.convt_kernel_size_inner = convt_kernel_size_inner + + self.embedding_hidden_size = embedding_hidden_size + + # encoder + self.encoder_emb_skip_op = encoder_emb_skip_op + self.encoder_emb_linear_groups = encoder_emb_linear_groups + self.encoder_emb_hidden_size = encoder_emb_hidden_size + + self.encoder_linear_groups = encoder_linear_groups + self.encoder_combine_op = encoder_combine_op + + # decoder + self.decoder_emb_num_layers = decoder_emb_num_layers + self.decoder_emb_skip_op = decoder_emb_skip_op + self.decoder_emb_linear_groups = decoder_emb_linear_groups + self.decoder_emb_hidden_size = decoder_emb_hidden_size + + # df decoder + self.df_decoder_hidden_size = df_decoder_hidden_size + self.df_num_layers = df_num_layers + self.df_order = df_order + self.df_bins = df_bins + self.df_gru_skip = df_gru_skip + self.df_decoder_linear_groups = df_decoder_linear_groups + self.df_pathway_kernel_size_t = df_pathway_kernel_size_t + self.df_lookahead = df_lookahead + + # lsnr + self.n_frame = n_frame + self.max_local_snr = max_local_snr + self.min_local_snr = min_local_snr + self.norm_tau = norm_tau + + # data snr + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + + # train + self.lr = lr + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() + + self.max_epochs = max_epochs + self.clip_grad_norm = clip_grad_norm + self.seed = seed + + self.num_workers = num_workers + self.batch_size = batch_size + self.eval_steps = eval_steps + + # runtime + self.use_post_filter = use_post_filter + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dfnet2/inference_dfnet2.py b/toolbox/torchaudio/models/dfnet2/inference_dfnet2.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef00ec482b8d835b728773d5477e03c8602f21f --- /dev/null +++ b/toolbox/torchaudio/models/dfnet2/inference_dfnet2.py @@ -0,0 +1,135 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from pathlib import Path +import shutil +import tempfile, time +import zipfile + +import librosa +import numpy as np +import torch +import torchaudio + +torch.set_num_threads(1) + +from project_settings import project_path +from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config +from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2PretrainedModel, MODEL_FILE + +logger = logging.getLogger("toolbox") + + +class InferenceDfNet2(object): + def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): + self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file + self.device = torch.device(device) + + logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") + config, model = self.load_models(self.pretrained_model_path_or_zip_file) + logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") + + self.config = config + self.model = model + self.model.to(device) + self.model.eval() + + def load_models(self, model_path: str): + model_path = Path(model_path) + if model_path.name.endswith(".zip"): + with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: + out_root = Path(tempfile.gettempdir()) / "nx_denoise" + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + model_path = out_root / model_path.stem + + config = DfNet2Config.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model = DfNet2PretrainedModel.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model.to(self.device) + model.eval() + + shutil.rmtree(model_path) + return config, model + + def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # noisy_audio shape: [batch_size, n_samples] + enhanced_audio = self.denoise_offline(noisy_audio) + # enhanced_audio shape: [channels, num_samples] + enhanced_audio = enhanced_audio[0] + # enhanced_audio shape: [num_samples] + return enhanced_audio.cpu().numpy() + + def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios) + + # shape: [batch_size, 1, num_samples] + denoise = est_wav[0] + # shape: [channels, num_samples] + return denoise + + def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + denoise = self.model.forward_chunk_by_chunk(noisy_audios) + + # shape: [batch_size, 1, num_samples] + denoise = denoise[0] + # shape: [channels, num_samples] + return denoise + + +def main(): + model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip" + infer_model = InferenceDfNet2(model_zip_file) + + sample_rate = 8000 + noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav" + noisy_audio, sample_rate = librosa.load( + noisy_audio_file.as_posix(), + sr=sample_rate, + ) + duration = librosa.get_duration(y=noisy_audio, sr=sample_rate) + # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + begin = time.time() + enhanced_audio = infer_model.denoise_offline(noisy_audio) + time_cost = time.time() - begin + print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + filename = "enhanced_audio_offline.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + begin = time.time() + enhanced_audio = infer_model.denoise_online(noisy_audio) + time_cost = time.time() - begin + print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + filename = "enhanced_audio_online.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py b/toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py new file mode 100644 index 0000000000000000000000000000000000000000..9054e4704b4cd0076f7b181b4b28a74adbd5c2e8 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py @@ -0,0 +1,1535 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +DeepFilterNet 的原生实现不直接支持流式推理 + +社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现 +https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF + +此文件试图实现一个支持流式推理的 dfnet + +""" +import os +import math +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config +from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT +from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget +from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands +from toolbox.torchaudio.modules.utils.ema import ErbEMA, SpecEMA + + +MODEL_FILE = "model.pt" + + +norm_layer_dict = { + "batch_norm_2d": torch.nn.BatchNorm2d +} + + +activation_layer_dict = { + "relu": torch.nn.ReLU, + "identity": torch.nn.Identity, + "sigmoid": torch.nn.Sigmoid, +} + + +class CausalConv2d(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + pad_f_dim: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + ): + super(CausalConv2d, self).__init__() + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) + + if pad_f_dim: + fpad = kernel_size[1] // 2 + dilation - 1 + else: + fpad = 0 + + # for last 2 dim, pad (left, right, top, bottom). + self.lookback = kernel_size[0] - 1 + if self.lookback > 0: + self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) + else: + self.tpad = nn.Identity() + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + if max(kernel_size) == 1: + separable = False + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(0, fpad), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + + if separable: + self.convp = nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + else: + self.convp = nn.Identity() + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + self.norm = norm_layer(out_channels) + else: + self.norm = nn.Identity() + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + self.activation = activation_layer() + else: + self.activation = nn.Identity() + + def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): + """ + :param inputs: shape: [b, c, t, f] + :param cache: shape: [b, c, lookback, f]; + :return: + """ + x = inputs + + if cache is None: + x = self.tpad(x) + else: + x = torch.concat(tensors=[cache, x], dim=2) + + new_cache = None + if self.lookback > 0: + new_cache = x[:, :, -self.lookback:, :] + + x = self.conv(x) + + x = self.convp(x) + x = self.norm(x) + x = self.activation(x) + + return x, new_cache + + +class CausalConvTranspose2dErrorCase(nn.Module): + """ + 错误的缓存方法。 + """ + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + pad_f_dim: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + ): + super(CausalConvTranspose2dErrorCase, self).__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + if pad_f_dim: + fpad = kernel_size[1] // 2 + else: + fpad = 0 + + # for last 2 dim, pad (left, right, top, bottom). + self.lookback = kernel_size[0] - 1 + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + + self.convt = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(0, fpad), + output_padding=(0, fpad), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + + if separable: + self.convp = nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + else: + self.convp = nn.Identity() + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + self.norm = norm_layer(out_channels) + else: + self.norm = nn.Identity() + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + self.activation = activation_layer() + else: + self.activation = nn.Identity() + + def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): + """ + :param inputs: shape: [b, c, t, f] + :param cache: shape: [b, c, lookback, f]; + :return: + """ + x = inputs + + # x shape: [b, c, t, f] + x = self.convt(x) + # x shape: [b, c, t+lookback, f] + + new_cache = None + if self.lookback > 0: + if cache is not None: + x = torch.concat(tensors=[ + x[:, :, :self.lookback, :] + cache, + x[:, :, self.lookback:, :] + ], dim=2) + + x = x[:, :, :-self.lookback, :] + new_cache = x[:, :, -self.lookback:, :] + + x = self.convp(x) + x = self.norm(x) + x = self.activation(x) + + return x, new_cache + + +class CausalConvTranspose2d(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + pad_f_dim: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + ): + super(CausalConvTranspose2d, self).__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + if pad_f_dim: + fpad = kernel_size[1] // 2 + else: + fpad = 0 + + # for last 2 dim, pad (left, right, top, bottom). + self.lookback = kernel_size[0] - 1 + if self.lookback > 0: + self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) + else: + self.tpad = nn.Identity() + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + + self.convt = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size[0] - 1, fpad + dilation - 1), + output_padding=(0, fpad), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + + if separable: + self.convp = nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + else: + self.convp = nn.Identity() + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + self.norm = norm_layer(out_channels) + else: + self.norm = nn.Identity() + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + self.activation = activation_layer() + else: + self.activation = nn.Identity() + + def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): + """ + :param inputs: shape: [b, c, t, f] + :param cache: shape: [b, c, lookback, f]; + :return: + """ + x = inputs + + # x shape: [b, c, t, f] + x = self.convt(x) + # x shape: [b, c, t+lookback, f] + + if cache is None: + x = self.tpad(x) + else: + x = torch.concat(tensors=[cache, x], dim=2) + + new_cache = None + if self.lookback > 0: + new_cache = x[:, :, -self.lookback:, :] + + x = self.convp(x) + x = self.norm(x) + x = self.activation(x) + + return x, new_cache + + +class GroupedLinear(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" + assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" + self.ws = input_size // groups + self.register_parameter( + "weight", + torch.nn.Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True + ), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., I] + b, t, f = x.shape + if f != self.input_size: + raise AssertionError + + # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] + new_shape = (b, t, self.groups, self.ws) + x = x.view(new_shape) + # The better way, but not supported by torchscript + # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] + x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] + x = x.flatten(2, 3) + # x: [b, t, h] + return x + + def __repr__(self): + cls = self.__class__.__name__ + return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" + + +class SqueezedGRU_S(nn.Module): + """ + SGE net: Video object detection with squeezed GRU and information entropy map + https://arxiv.org/abs/2106.07224 + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + skip_op: str = "none", + activation_layer: str = "identity", + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + self.linear_in = nn.Sequential( + GroupedLinear( + input_size=input_size, + hidden_size=hidden_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + + # gru skip operator + self.gru_skip_op = None + + if skip_op == "none": + self.gru_skip_op = None + elif skip_op == "identity": + if not input_size != output_size: + raise AssertionError("Dimensions do not match") + self.gru_skip_op = nn.Identity() + elif skip_op == "grouped_linear": + self.gru_skip_op = GroupedLinear( + input_size=hidden_size, + hidden_size=hidden_size, + groups=linear_groups, + ) + else: + raise NotImplementedError() + + self.gru = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + bidirectional=False, + ) + + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinear( + input_size=hidden_size, + hidden_size=output_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + else: + self.linear_out = nn.Identity() + + def forward(self, inputs: torch.Tensor, hx: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + # inputs: shape: [b, t, h] + x = self.linear_in.forward(inputs) + + x, hx = self.gru.forward(x, hx) + + x = self.linear_out(x) + + if self.gru_skip_op is not None: + x = x + self.gru_skip_op(inputs) + + return x, hx + + +class Add(nn.Module): + def forward(self, a, b): + return a + b + + +class Concat(nn.Module): + def forward(self, a, b): + return torch.cat((a, b), dim=-1) + + +class Encoder(nn.Module): + def __init__(self, config: DfNet2Config): + super(Encoder, self).__init__() + self.embedding_input_size = config.conv_channels * config.erb_bins // 4 + self.embedding_output_size = config.conv_channels * config.erb_bins // 4 + self.embedding_hidden_size = config.embedding_hidden_size + + self.spec_conv0 = CausalConv2d( + in_channels=1, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + ) + self.spec_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.spec_conv2 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.spec_conv3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + ) + + self.df_conv0 = CausalConv2d( + in_channels=2, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + ) + self.df_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.df_fc_emb = nn.Sequential( + GroupedLinear( + config.conv_channels * config.df_bins // 2, + self.embedding_input_size, + groups=config.encoder_linear_groups + ), + nn.ReLU(inplace=True) + ) + + if config.encoder_combine_op == "concat": + self.embedding_input_size *= 2 + self.combine = Concat() + else: + self.combine = Add() + + # emb_gru + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_gru = SqueezedGRU_S( + self.embedding_input_size, + self.embedding_hidden_size, + output_size=self.embedding_output_size, + num_layers=1, + batch_first=True, + skip_op=config.encoder_emb_skip_op, + linear_groups=config.encoder_emb_linear_groups, + activation_layer="relu", + ) + + # lsnr + self.lsnr_fc = nn.Sequential( + nn.Linear(self.embedding_output_size, 1), + nn.Sigmoid() + ) + self.lsnr_scale = config.max_local_snr - config.min_local_snr + self.lsnr_offset = config.min_local_snr + + def forward(self, + feat_erb: torch.Tensor, + feat_spec: torch.Tensor, + cache_dict: dict = None, + ): + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) + cache2 = cache_dict.get("cache2", None) + cache3 = cache_dict.get("cache3", None) + cache4 = cache_dict.get("cache4", None) + cache5 = cache_dict.get("cache5", None) + cache6 = cache_dict.get("cache6", None) + + # feat_erb shape: (b, 1, t, erb_bins) + e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0) + e1, new_cache1 = self.spec_conv1.forward(e0, cache=cache1) + e2, new_cache2 = self.spec_conv2.forward(e1, cache=cache2) + e3, new_cache3 = self.spec_conv3.forward(e2, cache=cache3) + # e0 shape: [b, c, t, erb_bins] + # e1 shape: [b, c, t, erb_bins // 2] + # e2 shape: [b, c, t, erb_bins // 4] + # e3 shape: [b, c, t, erb_bins // 4] + # e3 shape: [b, 64, t, 32/4=8] + + # feat_spec, shape: (b, 2, t, df_bins) + c0, new_cache4 = self.df_conv0.forward(feat_spec, cache=cache4) + c1, new_cache5 = self.df_conv1.forward(c0, cache=cache5) + # c0 shape: [b, c, t, df_bins] + # c1 shape: [b, c, t, df_bins // 2] + # c1 shape: [b, 64, t, 96/2=48] + + cemb = c1.permute(0, 2, 3, 1) + # cemb shape: [b, t, df_bins // 2, c] + cemb = cemb.flatten(2) + # cemb shape: [b, t, df_bins // 2 * c] + # cemb shape: [b, t, 96/2*64=3072] + cemb = self.df_fc_emb.forward(cemb) + # cemb shape: [b, t, erb_bins // 4 * c] + # cemb shape: [b, t, 32/4*64=512] + + # e3 shape: [b, c, t, erb_bins // 4] + emb = e3.permute(0, 2, 3, 1) + # emb shape: [b, t, erb_bins // 4, c] + emb = emb.flatten(2) + # emb shape: [b, t, erb_bins // 4 * c] + # emb shape: [b, t, 32/4*64=512] + + emb = self.combine(emb, cemb) + # if concat; emb shape: [b, t, spec_bins // 4 * c * 2] + # if add; emb shape: [b, t, spec_bins // 4 * c] + + emb, new_cache6 = self.emb_gru.forward(emb, hx=cache6) + + # emb shape: [b, t, spec_dim // 4 * c] + # h shape: [b, 1, spec_dim] + + lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset + # lsnr shape: [b, t, 1] + + new_cache_dict = { + "cache0": new_cache0, + "cache1": new_cache1, + "cache2": new_cache2, + "cache3": new_cache3, + "cache4": new_cache4, + "cache5": new_cache5, + "cache6": new_cache6, + } + return e0, e1, e2, e3, emb, c0, lsnr, new_cache_dict + + +class ErbDecoder(nn.Module): + def __init__(self, config: DfNet2Config): + super(ErbDecoder, self).__init__() + + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * config.erb_bins // 4 + self.emb_out_dim = config.conv_channels * config.erb_bins // 4 + self.emb_hidden_dim = config.decoder_emb_hidden_size + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=config.decoder_emb_num_layers - 1, + batch_first=True, + skip_op=config.decoder_emb_skip_op, + linear_groups=config.decoder_emb_linear_groups, + activation_layer="relu", + ) + self.conv3p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + ) + self.convt3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + ) + self.conv2p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + ) + self.convt2 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.conv1p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + ) + self.convt1 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.conv0p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + ) + self.conv0_out = CausalConv2d( + in_channels=config.conv_channels, + out_channels=1, + kernel_size=config.conv_kernel_size_inner, + activation_layer="sigmoid", + bias=False, + separable=True, + fstride=1, + ) + + def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor: + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) + cache2 = cache_dict.get("cache2", None) + cache3 = cache_dict.get("cache3", None) + cache4 = cache_dict.get("cache4", None) + + # Estimates erb mask + b, _, t, f8 = e3.shape + + # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] + emb, new_cache0 = self.emb_gru.forward(emb, hx=cache0) + # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) + + e3, new_cache1 = self.convt3.forward(self.conv3p(e3)[0] + emb, cache=cache1) + # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + e2, new_cache2 = self.convt2.forward(self.conv2p(e2)[0] + e3, cache=cache2) + # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] + e1, new_cache3 = self.convt1.forward(self.conv1p(e1)[0] + e2, cache=cache3) + # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] + mask, new_cache4 = self.conv0_out.forward(self.conv0p(e0)[0] + e1, cache=cache4) + # mask shape: [batch_size, 1, time_steps, freq_dim] + + new_cache_dict = { + "cache0": new_cache0, + "cache1": new_cache1, + "cache2": new_cache2, + "cache3": new_cache3, + "cache4": new_cache4, + } + return mask, new_cache_dict + + +class DfDecoder(nn.Module): + def __init__(self, config: DfNet2Config): + super(DfDecoder, self).__init__() + + self.embedding_input_size = config.conv_channels * config.erb_bins // 4 + self.df_decoder_hidden_size = config.df_decoder_hidden_size + self.df_num_layers = config.df_num_layers + + self.df_order = config.df_order + + self.df_bins = config.df_bins + self.df_out_ch = config.df_order * 2 + + self.df_convp = CausalConv2d( + config.conv_channels, + self.df_out_ch, + fstride=1, + kernel_size=(config.df_pathway_kernel_size_t, 1), + separable=True, + bias=False, + ) + self.df_gru = SqueezedGRU_S( + self.embedding_input_size, + self.df_decoder_hidden_size, + num_layers=self.df_num_layers, + batch_first=True, + skip_op="none", + activation_layer="relu", + ) + + if config.df_gru_skip == "none": + self.df_skip = None + elif config.df_gru_skip == "identity": + if config.embedding_hidden_size != config.df_decoder_hidden_size: + raise AssertionError("Dimensions do not match") + self.df_skip = nn.Identity() + elif config.df_gru_skip == "grouped_linear": + self.df_skip = GroupedLinear( + self.embedding_input_size, + self.df_decoder_hidden_size, + groups=config.df_decoder_linear_groups + ) + else: + raise NotImplementedError() + + self.df_out: nn.Module + out_dim = self.df_bins * self.df_out_ch + + self.df_out = nn.Sequential( + GroupedLinear( + input_size=self.df_decoder_hidden_size, + hidden_size=out_dim, + groups=config.df_decoder_linear_groups, + # groups = self.df_bins // 5, + ), + nn.Tanh() + ) + self.df_fc_a = nn.Sequential( + nn.Linear(self.df_decoder_hidden_size, 1), + nn.Sigmoid() + ) + + def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor: + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) + + # emb shape: [batch_size, time_steps, df_bins // 4 * channels] + b, t, _ = emb.shape + df_coefs, new_cache0 = self.df_gru.forward(emb, hx=cache0) + if self.df_skip is not None: + df_coefs = df_coefs + self.df_skip(emb) + # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] + + # c0 shape: [batch_size, channels, time_steps, df_bins] + c0, new_cache1 = self.df_convp.forward(c0, cache=cache1) + # c0 shape: [batch_size, df_order * 2, time_steps, df_bins] + c0 = c0.permute(0, 2, 3, 1) + # c0 shape: [batch_size, time_steps, df_bins, df_order * 2] + + df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order + # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] + df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) + # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] + df_coefs = df_coefs + c0 + # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] + + new_cache_dict = { + "cache0": new_cache0, + "cache1": new_cache1, + } + return df_coefs, new_cache_dict + + +class DfOutputReshapeMF(nn.Module): + """Coefficients output reshape for multiframe/MultiFrameModule + + Requires input of shape B, C, T, F, 2. + """ + + def __init__(self, df_order: int, df_bins: int): + super().__init__() + self.df_order = df_order + self.df_bins = df_bins + + def forward(self, coefs: torch.Tensor) -> torch.Tensor: + # [B, T, F, O*2] -> [B, O, T, F, 2] + new_shape = list(coefs.shape) + new_shape[-1] = -1 + new_shape.append(2) + coefs = coefs.view(new_shape) + coefs = coefs.permute(0, 3, 1, 2, 4) + return coefs + + +class Mask(nn.Module): + def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): + super().__init__() + self.use_post_filter = use_post_filter + self.eps = eps + + def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: + """ + Post-Filter + + A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. + https://arxiv.org/abs/2008.04259 + + :param mask: Real valued mask, typically of shape [B, C, T, F]. + :param beta: Global gain factor. + :return: + """ + mask_sin = mask * torch.sin(np.pi * mask / 2) + mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) + return mask_pf + + def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # spec shape: [b, 1, t, spec_bins, 2] + + if not self.training and self.use_post_filter: + mask = self.post_filter(mask) + + # mask shape: [b, 1, t, spec_bins] + mask = mask.unsqueeze(4) + # mask shape: [b, 1, t, spec_bins, 1] + return spec * mask + + +class DeepFiltering(nn.Module): + def __init__(self, + df_bins: int, + df_order: int, + lookahead: int = 0, + ): + super(DeepFiltering, self).__init__() + self.df_bins = df_bins + self.df_order = df_order + self.lookahead = lookahead + + self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) + + def forward(self, *args, **kwargs): + raise AssertionError("use `forward_offline` or `forward_online` stead.") + + def spec_unfold_offline(self, spec: torch.Tensor) -> torch.Tensor: + """ + Pads and unfolds the spectrogram according to frame_size. + :param spec: shape: [b, c, t, f], dtype: torch.complex64 + :return: shape: [b, c, t, f, df_order] + """ + if self.df_order <= 1: + return spec.unsqueeze(-1) + + # spec shape: [b, 1, t, f], dtype: torch.complex64 + spec = self.pad(spec) + # spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64 + spec_unfold = spec.unfold(dimension=2, size=self.df_order, step=1) + # spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64 + return spec_unfold + + def forward_offline(self, + spec: torch.Tensor, + coefs: torch.Tensor, + ): + # spec shape: [b, 1, t, spec_bins, 2] + spec_c = torch.view_as_complex(spec.contiguous()) + # spec_c shape: [b, 1, t, spec_bins] + spec_u = self.spec_unfold_offline(spec_c) + # spec_u shape: [b, 1, t, spec_bins, df_order] + spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins) + # spec_f shape: [b, 1, t, df_bins, df_order] + + # coefs shape: [b, df_order, t, df_bins, 2] + coefs = torch.view_as_complex(coefs.contiguous()) + # coefs shape: [b, df_order, t, df_bins] + coefs = coefs.unsqueeze(dim=1) + # coefs shape: [b, 1, df_order, t, df_bins] + + spec_f = self.df_offline(spec_f, coefs) + # spec_f shape: [b, 1, t, df_bins] + + spec_f = torch.view_as_real(spec_f) + # spec_f shape: [b, 1, t, df_bins, 2] + return spec_f + + def df_offline(self, spec: torch.Tensor, coefs: torch.Tensor): + """ + Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. + :param spec: [b, 1, t, df_bins, df_order] complex. + :param coefs: [b, 1, df_order, t, df_bins] complex. + :return: [b, 1, t, df_bins] complex. + """ + spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs) + return spec_f + + def spec_unfold_online(self, spec: torch.Tensor, cache_spec: torch.Tensor = None): + """ + Pads and unfolds the spectrogram according to frame_size. + :param spec: shape: [b, c, t, f], dtype: torch.complex64 + :param cache_spec: shape: [b, c, df_order-1, f], dtype: torch.complex64 + :return: shape: [b, c, t, f, df_order] + """ + if self.df_order <= 1: + return spec.unsqueeze(-1) + + if cache_spec is None: + b, c, _, f = spec.shape + cache_spec = spec.new_zeros(size=(b, c, self.df_order-1, f)) + spec_pad = torch.concat(tensors=[ + cache_spec, spec + ], dim=2) + new_cache_spec = spec_pad[:, :, -(self.df_order-1):, :] + + # spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64 + spec_unfold = spec_pad.unfold(dimension=2, size=self.df_order, step=1) + # spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64 + return spec_unfold, new_cache_spec + + def forward_online(self, + spec: torch.Tensor, + coefs: torch.Tensor, + cache_dict: dict = None, + ): + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) + + # spec shape: [b, 1, t, spec_bins, 2] + spec_c = torch.view_as_complex(spec.contiguous()) + # spec_c shape: [b, 1, t, spec_bins] + spec_u, new_cache0 = self.spec_unfold_online(spec_c, cache_spec=cache0) + # spec_u shape: [b, 1, t, spec_bins, df_order] + spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins) + # spec_f shape: [b, 1, t, df_bins, df_order] + + # coefs shape: [b, df_order, t, df_bins, 2] + coefs = torch.view_as_complex(coefs.contiguous()) + # coefs shape: [b, df_order, t, df_bins] + coefs = coefs.unsqueeze(dim=1) + # coefs shape: [b, 1, df_order, t, df_bins] + + spec_f, new_cache1 = self.df_online(spec_f, coefs, cache_coefs=cache1) + # spec_f shape: [b, 1, t, df_bins] + + spec_f = torch.view_as_real(spec_f) + # spec_f shape: [b, 1, t, df_bins, 2] + + new_cache_dict = { + "cache0": new_cache0, + "cache1": new_cache1, + } + return spec_f, new_cache_dict + + def df_online(self, spec: torch.Tensor, coefs: torch.Tensor, cache_coefs: torch.Tensor = None) -> torch.Tensor: + """ + Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. + :param spec: [b, 1, 1, df_bins, df_order] complex. + :param coefs: [b, 1, df_order, 1, df_bins] complex. + :param cache_coefs: [b, 1, df_order, lookahead, df_bins] complex. + :return: [b, 1, 1, df_bins] complex. + """ + + if cache_coefs is None: + b, c, _, _, f = coefs.shape + cache_coefs = coefs.new_zeros(size=(b, c, self.df_order, self.lookahead, f)) + coefs_pad = torch.concat(tensors=[ + cache_coefs, coefs + ], dim=3) + + # coefs_pad shape: [b, 1, df_order, 1+lookahead, df_bins], torch.complex64. + coefs = coefs_pad[:, :, :, :-self.lookahead, :] + # coefs shape: [b, 1, df_order, 1, df_bins], torch.complex64. + new_cache_coefs = coefs_pad[:, :, :, -self.lookahead:, :] + # new_cache_coefs shape: [b, 1, df_order, lookahead, df_bins], torch.complex64. + spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs) + return spec_f, new_cache_coefs + + +class DfNet2(nn.Module): + def __init__(self, config: DfNet2Config): + super(DfNet2, self).__init__() + self.config = config + self.eps = 1e-12 + + self.freq_bins = self.config.nfft // 2 + 1 + + self.nfft = config.nfft + self.win_size = config.win_size + self.hop_size = config.hop_size + self.win_type = config.win_type + + self.stft = ConvSTFT( + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + win_type=config.win_type, + power=None, + requires_grad=False + ) + self.istft = ConviSTFT( + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + win_type=config.win_type, + requires_grad=False + ) + + self.erb_bands = ErbBands( + sample_rate=config.sample_rate, + nfft=config.nfft, + erb_bins=config.erb_bins, + min_freq_bins_for_erb=config.min_freq_bins_for_erb, + ) + + self.erb_ema = ErbEMA( + sample_rate=config.sample_rate, + hop_size=config.hop_size, + erb_bins=config.erb_bins, + ) + self.spec_ema = SpecEMA( + sample_rate=config.sample_rate, + hop_size=config.hop_size, + df_bins=config.df_bins, + ) + + self.encoder = Encoder(config) + self.erb_decoder = ErbDecoder(config) + + self.df_decoder = DfDecoder(config) + self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) + self.df_op = DeepFiltering( + df_bins=config.df_bins, + df_order=config.df_order, + lookahead=config.df_lookahead, + ) + + self.mask = Mask(use_post_filter=config.use_post_filter) + + self.lsnr_fn = LocalSnrTarget( + sample_rate=config.sample_rate, + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + n_frame=config.n_frame, + min_local_snr=config.min_local_snr, + max_local_snr=config.max_local_snr, + db=True, + ) + + def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: + if signal.dim() == 2: + signal = torch.unsqueeze(signal, dim=1) + _, _, n_samples = signal.shape + remainder = (n_samples - self.win_size) % self.hop_size + if remainder > 0: + n_samples_pad = self.hop_size - remainder + signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) + return signal + + def feature_prepare(self, signal: torch.Tensor): + # noisy shape: [b, num_samples_pad] + spec_cmp = self.stft.forward(signal) + # spec_complex shape: [b, f, t], torch.complex64 + spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2) + # spec_complex shape: [b, t, f], torch.complex64 + spec_cmp_real = torch.view_as_real(spec_cmp) + # spec_cmp_real shape: [b, t, f, 2] + spec_mag = torch.abs(spec_cmp) + spec_pow = torch.square(spec_mag) + # shape: [b, t, f] + + spec = torch.unsqueeze(spec_cmp_real, dim=1) + # spec shape: [b, 1, t, f, 2] + + feat_erb = self.erb_bands.erb_scale(spec_pow, db=True) + # feat_erb shape: [b, t, erb_bins] + feat_erb = torch.unsqueeze(feat_erb, dim=1) + # feat_erb shape: [b, 1, t, erb_bins] + + feat_spec = spec_cmp_real.permute(0, 3, 1, 2) + # feat_spec shape: [b, 2, t, f] + feat_spec = feat_spec[..., :self.df_decoder.df_bins] + # feat_spec shape: [b, 2, t, df_bins] + + spec = spec.detach() + feat_erb = feat_erb.detach() + feat_spec = feat_spec.detach() + return spec, feat_erb, feat_spec + + def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None): + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) + + feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0) + feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1) + + new_cache_dict = { + "cache0": new_cache0, + "cache1": new_cache1, + } + + feat_erb = feat_erb.detach() + feat_spec = feat_spec.detach() + return feat_erb, feat_spec, new_cache_dict + + def forward(self, + noisy: torch.Tensor, + ): + """ + :param noisy: + :return: + est_spec: shape: [b, 257*2, t] + est_wav: shape: [b, num_samples] + est_mask: shape: [b, 257, t] + lsnr: shape: [b, 1, t] + """ + n_samples = noisy.shape[-1] + noisy = self.signal_prepare(noisy) + + spec, feat_erb, feat_spec = self.feature_prepare(noisy) + if self.config.use_ema_norm: + feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec) + + e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec) + + mask, _ = self.erb_decoder.forward(emb, e3, e2, e1, e0) + # mask shape: [b, 1, t, erb_bins] + mask = self.erb_bands.erb_scale_inv(mask) + # mask shape: [b, 1, t, f] + if torch.any(mask > 1) or torch.any(mask < 0): + raise AssertionError + + spec_m = self.mask.forward(spec, mask) + # spec_m shape: [b, 1, t, f, 2] + spec_m = spec_m[:, :, :, :self.config.spec_bins, :] + # spec_m shape: [b, 1, t, spec_bins, 2] + + # lsnr shape: [b, t, 1] + lsnr = torch.transpose(lsnr, dim0=2, dim1=1) + # lsnr shape: [b, 1, t] + + df_coefs, _ = self.df_decoder.forward(emb, c0) + df_coefs = self.df_out_transform(df_coefs) + # df_coefs shape: [b, df_order, t, df_bins, 2] + + spec_ = spec[:, :, :, :self.config.spec_bins, :] + # spec shape: [b, 1, t, spec_bins, 2] + spec_f = self.df_op.forward_offline(spec_, df_coefs) + # spec_f shape: [b, 1, t, df_bins, 2], torch.float32 + + spec_e = torch.concat(tensors=[ + spec_f, spec_m[..., self.df_decoder.df_bins:, :] + ], dim=3) + + spec_e = torch.squeeze(spec_e, dim=1) + spec_e = spec_e.permute(0, 2, 1, 3) + # spec_e shape: [b, spec_bins, t, 2] + + # spec_e shape: [b, spec_bins, t, 2] + est_spec = torch.view_as_complex(spec_e.contiguous()) + # est_spec shape: [b, spec_bins, t], torch.complex64 + est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) + # est_spec shape: [b, f, t], torch.complex64 + + est_wav = self.istft.forward(est_spec) + est_wav = est_wav[:, :, :n_samples] + # est_wav shape: [b, 1, n_samples] + + est_mask = torch.squeeze(mask, dim=1) + est_mask = est_mask.permute(0, 2, 1) + # est_mask shape: [b, f, t] + + return est_spec, est_wav, est_mask, lsnr + + def forward_chunk(self, + sub_noisy: torch.Tensor, + cache_dict0: dict = None, + cache_dict1: dict = None, + cache_dict2: dict = None, + cache_dict3: dict = None, + cache_dict4: dict = None, + cache_dict5: dict = None, + cache_dict6: dict = None, + ): + + spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy) + # spec shape: [b, 1, t, f, 2] + # feat_erb shape: [b, 1, t, erb_bins] + # feat_spec shape: [b, 2, t, df_bins] + if self.config.use_ema_norm: + feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0) + + e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1) + + mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2) + # mask shape: [b, 1, t, erb_bins] + mask = self.erb_bands.erb_scale_inv(mask) + # mask shape: [b, 1, t, f] + + spec_m = self.mask.forward(spec, mask) + # spec_m shape: [b, 1, t, f, 2] + spec_m = spec_m[:, :, :, :self.config.spec_bins, :] + # spec_m shape: [b, 1, t, spec_bins, 2] + + # lsnr shape: [b, t, 1] + lsnr = torch.transpose(lsnr, dim0=2, dim1=1) + # lsnr shape: [b, 1, t] + + df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3) + df_coefs = self.df_out_transform(df_coefs) + # df_coefs shape: [b, df_order, t, df_bins, 2] + + spec_ = spec[:, :, :, :self.config.spec_bins, :] + # spec shape: [b, 1, t, spec_bins, 2] + spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4) + # spec_f shape: [b, 1, t, df_bins, 2], torch.float32 + + spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5) + + spec_e = torch.squeeze(spec_e, dim=1) + spec_e = spec_e.permute(0, 2, 1, 3) + # spec_e shape: [b, spec_bins, t, 2] + + # spec_e shape: [b, spec_bins, t, 2] + est_spec = torch.view_as_complex(spec_e.contiguous()) + # est_spec shape: [b, spec_bins, t], torch.complex64 + est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) + # est_spec shape: [b, f, t], torch.complex64 + + est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6) + # est_wav shape: [b, 1, hop_size] + return est_wav, cache_dict0, cache_dict1, cache_dict2, cache_dict3, cache_dict4, cache_dict5, cache_dict6 + + def forward_chunk_by_chunk(self, + noisy: torch.Tensor, + ): + noisy = self.signal_prepare(noisy) + b, _, _ = noisy.shape + noisy = torch.concat(tensors=[ + noisy, noisy.new_zeros(size=(b, 1, (self.config.df_lookahead+1)*self.hop_size)) + ], dim=2) + b, _, num_samples = noisy.shape + + t = (num_samples - self.win_size) // self.hop_size + 1 + + cache_dict0 = None + cache_dict1 = None + cache_dict2 = None + cache_dict3 = None + cache_dict4 = None + cache_dict5 = None + cache_dict6 = None + + waveform_list = list() + for i in range(int(t)): + begin = i * self.hop_size + end = begin + self.win_size + sub_noisy = noisy[:, :, begin: end] + + (est_wav, + cache_dict0, cache_dict1, cache_dict2, cache_dict3, + cache_dict4, cache_dict5, cache_dict6) = self.forward_chunk( + sub_noisy, + cache_dict0, cache_dict1, cache_dict2, cache_dict3, + cache_dict4, cache_dict5, cache_dict6 + ) + + waveform_list.append(est_wav) + + waveform = torch.concat(tensors=waveform_list, dim=-1) + # waveform shape: [b, 1, n] + return waveform + + def spec_e_m_combine_online(self, spec_f: torch.Tensor, spec_m: torch.Tensor, cache_dict: dict = None): + """ + :param spec_f: shape: [b, 1, t, df_bins, 2], torch.float32 + :param spec_m: shape: [b, 1, t, spec_bins, 2] + :param cache_dict: + :return: + """ + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + + if cache0 is None: + b, c, t, f, _ = spec_m.shape + cache0 = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2)) + # cache0 shape: [b, 1, lookahead, f, 2] + spec_m_cat = torch.concat(tensors=[ + cache0, spec_m, + ], dim=2) + + spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :] + new_cache0 = spec_m_cat[:, :, -self.config.df_lookahead:, :, :] + + spec_e = torch.concat(tensors=[ + spec_f, spec_m[..., self.df_decoder.df_bins:, :] + ], dim=3) + + new_cache_dict = { + "cache0": new_cache0, + } + return spec_e, new_cache_dict + + def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): + """ + :param est_mask: torch.Tensor, shape: [b, 257, t] + :param clean: + :param noisy: + :return: + """ + if noisy.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + noise = noisy - clean + + clean = self.signal_prepare(clean) + noise = self.signal_prepare(noise) + + stft_clean = self.stft.forward(clean) + mag_clean = torch.abs(stft_clean) + + stft_noise = self.stft.forward(noise) + mag_noise = torch.abs(stft_noise) + + gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1) + + loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean") + + return loss + + def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): + if noisy.shape != clean.shape: + raise AssertionError("Input signals must have the same shape") + noise = noisy - clean + + clean = self.signal_prepare(clean) + noise = self.signal_prepare(noise) + + stft_clean = self.stft.forward(clean) + stft_noise = self.stft.forward(noise) + # shape: [b, f, t] + stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2) + stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2) + # shape: [b, t, f] + stft_clean = torch.unsqueeze(stft_clean, dim=1) + stft_noise = torch.unsqueeze(stft_noise, dim=1) + # shape: [b, 1, t, f] + + # lsnr shape: [b, 1, t] + lsnr = lsnr.squeeze(1) + # lsnr shape: [b, t] + + lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) + # lsnr_gth shape: [b, t] + + loss = F.mse_loss(lsnr, lsnr_gth) + return loss + + +class DfNet2PretrainedModel(DfNet2): + def __init__(self, + config: DfNet2Config, + ): + super(DfNet2PretrainedModel, self).__init__( + config=config, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = DfNet2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + import time + # torch.set_num_threads(1) + + config = DfNet2Config( + # nfft=512, + # win_size=200, + # hop_size=80, + nfft=512, + win_size=512, + hop_size=128, + ) + model = DfNet2PretrainedModel(config=config) + model.eval() + + num_samples = 16000 + noisy = torch.randn(size=(1, num_samples), dtype=torch.float32) + duration = num_samples / config.sample_rate + + begin = time.time() + with torch.no_grad(): + est_spec, est_wav, est_mask, lsnr = model.forward(noisy) + time_cost = time.time() - begin + print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + # print(f"est_spec.shape: {est_spec.shape}") + # print(f"est_wav.shape: {est_wav.shape}") + # print(f"est_mask.shape: {est_mask.shape}") + # print(f"lsnr.shape: {lsnr.shape}") + + waveform = est_wav + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + print(waveform[:, :, 1000: 1002]) + print(waveform[:, :, 8000: 8002]) + print(waveform[:, :, 14000: 14002]) + print(waveform[:, :, 15680: 15682]) + print(waveform[:, :, 15760: 15762]) + print(waveform[:, :, 15840: 15842]) + + begin = time.time() + waveform = model.forward_chunk_by_chunk(noisy) + time_cost = time.time() - begin + print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + waveform = waveform[:, :, (config.df_lookahead*config.hop_size):] + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + print(waveform[:, :, 1000: 1002]) + print(waveform[:, :, 8000: 8002]) + print(waveform[:, :, 14000: 14002]) + print(waveform[:, :, 15680: 15682]) + print(waveform[:, :, 15760: 15762]) + print(waveform[:, :, 15840: 15842]) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/dfnet2/yaml/config-200.yaml b/toolbox/torchaudio/models/dfnet2/yaml/config-200.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39ec8fd98e426c585879d181974c35df88358497 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet2/yaml/config-200.yaml @@ -0,0 +1,75 @@ +model_name: "dfnet2" + +# spec +sample_rate: 8000 +nfft: 512 +win_size: 200 +hop_size: 80 + +spec_bins: 256 +erb_bins: 32 +min_freq_bins_for_erb: 2 +use_ema_norm: true + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +convt_kernel_size_inner: + - 1 + - 3 + +embedding_hidden_size: 256 +encoder_combine_op: "concat" + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +encoder_linear_groups: 32 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +df_decoder_hidden_size: 256 +df_num_layers: 2 +df_order: 5 +df_bins: 96 +df_gru_skip: "grouped_linear" +df_decoder_linear_groups: 16 +df_pathway_kernel_size_t: 5 +df_lookahead: 2 + +# lsnr +n_frame: 3 +lsnr_max: 30 +lsnr_min: -15 +norm_tau: 1. + +# data +min_snr_db: -10 +max_snr_db: 20 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 8 +batch_size: 96 +eval_steps: 10000 + +# runtime +use_post_filter: true diff --git a/toolbox/torchaudio/models/dfnet2/yaml/config-512.yaml b/toolbox/torchaudio/models/dfnet2/yaml/config-512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c160e032aea3d8eae575a544b9275ce5608f8903 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet2/yaml/config-512.yaml @@ -0,0 +1,75 @@ +model_name: "dfnet" + +# spec +sample_rate: 8000 +nfft: 512 +win_size: 512 +hop_size: 128 + +spec_bins: 256 +erb_bins: 32 +min_freq_bins_for_erb: 2 +use_ema_norm: true + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +convt_kernel_size_inner: + - 1 + - 3 + +embedding_hidden_size: 256 +encoder_combine_op: "concat" + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +encoder_linear_groups: 32 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +df_decoder_hidden_size: 256 +df_num_layers: 2 +df_order: 5 +df_bins: 96 +df_gru_skip: "grouped_linear" +df_decoder_linear_groups: 16 +df_pathway_kernel_size_t: 5 +df_lookahead: 2 + +# lsnr +n_frame: 3 +lsnr_max: 30 +lsnr_min: -15 +norm_tau: 1. + +# data +min_snr_db: -10 +max_snr_db: 20 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 8 +batch_size: 96 +eval_steps: 10000 + +# runtime +use_post_filter: true diff --git a/toolbox/torchaudio/models/discriminators/__init__.py b/toolbox/torchaudio/models/discriminators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/models/discriminators/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/__init__.py b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/configuration_waveform_metric_discriminator.py b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/configuration_waveform_metric_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1d56ef4f84ea9fe7a4cce208749e29fd96e0e7 --- /dev/null +++ b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/configuration_waveform_metric_discriminator.py @@ -0,0 +1,34 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class WaveformMetricDiscriminatorConfig(PretrainedConfig): + """ + https://github.com/yxlu-0102/MP-SENet/blob/main/config.json + """ + def __init__(self, + sample_rate: int = 8000, + segment_size: int = 4, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + + discriminator_dim: int = 16, + discriminator_in_channel: int = 2, + + **kwargs + ): + super(WaveformMetricDiscriminatorConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.segment_size = segment_size + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + self.discriminator_dim = discriminator_dim + self.discriminator_in_channel = discriminator_in_channel + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/modeling_waveform_metric_discriminator.py b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/modeling_waveform_metric_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..843b0a4ac60a8a01b173d07ffc192b5cd1accc95 --- /dev/null +++ b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/modeling_waveform_metric_discriminator.py @@ -0,0 +1,145 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +import torchaudio + +from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.configuration_waveform_metric_discriminator import WaveformMetricDiscriminatorConfig + + +class LearnableSigmoid1d(nn.Module): + def __init__(self, in_features, beta=1): + super().__init__() + self.beta = beta + self.slope = nn.Parameter(torch.ones(in_features)) + self.slope.requiresGrad = True + + def forward(self, x): + # x shape: [batch_size, time_steps, spec_bins] + return self.beta * torch.sigmoid(self.slope * x) + + +class WaveformMetricDiscriminator(nn.Module): + def __init__(self, config: WaveformMetricDiscriminatorConfig): + super(WaveformMetricDiscriminator, self).__init__() + dim = config.discriminator_dim + self.in_channel = config.discriminator_in_channel + + self.n_fft = config.n_fft + self.win_length = config.win_length + self.hop_length = config.hop_length + + self.transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=1.0, + window_fn=torch.hann_window, + # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + self.layers = nn.Sequential( + nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim, affine=True), + nn.PReLU(dim), + nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim*2, affine=True), + nn.PReLU(dim*2), + nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim*4, affine=True), + nn.PReLU(dim*4), + nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim*8, affine=True), + nn.PReLU(dim*8), + nn.AdaptiveMaxPool2d(1), + nn.Flatten(), + nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), + nn.Dropout(0.3), + nn.PReLU(dim*4), + nn.utils.spectral_norm(nn.Linear(dim*4, 1)), + LearnableSigmoid1d(1) + ) + + def forward(self, denoise_audios, clean_audios): + x = denoise_audios + y = clean_audios + x = self.transform.forward(x) + y = self.transform.forward(y) + + xy = torch.stack((x, y), dim=1) + return self.layers(xy) + + +CONFIG_FILE = "discriminator_config.yaml" +MODEL_FILE = "discriminator.pt" + + +class WaveformMetricDiscriminatorPretrainedModel(WaveformMetricDiscriminator): + def __init__(self, + config: WaveformMetricDiscriminatorConfig, + ): + super(WaveformMetricDiscriminatorPretrainedModel, self).__init__( + config=config, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = WaveformMetricDiscriminatorPretrainedModel.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + config = WaveformMetricDiscriminatorConfig() + discriminator = WaveformMetricDiscriminator(config=config) + + # shape: [batch_size, num_samples] + # x = torch.ones([4, int(4.5 * 16000)]) + # y = torch.ones([4, int(4.5 * 16000)]) + x = torch.ones([4, 16000]) + y = torch.ones([4, 16000]) + + output = discriminator.forward(x, y) + print(output.shape) + print(output) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/yaml/discriminator_config.yaml b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/yaml/discriminator_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c43e51fd8bd329dee7fdc657b40e73d8138e233 --- /dev/null +++ b/toolbox/torchaudio/models/discriminators/waveform_metric_discriminator/yaml/discriminator_config.yaml @@ -0,0 +1,10 @@ +model_name: "waveform_metric_discriminator" + +sample_rate: 8000 +segment_size: 4 +n_fft: 512 +win_size: 200 +hop_size: 80 + +discriminator_dim: 16 +discriminator_in_channel: 2 diff --git a/toolbox/torchaudio/models/dtln/__init__.py b/toolbox/torchaudio/models/dtln/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/models/dtln/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dtln/configuration_dtln.py b/toolbox/torchaudio/models/dtln/configuration_dtln.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc91b38b660b9a2edfc6c65d54f885c8f89af3a --- /dev/null +++ b/toolbox/torchaudio/models/dtln/configuration_dtln.py @@ -0,0 +1,66 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class DTLNConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + fft_size: int = 200, + hop_size: int = 80, + win_type: str = "hann", + + encoder_size: int = 256, + + min_snr_db: float = -10, + max_snr_db: float = 20, + + lr: float = 0.001, + lr_scheduler: str = "CosineAnnealingLR", + lr_scheduler_kwargs: dict = None, + + max_epochs: int = 100, + clip_grad_norm: float = 10., + seed: int = 1234, + + num_workers: int = 4, + batch_size: int = 4, + eval_steps: int = 25000, + **kwargs + ): + super(DTLNConfig, self).__init__(**kwargs) + # transform + self.sample_rate = sample_rate + self.fft_size = fft_size + self.hop_size = hop_size + self.win_type = win_type + + # model params + self.encoder_size = encoder_size + + # data snr + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + + # train + self.lr = lr + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() + + self.max_epochs = max_epochs + self.clip_grad_norm = clip_grad_norm + self.seed = seed + + self.num_workers = num_workers + self.batch_size = batch_size + self.eval_steps = eval_steps + + +def main(): + config = DTLNConfig() + config.to_yaml_file("config.yaml") + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/dtln/inference_dtln.py b/toolbox/torchaudio/models/dtln/inference_dtln.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb8875592c6801229a5b417325b30161183ec50 --- /dev/null +++ b/toolbox/torchaudio/models/dtln/inference_dtln.py @@ -0,0 +1,137 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from pathlib import Path +import shutil +import tempfile, time +import zipfile + +import librosa +import numpy as np +import torch +import torchaudio + +torch.set_num_threads(1) + +from project_settings import project_path +from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig +from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNPretrainedModel, MODEL_FILE + +logger = logging.getLogger("toolbox") + + +class InferenceDTLN(object): + def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): + self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file + self.device = torch.device(device) + + logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") + config, model = self.load_models(self.pretrained_model_path_or_zip_file) + logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") + + self.config = config + self.model = model + self.model.to(device) + self.model.eval() + + def load_models(self, model_path: str): + model_path = Path(model_path) + if model_path.name.endswith(".zip"): + with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: + out_root = Path(tempfile.gettempdir()) / "nx_denoise" + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + model_path = out_root / model_path.stem + + config = DTLNConfig.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model = DTLNPretrainedModel.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model.to(self.device) + model.eval() + + shutil.rmtree(model_path) + return config, model + + def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # noisy_audio shape: [batch_size, n_samples] + enhanced_audio = self.denoise_offline(noisy_audio) + # enhanced_audio shape: [channels, num_samples] + enhanced_audio = enhanced_audio[0] + # enhanced_audio shape: [num_samples] + return enhanced_audio.cpu().numpy() + + def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + denoise = self.model.forward(noisy_audios) + + # denoise shape: [batch_size, 1, num_samples] + denoise = denoise[0] + # shape: [channels, num_samples] + return denoise + + def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + denoise = self.model.forward_chunk_by_chunk(noisy_audios) + + # denoise shape: [batch_size, 1, num_samples] + denoise = denoise[0] + # shape: [channels, num_samples] + return denoise + + +def main(): + model_zip_file = project_path / "trained_models/dtln-nx-dns3.zip" + infer_model = InferenceDTLN(model_zip_file) + + sample_rate = 8000 + noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav" + noisy_audio, sample_rate = librosa.load( + noisy_audio_file.as_posix(), + sr=sample_rate, + ) + duration = librosa.get_duration(y=noisy_audio, sr=sample_rate) + # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # offline + begin = time.time() + enhanced_audio = infer_model.denoise_offline(noisy_audio) + time_cost = time.time() - begin + print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + filename = "enhanced_audio_offline.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + # online + begin = time.time() + enhanced_audio = infer_model.denoise_online(noisy_audio) + time_cost = time.time() - begin + print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + filename = "enhanced_audio_online.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/dtln/modeling_dtln.py b/toolbox/torchaudio/models/dtln/modeling_dtln.py new file mode 100644 index 0000000000000000000000000000000000000000..99a7c8b8f5607faaadace5013045878d1130ebb2 --- /dev/null +++ b/toolbox/torchaudio/models/dtln/modeling_dtln.py @@ -0,0 +1,390 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://www.isca-archive.org/interspeech_2020/westhausen20_interspeech.pdf + +https://github.com/AkenoSyuRi/DTLNPytorch + +https://github.com/breizhn/DTLN + +数据集: DNS3 DNS-Challenge +信噪比从 DNS3 的 [0, 40] dB 调整为 [-5, 25] dB +信噪比级别从 5 个改到 30 个。 即: +[0dB, 10dB, 20dB, 30dB, 40dB] +改到: +[-5db, -4dB, -3dB, ..., 22dB, 23dB, 24dB, 25dB] + +窗长 32ms, 窗移 8ms + +在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。 + +""" +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT +from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig + + +class InstantLayerNormalization(nn.Module): + """ + Class implementing instant layer normalization. It can also be called + channel-wise layer normalization and was proposed by + Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2) + """ + + def __init__(self, channels): + super(InstantLayerNormalization, self).__init__() + self.epsilon = 1e-7 + self.gamma = nn.Parameter(torch.ones(1, 1, channels), requires_grad=True) + self.beta = nn.Parameter(torch.zeros(1, 1, channels), requires_grad=True) + self.register_parameter("gamma", self.gamma) + self.register_parameter("beta", self.beta) + + def forward(self, inputs: torch.Tensor): + # calculate mean of each frame + mean = torch.mean(inputs, dim=-1, keepdim=True) + + # calculate variance of each frame + variance = torch.mean(torch.square(inputs - mean), dim=-1, keepdim=True) + # calculate standard deviation + std = torch.sqrt(variance + self.epsilon) + outputs = (inputs - mean) / std + # scale with gamma + outputs = outputs * self.gamma + # add the bias beta + outputs = outputs + self.beta + # return output + return outputs + + +class SeperationBlock(nn.Module): + def __init__(self, + input_size: int = 257, + hidden_size: int = 128, + dropout: float = 0.25, + ): + super(SeperationBlock, self).__init__() + self.rnn1 = nn.LSTM(input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + batch_first=True, + dropout=0.0, + bidirectional=False, + ) + self.rnn2 = nn.LSTM(input_size=hidden_size, + hidden_size=hidden_size, + num_layers=1, + batch_first=True, + dropout=0.0, + bidirectional=False, + ) + self.drop = nn.Dropout(dropout) + + self.dense = nn.Linear(hidden_size, input_size) + self.sigmoid = nn.Sigmoid() + + def forward(self, x: torch.Tensor, in_states: torch.Tensor = None): + if in_states is None: + hx1 = None + hx2 = None + else: + h1_in, c1_in = in_states[:1, :, :, 0], in_states[:1, :, :, 1] + h2_in, c2_in = in_states[1:, :, :, 0], in_states[1:, :, :, 1] + hx1 = (h1_in, c1_in) + hx2 = (h2_in, c2_in) + + x1, (h1, c1) = self.rnn1.forward(x, hx=hx1) + x1 = self.drop(x1) + x2, (h2, c2) = self.rnn2.forward(x1, hx=hx2) + x2 = self.drop(x2) + + mask = self.dense(x2) + mask = self.sigmoid(mask) + + h = torch.cat((h1, h2), dim=0) + c = torch.cat((c1, c2), dim=0) + out_states = torch.stack((h, c), dim=-1) + return mask, out_states + + +MODEL_FILE = "model.pt" + + +class DTLNModel(nn.Module): + def __init__(self, + fft_size: int = 512, + hop_size: int = 128, + win_type: str = "hamming", + encoder_size: int = 256, + ): + super(DTLNModel, self).__init__() + self.fft_size = fft_size + self.hop_size = hop_size + self.encoder_size = encoder_size + + self.stft = ConvSTFT( + nfft=fft_size, + win_size=fft_size, + hop_size=hop_size, + win_type=win_type, + power=None, + requires_grad=False + ) + self.istft = ConviSTFT( + nfft=fft_size, + win_size=fft_size, + hop_size=hop_size, + win_type=win_type, + requires_grad=False + ) + + self.sep1 = SeperationBlock(input_size=(fft_size // 2 + 1), + # hidden_size=128, + hidden_size=self.encoder_size // 2, + dropout=0.25, + ) + + self.encoder_conv1 = nn.Conv1d(in_channels=fft_size, + out_channels=self.encoder_size, + kernel_size=1, + stride=1, + bias=False, + ) + + # self.encoder_norm1 = nn.InstanceNorm1d(num_features=self.encoder_size, eps=1e-7, affine=True) + self.encoder_norm1 = InstantLayerNormalization(channels=self.encoder_size) + + self.sep2 = SeperationBlock(input_size=self.encoder_size, + # hidden_size=128, + hidden_size=self.encoder_size // 2, + dropout=0.25, + ) + + self.decoder_conv1 = nn.Conv1d(in_channels=self.encoder_size, + out_channels=fft_size, + kernel_size=1, + stride=1, + bias=False, + ) + + def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: + if signal.dim() == 2: + signal = torch.unsqueeze(signal, dim=1) + _, _, n_samples = signal.shape + remainder = (n_samples - self.fft_size) % self.hop_size + if remainder > 0: + n_samples_pad = self.hop_size - remainder + signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) + return signal + + def forward(self, + noisy: torch.Tensor, + ): + num_samples = noisy.shape[-1] + noisy = self.signal_prepare(noisy) + batch_size, _, num_samples_pad = noisy.shape + # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") + + denoise_frame, _, _ = self.forward_chunk(noisy) + denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad) + # denoise shape: [b, num_samples_pad] + + denoise = denoise[:, :num_samples] + # denoise shape: [b, num_samples] + denoise = torch.unsqueeze(denoise, dim=1) + # denoise shape: [b, 1, num_samples] + return denoise + + def forward_chunk(self, + noisy: torch.Tensor, + in_state1: torch.Tensor = None, + in_state2: torch.Tensor = None, + ): + # noisy shape: [b, 1, num_samples] + spec = self.stft.forward(noisy) + # spec shape: [b, f, t], torch.complex64 + # t = (num_samples - win_size) / hop_size + 1 + spec = torch.view_as_real(spec) + # spec shape: [b, f, t, 2] + real = spec[..., 0] + imag = spec[..., 1] + mag = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag, real) + # shape: [b, f, t] + mag = mag.permute(0, 2, 1) + phase = phase.permute(0, 2, 1) + # shape: [b, t, f] + + mask, out_state1 = self.sep1.forward(mag, in_state1) + # mask shape: [b, t, f] + estimated_mag = mask * mag + + s1_stft = estimated_mag * torch.exp((1j * phase)) + # s1_stft shape: [b, t, f], torch.complex64 + y1 = torch.fft.irfft2(s1_stft, dim=-1) + # y1 shape: [b, t, fft_size], torch.float32 + y1 = y1.permute(0, 2, 1) + # y1 shape: [b, fft_size, t] + + encoded_f = self.encoder_conv1.forward(y1) + # shape: [b, c, t] + encoded_f = encoded_f.permute(0, 2, 1) + # shape: [b, t, c] + encoded_f_norm = self.encoder_norm1.forward(encoded_f) + # shape: [b, t, c] + + mask_2, out_state2 = self.sep2.forward(encoded_f_norm, in_state2) + # shape: [b, t, c] + estimated = mask_2 * encoded_f + estimated = estimated.permute(0, 2, 1) + # shape: [b, c, t] + + denoise_frame = self.decoder_conv1.forward(estimated) + # shape: [b, fft_size, t] + + return denoise_frame, out_state1, out_state2 + + def forward_chunk_by_chunk(self, noisy: torch.Tensor): + noisy = self.signal_prepare(noisy) + # noisy shape: [b, 1, num_samples] + batch_size, _, num_samples_pad = noisy.shape + # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") + + t = (num_samples_pad - self.fft_size) // self.hop_size + 1 + overlap_size = self.fft_size - self.hop_size + + denoise_list = list() + out_state1 = None + out_state2 = None + denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype) + for i in range(t): + begin = i * self.hop_size + end = begin + self.fft_size + sub_noisy = noisy[:, :, begin: end] + # noisy shape: [b, 1, frame_size] + with torch.no_grad(): + sub_denoise_frame, out_state1, out_state2 = self.forward_chunk(sub_noisy, out_state1, out_state2) + # sub_denoise_frame shape: [b, fft_size, 1] + sub_denoise_frame = sub_denoise_frame[:, :, 0] + # sub_denoise_frame shape: [b, fft_size] + + sub_denoise_frame[:, :overlap_size] += denoise_cache + denoise_out = sub_denoise_frame[:, :self.hop_size] + denoise_cache = sub_denoise_frame[:, self.hop_size:] + # denoise_cache shape: [b, hop_size] + + denoise_list.append(denoise_out) + + denoise = torch.concat(denoise_list, dim=-1) + # denoise shape: [b, num_samples] + denoise = torch.unsqueeze(denoise, dim=1) + # denoise shape: [b, 1, num_samples] + return denoise + + def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int): + # overlap and add + + # denoise_frame shape: [b, fft_size, t] + denoise = torch.nn.functional.fold( + denoise_frame, + output_size=(num_samples, 1), + kernel_size=(self.fft_size, 1), + padding=(0, 0), + stride=(self.hop_size, 1), + ) + # denoise shape: [b, 1, num_samples, 1] + denoise = denoise.reshape(batch_size, -1) + # denoise shape: [b, num_samples] + return denoise + + +class DTLNPretrainedModel(DTLNModel): + def __init__(self, + config: DTLNConfig, + ): + super(DTLNPretrainedModel, self).__init__( + fft_size=config.fft_size, + hop_size=config.hop_size, + win_type=config.win_type, + encoder_size=config.encoder_size, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = DTLNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + config = DTLNConfig(fft_size=512, + hop_size=128, + ) + model = DTLNPretrainedModel(config) + model.eval() + + noisy = torch.randn(size=(1, 16000), dtype=torch.float32) + + with torch.no_grad(): + denoise = model.forward(noisy) + print(f"denoise.shape: {denoise.shape}") + print(denoise[:, :, 300: 302]) + print(denoise[:, :, 8000: 8002]) + print(denoise[:, :, 15600: 15602]) + print(denoise[:, :, 15680: 15682]) + print(denoise[:, :, 15760: 15762]) + print(denoise[:, :, 15840: 15842]) + + denoise = model.forward_chunk_by_chunk(noisy) + print(f"denoise.shape: {denoise.shape}") + # denoise = denoise[:, :, (config.fft_size - config.hop_size):] + print(denoise[:, :, 300: 302]) + print(denoise[:, :, 8000: 8002]) + print(denoise[:, :, 15600: 15602]) + print(denoise[:, :, 15680: 15682]) + print(denoise[:, :, 15760: 15762]) + print(denoise[:, :, 15840: 15842]) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/dtln/yaml/config-160.yaml b/toolbox/torchaudio/models/dtln/yaml/config-160.yaml new file mode 100644 index 0000000000000000000000000000000000000000..00e27a12d5d4b20c17cbc9cd0265146339c91847 --- /dev/null +++ b/toolbox/torchaudio/models/dtln/yaml/config-160.yaml @@ -0,0 +1,23 @@ +model_name: "DTLN" + +sample_rate: 8000 +fft_size: 160 +hop_size: 80 +win_type: hann + +max_snr_db: 20 +min_snr_db: -10 + +encoder_size: 256 + +max_epochs: 100 +batch_size: 64 +num_workers: 4 +seed: 1234 +eval_steps: 25000 + +lr: 0.001 +lr_scheduler: CosineAnnealingLR +lr_scheduler_kwargs: {} + +clip_grad_norm: 10.0 diff --git a/toolbox/torchaudio/models/dtln/yaml/config-256.yaml b/toolbox/torchaudio/models/dtln/yaml/config-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fc901ad9f1be59d3915da55ba7a06279e7850c02 --- /dev/null +++ b/toolbox/torchaudio/models/dtln/yaml/config-256.yaml @@ -0,0 +1,23 @@ +model_name: "DTLN" + +sample_rate: 8000 +fft_size: 256 +hop_size: 128 +win_type: hann + +max_snr_db: 20 +min_snr_db: -10 + +encoder_size: 256 + +max_epochs: 100 +batch_size: 64 +num_workers: 4 +seed: 1234 +eval_steps: 25000 + +lr: 0.001 +lr_scheduler: CosineAnnealingLR +lr_scheduler_kwargs: {} + +clip_grad_norm: 10.0 diff --git a/toolbox/torchaudio/models/dtln/yaml/config-512.yaml b/toolbox/torchaudio/models/dtln/yaml/config-512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f69393612197a648b714fcdc0c53d7dc82ba235 --- /dev/null +++ b/toolbox/torchaudio/models/dtln/yaml/config-512.yaml @@ -0,0 +1,29 @@ +model_name: "DTLN" + +# spec +sample_rate: 8000 +fft_size: 512 +hop_size: 128 +win_type: hann + +# data +max_snr_db: 20 +min_snr_db: -10 + +# model +encoder_size: 512 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +num_workers: 4 +batch_size: 64 +eval_steps: 15000 diff --git a/toolbox/torchaudio/models/ehnet/__init__.py b/toolbox/torchaudio/models/ehnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/ehnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/ehnet/modeling_ehnet.py b/toolbox/torchaudio/models/ehnet/modeling_ehnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5afecbfcdff63bddb1df740d6ee324a986622141 --- /dev/null +++ b/toolbox/torchaudio/models/ehnet/modeling_ehnet.py @@ -0,0 +1,131 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/1805.00579 + +https://github.com/haoxiangsnr/A-Convolutional-Recurrent-Neural-Network-for-Real-Time-Speech-Enhancement + +""" +import torch +import torch.nn as nn + + +class CausalConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 2), + stride=(2, 1), + padding=(0, 1) + ) + self.norm = nn.BatchNorm2d(num_features=out_channels) + self.activation = nn.ELU() + + def forward(self, x): + """ + 2D Causal convolution. + Args: + x: [B, C, F, T] + + Returns: + [B, C, F, T] + """ + x = self.conv(x) + x = x[:, :, :, :-1] # chomp size + x = self.norm(x) + x = self.activation(x) + return x + + +class CausalTransConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)): + super().__init__() + self.conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 2), + stride=(2, 1), + output_padding=output_padding + ) + self.norm = nn.BatchNorm2d(num_features=out_channels) + if is_last: + self.activation = nn.ReLU() + else: + self.activation = nn.ELU() + + def forward(self, x): + """ + 2D Causal convolution. + Args: + x: [B, C, F, T] + + Returns: + [B, C, F, T] + """ + x = self.conv(x) + x = x[:, :, :, :-1] # chomp size + x = self.norm(x) + x = self.activation(x) + return x + + +class CRN(nn.Module): + """ + Input: [batch size, channels=1, T, n_fft] + Output: [batch size, T, n_fft] + """ + + def __init__(self): + super(CRN, self).__init__() + # Encoder + self.conv_block_1 = CausalConvBlock(1, 16) + self.conv_block_2 = CausalConvBlock(16, 32) + self.conv_block_3 = CausalConvBlock(32, 64) + self.conv_block_4 = CausalConvBlock(64, 128) + self.conv_block_5 = CausalConvBlock(128, 256) + + # LSTM + self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True) + + self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128) + self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64) + self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32) + self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0)) + self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True) + + def forward(self, x): + self.lstm_layer.flatten_parameters() + + e_1 = self.conv_block_1(x) + e_2 = self.conv_block_2(e_1) + e_3 = self.conv_block_3(e_2) + e_4 = self.conv_block_4(e_3) + e_5 = self.conv_block_5(e_4) # [2, 256, 4, 200] + + batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape + + # [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024] + lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1) + lstm_out, _ = self.lstm_layer(lstm_in) # [2, 200, 1024] + lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size) # [2, 256, 4, 200] + + d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1)) + d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1)) + d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1)) + d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1)) + d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1)) + + return d_5 + + +def main(): + layer = CRN() + a = torch.rand(2, 1, 161, 200) + print(layer(a).shape) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/frcrn/__init__.py b/toolbox/torchaudio/models/frcrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/frcrn/complex_nn.py b/toolbox/torchaudio/models/frcrn/complex_nn.py new file mode 100644 index 0000000000000000000000000000000000000000..49f248e501c28721cf540fc64d89e8c05fb1c95f --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/complex_nn.py @@ -0,0 +1,258 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Union, Tuple + +import torch +import torch.nn as nn + +from toolbox.torchaudio.models.frcrn.uni_deep_fsmn import UniDeepFsmn + + +class ComplexUniDeepFsmn(nn.Module): + def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): + super(ComplexUniDeepFsmn, self).__init__() + + self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) + self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) + self.fsmn_re_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) + self.fsmn_im_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) + + def forward(self, x: torch.Tensor): + """ + :param x: torch.Tensor, shape: [b, c, h, t, 2] + :return: torch.Tensor, shape: [b, h, t, 2] + """ + b, c, h, t, d = x.size() + x = torch.reshape(x, shape=(b, c * h, t, d)) + # x shape: [b, h', t, 2] + x = torch.transpose(x, dim0=1, dim1=2) + # x shape: [b, t, h', 2] + + real_l1 = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) + imaginary_l1 = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) + # real, image shape: [b, t, h'] + + real = self.fsmn_re_l2(real_l1) - self.fsmn_im_l2(imaginary_l1) + imaginary = self.fsmn_re_l2(imaginary_l1) + self.fsmn_im_l2(real_l1) + # real, image shape: [b, t, h'] + + output = torch.stack(tensors=(real, imaginary), dim=-1) + # output shape: [b, t, h', 2] + output = torch.transpose(output, dim0=1, dim1=2) + # output shape: [b, h', t, 2] + output = torch.reshape(output, shape=(b, c, h, t, d)) + # output shape: [b, c, h, t, 2] + return output + + +class ComplexUniDeepFsmnL1(nn.Module): + def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): + super(ComplexUniDeepFsmnL1, self).__init__() + self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) + self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) + + def forward(self, x: torch.Tensor): + b, c, h, t, d = x.size() + x = torch.transpose(x, dim0=1, dim1=3) + # x shape: [b, t, h, c, 2] + x = torch.reshape(x, shape=(b * t, h, c, d)) + # x shape: [b*t, h, c, 2] + + real = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) + imaginary = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) + # real, image shape: [b*t, h, c] + + output = torch.stack(tensors=(real, imaginary), dim=-1) + # output shape: [b*t, h, c, 2] + output = torch.reshape(output, shape=(b, t, h, c, d)) + # output shape: [b, t, h, c, 2] + output = torch.transpose(output, dim0=1, dim1=3) + # output shape: [b, c, h, t, 2] + return output + + +class ComplexConv2d(nn.Module): + # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + **kwargs + ): + super().__init__() + + # Model components + self.conv_re = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs + ) + self.conv_im = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs + ) + + def forward(self, x: torch.Tensor): + """ + + :param x: torch.Tensor, shape: [b, c, h, w, 2] + :return: + """ + real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) + imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) + + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexConvTranspose2d(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + output_padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias=True, + **kwargs + ): + super().__init__() + + # Model components + self.tconv_re = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs + ) + self.tconv_im = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + **kwargs + ) + + def forward(self, x: torch.Tensor): + """ + :param x: torch.Tensor, shape: [b, c, h, w, 2] + :return: + """ + real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) + imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) + + output = torch.stack((real, imaginary), dim=-1) + return output + + +class ComplexBatchNorm2d(nn.Module): + def __init__(self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + **kwargs + ): + super().__init__() + self.bn_re = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs + ) + self.bn_im = nn.BatchNorm2d( + num_features=num_features, + momentum=momentum, + affine=affine, + eps=eps, + track_running_stats=track_running_stats, + **kwargs + ) + + def forward(self, x: torch.Tensor): + real = self.bn_re(x[..., 0]) + imag = self.bn_im(x[..., 1]) + + output = torch.stack((real, imag), dim=-1) + return output + + +def main(): + # x = torch.rand(size=(1, 1, 32, 200, 2)) + # fsmn = ComplexUniDeepFsmn( + # input_dim=32, + # hidden_size=64, + # ) + # result = fsmn.forward(x) + # print(result.shape) + + # x = torch.rand(size=(1, 32, 32, 200, 2)) + # fsmn = ComplexUniDeepFsmnL1( + # input_dim=32, + # hidden_size=64, + # ) + # result = fsmn.forward(x) + # print(result.shape) + + # x = torch.rand(size=(1, 32, 200, 200, 2)) + x = torch.rand(size=(1, 1, 320, 200, 2)) + conv2d = ComplexConv2d( + in_channels=1, + out_channels=128, + kernel_size=(5, 2), + stride=(2, 1), + padding=(0, 1), + ) + result = conv2d.forward(x) + print(result.shape) + + # x = torch.rand(size=(1, 32, 200, 200, 2)) + # x = torch.rand(size=(1, 64, 15, 2000, 2)) + # tconv = ComplexConvTranspose2d( + # in_channels=64, + # out_channels=32, + # kernel_size=(3, 3), + # stride=(2, 1), + # padding=(0, 1), + # ) + # result = tconv.forward(x) + # print(result.shape) + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/frcrn/configuration_frcrn.py b/toolbox/torchaudio/models/frcrn/configuration_frcrn.py new file mode 100644 index 0000000000000000000000000000000000000000..08292097bf31a1f62438382f33d28fe33fbf8e14 --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/configuration_frcrn.py @@ -0,0 +1,80 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/checkpoints/FRCRN_SE_16K/config.yaml +https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/config/inference/FRCRN_SE_16K.yaml + +""" +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class FRCRNConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + segment_size: int = 32000, + nfft: int = 512, + win_size: int = 512, + hop_size: int = 128, + win_type: str = "hann", + + use_complex_networks: bool = True, + model_depth: int = 20, + model_complexity: int = 45, + + min_snr_db: float = -10, + max_snr_db: float = 20, + + num_workers: int = 4, + batch_size: int = 4, + eval_steps: int = 25000, + + lr: float = 0.001, + lr_scheduler: str = "CosineAnnealingLR", + lr_scheduler_kwargs: dict = None, + + max_epochs: int = 100, + weight_decay: float = 0.00001, + clip_grad_norm: float = 10., + seed: int = 1234, + num_gpus: int = -1, + + **kwargs + ): + super(FRCRNConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.segment_size = segment_size + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.use_complex_networks = use_complex_networks + self.model_depth = model_depth + self.model_complexity = model_complexity + + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + + self.num_workers = num_workers + self.batch_size = batch_size + self.eval_steps = eval_steps + + self.lr = lr + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() + + self.max_epochs = max_epochs + self.weight_decay = weight_decay + self.clip_grad_norm = clip_grad_norm + self.seed = seed + self.num_gpus = num_gpus + + +def main(): + config = FRCRNConfig() + config.to_yaml_file("config.yaml") + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/frcrn/conv_stft.py b/toolbox/torchaudio/models/frcrn/conv_stft.py new file mode 100644 index 0000000000000000000000000000000000000000..8dcfc5937c80a3823146ced2629efc8971857969 --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/conv_stft.py @@ -0,0 +1,147 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py +""" +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.signal import get_window + + +def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False): + if win_type == "None" or win_type is None: + window = np.ones(win_size) + else: + window = get_window(win_type, win_size, fftbins=True)**0.5 + + fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size] + real_kernel = np.real(fourier_basis) + image_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, image_kernel], 1).T + + if inverse: + kernel = np.linalg.pinv(kernel).T + + kernel = kernel * window + kernel = kernel[:, None, :] + result = ( + torch.from_numpy(kernel.astype(np.float32)), + torch.from_numpy(window[None, :, None].astype(np.float32)) + ) + return result + + +class ConvSTFT(nn.Module): + + def __init__(self, + nfft: int, + win_size: int, + hop_size: int, + win_type: str = "hamming", + feature_type: str = "real", + requires_grad: bool = False): + super(ConvSTFT, self).__init__() + + if nfft is None: + self.nfft = int(2**np.ceil(np.log2(win_size))) + else: + self.nfft = nfft + + kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type) + self.weight = nn.Parameter(kernel, requires_grad=requires_grad) + + self.win_size = win_size + self.hop_size = hop_size + + self.stride = hop_size + self.dim = self.nfft + self.feature_type = feature_type + + def forward(self, inputs: torch.Tensor): + if inputs.dim() == 2: + inputs = torch.unsqueeze(inputs, 1) + + outputs = F.conv1d(inputs, self.weight, stride=self.stride) + + if self.feature_type == "complex": + return outputs + else: + dim = self.dim // 2 + 1 + real = outputs[:, :dim, :] + imag = outputs[:, dim:, :] + mags = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag, real) + return mags, phase + + +class ConviSTFT(nn.Module): + + def __init__(self, + win_size: int, + hop_size: int, + nfft: int = None, + win_type: str = "hamming", + feature_type: str = "real", + requires_grad: bool = False): + super(ConviSTFT, self).__init__() + if nfft is None: + self.nfft = int(2**np.ceil(np.log2(win_size))) + else: + self.nfft = nfft + + kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True) + self.weight = nn.Parameter(kernel, requires_grad=requires_grad) + + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.stride = hop_size + self.dim = self.nfft + self.feature_type = feature_type + + self.register_buffer("window", window) + self.register_buffer("enframe", torch.eye(win_size)[:, None, :]) + + def forward(self, + inputs: torch.Tensor, + phase: torch.Tensor = None): + """ + :param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags) + :param phase: torch.Tensor, shape: [b, n//2+1, t] + :return: + """ + if phase is not None: + real = inputs * torch.cos(phase) + imag = inputs * torch.sin(phase) + inputs = torch.cat([real, imag], 1) + outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) + + # this is from torch-stft: https://github.com/pseeth/torch-stft + t = self.window.repeat(1, 1, inputs.size(-1))**2 + coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) + outputs = outputs / (coff + 1e-8) + return outputs + + +def main(): + stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex") + istft = ConviSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex") + + mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32) + + spec = stft.forward(mixture) + # shape: [batch_size, freq_bins, time_steps] + print(spec.shape) + + waveform = istft.forward(spec) + # shape: [batch_size, channels, num_samples] + print(waveform.shape) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/frcrn/inference_frcrn.py b/toolbox/torchaudio/models/frcrn/inference_frcrn.py new file mode 100644 index 0000000000000000000000000000000000000000..516fc0ebbb07cde8799fc625c6725958a812a341 --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/inference_frcrn.py @@ -0,0 +1,115 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from pathlib import Path +import shutil +import tempfile, time +import zipfile + +import librosa +import numpy as np +import torch +import torchaudio + +torch.set_num_threads(1) + +from project_settings import project_path +from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig +from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRNPretrainedModel, MODEL_FILE + +logger = logging.getLogger("toolbox") + + +class InferenceFRCRN(object): + def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): + self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file + self.device = torch.device(device) + + logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") + config, model = self.load_models(self.pretrained_model_path_or_zip_file) + logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") + + self.config = config + self.model = model + self.model.to(device) + self.model.eval() + + def load_models(self, model_path: str): + model_path = Path(model_path) + if model_path.name.endswith(".zip"): + with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: + out_root = Path(tempfile.gettempdir()) / "nx_denoise" + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + model_path = out_root / model_path.stem + + config = FRCRNConfig.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model = FRCRNPretrainedModel.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + model.to(self.device) + model.eval() + + shutil.rmtree(model_path) + return config, model + + def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # noisy_audio shape: [batch_size, n_samples] + enhanced_audio = self.enhancement_by_tensor(noisy_audio) + # enhanced_audio shape: [channels, num_samples] + enhanced_audio = enhanced_audio[0] + # enhanced_audio shape: [num_samples] + return enhanced_audio.cpu().numpy() + + def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + # noisy_audio shape: [batch_size, num_samples] + noisy_audios = noisy_audio.to(self.device) + + with torch.no_grad(): + est_spec, est_wav, est_mask = self.model.forward(noisy_audios) + + # shape: [batch_size, num_samples] + enhanced_audio = torch.unsqueeze(est_wav, dim=1) + # shape: [batch_size, 1, num_samples] + + enhanced_audio = enhanced_audio[0] + # shape: [channels, num_samples] + return enhanced_audio + + +def main(): + model_zip_file = project_path / "trained_models/frcrn-dns3.zip" + infer_model = InferenceFRCRN(model_zip_file) + + sample_rate = 8000 + noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_5.wav" + noisy_audio, sample_rate = librosa.load( + noisy_audio_file.as_posix(), + sr=sample_rate, + ) + duration = librosa.get_duration(y=noisy_audio, sr=sample_rate) + # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + begin = time.time() + enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio) + time_cost = time.time() - begin + print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") + + filename = "enhanced_audio.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/frcrn/modeling_frcrn.py b/toolbox/torchaudio/models/frcrn/modeling_frcrn.py new file mode 100644 index 0000000000000000000000000000000000000000..7d1debbd9b789550af3c20108a225e82cc044994 --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/modeling_frcrn.py @@ -0,0 +1,345 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2206.07293 + +https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/frcrn.py +https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/frcrn_se/frcrn.py + +https://github.com/modelscope/ClearerVoice-Studio/tree/main/clearvoice/clearvoice/models/frcrn_se + +""" +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig +from toolbox.torchaudio.models.frcrn.conv_stft import ConviSTFT, ConvSTFT +from toolbox.torchaudio.models.frcrn.unet import UNet + + +class FRCRN(nn.Module): + """ Frequency Recurrent CRN """ + + def __init__(self, + use_complex_networks: bool = True, + model_complexity: int = 45, + model_depth: int = 14, + padding_mode: str = "zeros", + nfft: int = 640, + win_size: int = 640, + hop_size: int = 320, + win_type: str = "hann", + ): + """ + :param use_complex_networks: bool, Whether to use complex networks. + :param model_complexity: int, define the model complexity with the number of layers + :param model_depth: int, Only two options are available : 10, 20 + :param padding_mode: str, Encoder's convolution filter. 'zeros', 'reflect' + :param nfft: int, number of Short Time Fourier Transform (STFT) points + :param win_size: int, length of window used for defining one frame of sample points + :param hop_size: int, length of window shifting (equivalent to hop_size) + :param win_type: str, windowing type used in STFT, eg. 'hanning', 'hamming' + """ + super().__init__() + self.freq_bins = nfft // 2 + 1 + + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.eps = 1e-8 + + self.stft = ConvSTFT( + nfft=self.nfft, + win_size=self.win_size, + hop_size=self.hop_size, + win_type=self.win_type, + feature_type="complex", + requires_grad=False + ) + self.istft = ConviSTFT( + nfft=self.nfft, + win_size=self.win_size, + hop_size=self.hop_size, + win_type=self.win_type, + feature_type="complex", + requires_grad=False + ) + self.unet = UNet( + in_channels=1, + use_complex_networks=use_complex_networks, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode + ) + self.unet2 = UNet( + in_channels=1, + use_complex_networks=use_complex_networks, + model_complexity=model_complexity, + model_depth=model_depth, + padding_mode=padding_mode + ) + + def forward(self, noisy: torch.Tensor): + """ + :param noisy: torch.Tensor, shape: [b, n_samples] or [b, c, n_samples] + :return: + """ + if noisy.dim() == 2: + noisy = torch.unsqueeze(noisy, dim=1) + _, _, n_samples = noisy.shape + remainder = (n_samples - self.win_size) % self.hop_size + if remainder > 0: + n_samples_pad = self.hop_size - remainder + noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0) + + # [batch_size, freq_bins * 2, num_samples] + cmp_spec = self.stft.forward(noisy) + # [batch_size, 1, freq_bins * 2, time_steps] + # time_steps = (num_samples - win_size) / hop_size + 1 + cmp_spec = torch.unsqueeze(cmp_spec, 1) + + # [batch_size, 2, freq_bins, time_steps] + cmp_spec = torch.cat([ + cmp_spec[:, :, :self.freq_bins, :], + cmp_spec[:, :, self.freq_bins:, :], + ], dim=1) + + # [batch_size, 2, freq_bins, time_steps, 1] + cmp_spec = torch.unsqueeze(cmp_spec, dim=4) + + cmp_spec = torch.transpose(cmp_spec, 1, 4) + # [batch_size, 1, freq_bins, time_steps, 2] + + unet1_out = self.unet.forward(cmp_spec) + cmp_mask1 = torch.tanh(unet1_out) + unet2_out = self.unet2.forward(unet1_out) + cmp_mask2 = torch.tanh(unet2_out) + + # est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1) + + cmp_mask2 = cmp_mask2 + cmp_mask1 + est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) + # est_wav shape: [b, n_samples] + + est_wav = est_wav[:, :n_samples] + return est_spec, est_wav, est_mask + + def apply_mask(self, + cmp_spec: torch.Tensor, + cmp_mask: torch.Tensor, + ): + """ + :param cmp_spec: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2] + :param cmp_mask: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2] + :return: + """ + est_spec = torch.cat( + tensors=[ + cmp_spec[..., 0] * cmp_mask[..., 0] - cmp_spec[..., 1] * cmp_mask[..., 1], + cmp_spec[..., 0] * cmp_mask[..., 1] + cmp_spec[..., 1] * cmp_mask[..., 0] + ], dim=1 + ) + # est_spec shape: [b, 2, n//2+1, t] + est_spec = torch.cat(tensors=[est_spec[:, 0, :, :], est_spec[:, 1, :, :]], dim=1) + # est_spec shape: [b, n+2, t] + + # cmp_mask shape: [b, 1, n//2+1, t, 2] + cmp_mask = torch.squeeze(cmp_mask, dim=1) + # cmp_mask shape: [b, n//2+1, t, 2] + cmp_mask = torch.cat(tensors=[cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], dim=1) + # cmp_mask shape: [b, n+2, t] + + # est_spec shape: [b, n+2, t] + est_wav = self.istft(est_spec) + # est_wav shape: [b, 1, n_samples] + est_wav = torch.squeeze(est_wav, 1) + # est_wav shape: [b, n_samples] + return est_spec, est_wav, cmp_mask + + def get_params(self, weight_decay=0.0): + """ + 为可训练参数配置 weight_decay (权重衰减) 的作用是实现 L2 正则化。 + 1. 防止过拟合: 通过向损失函数添加参数的 L2 范数 (平方和) 作为惩罚项, weight_decay 会限制模型权重的大小. + 这使得模型倾向于学习更小的权重值, 降低对训练数据的过度敏感, 从而提高泛化能力. + 2. 控制模型复杂度: 权重衰减直接作用于优化过程, 在梯度更新时对权重进行衰减, + 公式: weight = weight - lr * (gradient + weight_decay * weight). + 这相当于在梯度下降中额外引入了一个与当前权重值成正比的衰减力, 抑制权重快速增长. + 3. 与优化器的具体实现相关 + 在 SGD 等传统优化器中, weight_decay 直接等价于 L2 正则化. + 在 Adam 优化器中, 权重衰减的实现与参数更新耦合, 可能因学习率调整而效果减弱. + 在 AdamW 优化器改进了这一点, 将权重衰减与学习率解耦, 使其更符合 L2 正则化的理论效果. + + 注意: + 值过大会导致欠拟合, 过小则正则化效果弱, 常用范围是 1e-4到 1e-2. + 某些场景 (如 BatchNorm 层) 可能需要通过参数分组对不同层设置不同的 weight_decay. + :param weight_decay: + :return: + """ + weights, biases = [], [] + for name, param in self.named_parameters(): + if "bias" in name: + biases += [param] + else: + weights += [param] + + params = [{ + 'params': weights, + 'weight_decay': weight_decay, + }, { + 'params': biases, + 'weight_decay': 0.0, + }] + return params + + def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): + """ + + :param est_mask: torch.Tensor, shape: [b, n+2, t] + :param clean: + :param noisy: + :return: + """ + clean_stft = self.stft(clean) + clean_re = clean_stft[:, :self.freq_bins, :] + clean_im = clean_stft[:, self.freq_bins:, :] + + noisy_stft = self.stft(noisy) + noisy_re = noisy_stft[:, :self.freq_bins, :] + noisy_im = noisy_stft[:, self.freq_bins:, :] + + noisy_power = noisy_re ** 2 + noisy_im ** 2 + + sr = clean_re + yr = noisy_re + si = clean_im + yi = noisy_im + y_pow = noisy_power + # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8) + gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps) + # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8) + gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps) + + gth_mask_re[gth_mask_re > 2] = 1 + gth_mask_re[gth_mask_re < -2] = -1 + gth_mask_im[gth_mask_im > 2] = 1 + gth_mask_im[gth_mask_im < -2] = -1 + + mask_re = est_mask[:, :self.freq_bins, :] + mask_im = est_mask[:, self.freq_bins:, :] + + loss_re = F.mse_loss(gth_mask_re, mask_re) + loss_im = F.mse_loss(gth_mask_im, mask_im) + + loss = loss_re + loss_im + return loss + + +MODEL_FILE = "model.pt" + + +class FRCRNPretrainedModel(FRCRN): + def __init__(self, + config: FRCRNConfig, + ): + super(FRCRNPretrainedModel, self).__init__( + use_complex_networks=config.use_complex_networks, + model_complexity=config.model_complexity, + model_depth=config.model_depth, + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + win_type=config.win_type, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = FRCRNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + model = FRCRN( + use_complex_networks=True, + model_complexity=-1, + model_depth=10, + padding_mode="zeros", + nfft=128, + win_size=128, + hop_size=64, + win_type="hann", + ) + + # model = FRCRN( + # use_complex_networks=True, + # model_complexity=-1, + # model_depth=14, + # padding_mode="zeros", + # nfft=640, + # win_size=640, + # hop_size=320, + # win_type="hann", + # ) + + # model = FRCRN( + # use_complex_networks=True, + # model_complexity=45, + # model_depth=20, + # padding_mode="zeros", + # nfft=512, + # win_size=512, + # hop_size=256, + # win_type="hann", + # ) + + mixture = torch.rand(size=(1, 32000), dtype=torch.float32) + + est_spec, est_wav, est_mask = model.forward(mixture) + print(est_spec.shape) + print(est_wav.shape) + print(est_mask.shape) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/frcrn/unet.py b/toolbox/torchaudio/models/frcrn/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcaf9a3dd03f3ac5b5db046e954edefdba8df4b --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/unet.py @@ -0,0 +1,374 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Union, Tuple + +import torch +import torch.nn as nn + +from toolbox.torchaudio.models.frcrn import complex_nn + + +class SELayer(nn.Module): + def __init__(self, channels: int, reduction: int = 16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + self.fc_r = nn.Sequential( + nn.Linear(channels, channels // reduction), + nn.ReLU(inplace=True), + nn.Linear(channels // reduction, channels), + nn.Sigmoid() + ) + self.fc_i = nn.Sequential( + nn.Linear(channels, channels // reduction), + nn.ReLU(inplace=True), + nn.Linear(channels // reduction, channels), + nn.Sigmoid() + ) + + def forward(self, x: torch.Tensor): + b, c, _, _, _ = x.size() + x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) + x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) + + y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(b, c, 1, 1, 1) + y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(b, c, 1, 1, 1) + + y = torch.cat(tensors=[y_r, y_i], dim=4) + return x * y + + +class Encoder(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int]] = None, + use_complex_networks: bool = False, + padding_mode: str = "zeros" + ): + super().__init__() + if padding is None: + padding = [(k - 1) // 2 for k in kernel_size] # 'SAME' padding + + if use_complex_networks: + conv = complex_nn.ComplexConv2d + bn = complex_nn.ComplexBatchNorm2d + else: + conv = nn.Conv2d + bn = nn.BatchNorm2d + + self.conv = conv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + padding_mode=padding_mode + ) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x: torch.Tensor): + # x shape: [b, c, f, t, 2] + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Decoder(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int]] = (0, 0), + use_complex_networks: bool = False, + ): + super().__init__() + if use_complex_networks: + tconv = complex_nn.ComplexConvTranspose2d + bn = complex_nn.ComplexBatchNorm2d + else: + tconv = nn.ConvTranspose2d + bn = nn.BatchNorm2d + + self.transconv = tconv( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + self.bn = bn(out_channels) + self.relu = nn.LeakyReLU(inplace=True) + + def forward(self, x): + x = self.transconv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class UNetConfig14(object): + """ + inputs x shape: [1, 1, 321, 2000, 2] + + sample rate: 16000 + nfft: 640 + win_size: 640 + hop_size: 320 (200ms) + """ + def __init__(self, in_channels: int): + self.enc_channels = [in_channels, 128, 128, 128, 128, 128, 128, 128] + self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (2, 2)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + + self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1] + self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), (5, 2), (5, 2)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + + +class UNetConfig10(object): + """ + inputs x shape: [1, 1, 65, 200, 2] + + sample rate: 8000 + nfft: 128 + win_size: 128 + hop_size: 64 (8ms) + + """ + def __init__(self, in_channels: int): + self.enc_channels = [in_channels, 16, 32, 64, 128, 256] + self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] + self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + + self.dec_channels = [128, 128, 64, 32, 16, 1] + self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)] + self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] + self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] + + +class UNetConfig20(object): + """ + inputs x shape: [1, 1, 257, 2000, 2] + + sample rate: 8000 + nfft: 512 + win_size: 512 + hop_size: 256 (32ms) + + """ + def __init__(self, in_channels: int, model_complexity: int): + self.enc_channels = [ + in_channels, + model_complexity, model_complexity, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, + 128 + ] + + self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3), + (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)] + + self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), + (2, 1), (2, 2), (2, 1), (2, 2), (2, 1)] + + self.enc_paddings = [ + (3, 0), + (0, 3), + None, # (0, 2), + None, + None, # (3,1), + None, # (3,1), + None, # (1,2), + None, + None, + None + ] + + self.dec_channels = [ + 64, + model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity * 2, model_complexity * 2, + model_complexity, model_complexity, + 1 + ] + + self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3), + (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)] + + self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), + (2, 2), (2, 1), (2, 2), (1, 1), (1, 1)] + + self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1), + (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)] + + +class UNet(nn.Module): + def __init__(self, + in_channels: int = 1, + use_complex_networks: bool = False, + model_complexity: int = 45, + model_depth: int = 20, + padding_mode: str = "zeros" + ): + super().__init__() + if use_complex_networks: + model_complexity = int(model_complexity // 1.414) + + # config + if model_depth == 14: + config = UNetConfig14(in_channels) + elif model_depth == 10: + config = UNetConfig10(in_channels) + elif model_depth == 20: + config = UNetConfig20(in_channels, model_complexity) + else: + raise AssertionError(f"Unknown model depth : {model_depth}") + + self.model_length = model_depth // 2 + + self.fsmn = complex_nn.ComplexUniDeepFsmn( + config.enc_channels[-1], + config.enc_channels[-1] + ) + + # go down + self.encoder_layers = nn.ModuleList(modules=[]) + for i in range(self.model_length): + encoder_layer = nn.Sequential( + complex_nn.ComplexUniDeepFsmnL1( + config.enc_channels[i], + config.enc_channels[i] + ) + if i != 0 else nn.Identity(), + Encoder( + config.enc_channels[i], + config.enc_channels[i + 1], + kernel_size=config.enc_kernel_sizes[i], + stride=config.enc_strides[i], + padding=config.enc_paddings[i], + use_complex_networks=use_complex_networks, + padding_mode=padding_mode + ), + SELayer(config.enc_channels[i + 1], reduction=8) + ) + self.encoder_layers.append(encoder_layer) + + self.decoder_layers = nn.ModuleList(modules=[]) + for i in range(self.model_length): + decoder_layer = nn.Sequential( + Decoder( + config.dec_channels[i] * 2, + config.dec_channels[i + 1], + kernel_size=config.dec_kernel_sizes[i], + stride=config.dec_strides[i], + padding=config.dec_paddings[i], + use_complex_networks=use_complex_networks + ), + complex_nn.ComplexUniDeepFsmnL1( + config.dec_channels[i + 1], + config.dec_channels[i + 1] + ) + if i < (self.model_length - 1) else nn.Identity(), + SELayer( + config.dec_channels[i + 1], + reduction=8 + ) + if i < (self.model_length - 2) else nn.Identity() + ) + self.decoder_layers.append(decoder_layer) + + if use_complex_networks: + conv = complex_nn.ComplexConv2d + else: + conv = nn.Conv2d + + self.linear = conv( + in_channels=config.dec_channels[-1], + out_channels=1, + kernel_size=1, + ) + + def forward(self, inputs: torch.Tensor): + """ + :param inputs: torch.Tensor, shape: [b, c, f, t, 2] + :return: + """ + x = inputs + # print(f"inputs: {x.shape}") + + # go down + xs = list() + xs_se = list() + xs_se.append(x) + for encoder_layer in self.encoder_layers: + xs.append(x) + # print(f"x: {x.shape}") + x = encoder_layer.forward(x) + # print(f"x: {x.shape}") + xs_se.append(x) + + # x shape: [b, c, 1, t', 2] + x = self.fsmn.forward(x) + # x shape: [b, c, 1, t', 2] + # print(f"fsmn") + + p = x + for i, decoder_layers in enumerate(self.decoder_layers): + p = decoder_layers.forward(p) + # print(f"p: {p.shape}") + if i == self.model_length - 1: + break + p = torch.cat(tensors=[p, xs_se[self.model_length - 1 - i]], dim=1) + + # cmp_spec: [1, 1, 321, 200, 2] + # cmp_spec: [1, 1, 513, 200, 2] + cmp_spec = self.linear.forward(p) + return cmp_spec + + +def main10(): + # [batch_size, 1, freq_bins, time_steps, 2] + # x = torch.rand(size=(1, 1, 65, 2000, 2)) + x = torch.rand(size=(1, 1, 65, 200, 2)) + unet = UNet( + in_channels=1, + model_complexity=-1, + model_depth=10, + use_complex_networks=True + ) + print(unet) + result = unet.forward(x) + print(result.shape) + return + + +def main20(): + # [batch_size, 1, freq_bins, time_steps, 2] + x = torch.rand(size=(1, 1, 257, 2000, 2)) + unet = UNet( + in_channels=1, + model_complexity=45, + model_depth=20, + use_complex_networks=True + ) + print(unet) + result = unet.forward(x) + print(result.shape) + return + + +if __name__ == "__main__": + main20() diff --git a/toolbox/torchaudio/models/frcrn/uni_deep_fsmn.py b/toolbox/torchaudio/models/frcrn/uni_deep_fsmn.py new file mode 100644 index 0000000000000000000000000000000000000000..7310bdc2a41b7b334c5c5073d4e252b980b8755d --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/uni_deep_fsmn.py @@ -0,0 +1,71 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/layers/uni_deep_fsmn.py +https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/mossformer2_se/fsmn.py +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UniDeepFsmn(nn.Module): + + def __init__(self, + input_dim: int, + hidden_size: int, + lorder: int = 1, + ): + super(UniDeepFsmn, self).__init__() + self.input_dim = input_dim + self.hidden_size = hidden_size + self.lorder = lorder + + self.linear = nn.Linear(input_dim, hidden_size) + self.project = nn.Linear(hidden_size, input_dim, bias=False) + self.conv1 = nn.Conv2d( + input_dim, + input_dim, + kernel_size=(lorder, 1), + stride=(1, 1), + groups=input_dim, + bias=False + ) + + def forward(self, inputs: torch.Tensor): + """ + :param inputs: torch.Tensor, shape: [b, t, h] + :return: torch.Tensor, shape: [b, t, h] + """ + x = F.relu(self.linear(inputs)) + x = self.project(x) + x = torch.unsqueeze(x, 1) + # x shape: [b, 1, t, h] + + x = x.permute(0, 3, 2, 1) + # x shape: [b, h, t, 1] + y = F.pad(x, [0, 0, self.lorder - 1, 0]) + + x = x + self.conv1(y) + x = x.permute(0, 3, 2, 1) + # x shape: [b, 1, t, h] + x = x.squeeze() + + result = inputs + x + return result + + +def main(): + x = torch.rand(size=(1, 200, 32)) + fsmn = UniDeepFsmn( + input_dim=32, + hidden_size=64, + lorder=3, + ) + result = fsmn.forward(x) + print(result.shape) + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/frcrn/yaml/config-10.yaml b/toolbox/torchaudio/models/frcrn/yaml/config-10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf4f013ee348189a4aabe3c89ac129ab7beb591b --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/yaml/config-10.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 128 +win_size: 128 +hop_size: 64 +win_type: hann + +use_complex_networks: true +model_depth: 10 +model_complexity: -1 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 10000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/toolbox/torchaudio/models/frcrn/yaml/config-14.yaml b/toolbox/torchaudio/models/frcrn/yaml/config-14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..707a9506ae459ba555c17b187d6631bbf63d681f --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/yaml/config-14.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 640 +win_size: 640 +hop_size: 320 +win_type: hann + +use_complex_networks: true +model_depth: 14 +model_complexity: -1 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 10000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/toolbox/torchaudio/models/frcrn/yaml/config-20.yaml b/toolbox/torchaudio/models/frcrn/yaml/config-20.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3182a4024e561560d8688bdc92617dfcb000d6e7 --- /dev/null +++ b/toolbox/torchaudio/models/frcrn/yaml/config-20.yaml @@ -0,0 +1,31 @@ +model_name: "frcrn" + +sample_rate: 8000 +segment_size: 32000 +nfft: 512 +win_size: 512 +hop_size: 256 +win_type: hann + +use_complex_networks: true +model_depth: 20 +model_complexity: 45 + +min_snr_db: -10 +max_snr_db: 20 + +num_workers: 8 +batch_size: 32 +eval_steps: 10000 + +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +weight_decay: 1.0e-05 +clip_grad_norm: 10.0 +seed: 1234 +num_gpus: -1 diff --git a/toolbox/torchaudio/models/gtcrn/__init__.py b/toolbox/torchaudio/models/gtcrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/models/gtcrn/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/gtcrn/modeling_gtcrn.py b/toolbox/torchaudio/models/gtcrn/modeling_gtcrn.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e3a0f6a11fb95ec8c626432cf569ca331bc438 --- /dev/null +++ b/toolbox/torchaudio/models/gtcrn/modeling_gtcrn.py @@ -0,0 +1,15 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://blog.csdn.net/gitblog_00478/article/details/141522595 + +https://github.com/Xiaobin-Rong/gtcrn/blob/main/gtcrn.py +https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/gtcrn_stream.py +""" +import torch +import torch.nn as nn +from typing import List, Tuple, Union + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/lstm/__init__.py b/toolbox/torchaudio/models/lstm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/lstm/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/lstm/configuration_lstm.py b/toolbox/torchaudio/models/lstm/configuration_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..193c89ff9aa52c71e53d1e8563fdd346fdf24a4a --- /dev/null +++ b/toolbox/torchaudio/models/lstm/configuration_lstm.py @@ -0,0 +1,73 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class LstmConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + segment_size: int = 32000, + nfft: int = 512, + win_size: int = 512, + hop_size: int = 256, + win_type: str = "hann", + + hidden_size: int = 1024, + num_layers: int = 2, + dropout: float = 0.2, + + min_snr_db: float = -10, + max_snr_db: float = 20, + + max_epochs: int = 100, + batch_size: int = 4, + num_workers: int = 4, + seed: int = 1234, + + lr: float = 0.001, + lr_scheduler: str = "CosineAnnealingLR", + lr_scheduler_kwargs: dict = None, + + weight_decay: float = 0.00001, + clip_grad_norm: float = 10., + eval_steps: int = 25000, + + **kwargs + ): + super(LstmConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.segment_size = segment_size + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.hidden_size = hidden_size + self.num_layers = num_layers + self.dropout = dropout + + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + + self.max_epochs = max_epochs + self.batch_size = batch_size + self.num_workers = num_workers + self.seed = seed + + self.lr = lr + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() + + self.weight_decay = weight_decay + self.clip_grad_norm = clip_grad_norm + self.eval_steps = eval_steps + + +def main(): + config = LstmConfig() + config.to_yaml_file("config.yaml") + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/lstm/modeling_lstm.py b/toolbox/torchaudio/models/lstm/modeling_lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..88b168edf91741f1022eeaa1be30cd9b367441a5 --- /dev/null +++ b/toolbox/torchaudio/models/lstm/modeling_lstm.py @@ -0,0 +1,260 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py +""" +import os +from typing import Optional, Union, Tuple + +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchaudio + +from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT + + +MODEL_FILE = "model.pt" + + +class Transpose(nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, inputs: torch.Tensor): + inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1) + return inputs + + +class LstmModel(nn.Module): + def __init__(self, + nfft: int = 512, + win_size: int = 512, + hop_size: int = 256, + win_type: str = "hann", + hidden_size=1024, + num_layers: int = 2, + batch_first: bool = True, + dropout: float = 0.2, + ): + super(LstmModel, self).__init__() + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.spec_bins = nfft // 2 + 1 + self.hidden_size = hidden_size + + self.eps = 1e-8 + + self.stft = ConvSTFT( + nfft=self.nfft, + win_size=self.win_size, + hop_size=self.hop_size, + win_type=self.win_type, + power=None, + requires_grad=False + ) + self.istft = ConviSTFT( + nfft=self.nfft, + win_size=self.win_size, + hop_size=self.hop_size, + win_type=self.win_type, + requires_grad=False + ) + + self.lstm = nn.LSTM(input_size=self.spec_bins, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + dropout=dropout, + ) + self.linear = nn.Linear(in_features=hidden_size, out_features=self.spec_bins) + self.activation = nn.Sigmoid() + + def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: + if signal.dim() == 2: + signal = torch.unsqueeze(signal, dim=1) + _, _, n_samples = signal.shape + remainder = (n_samples - self.win_size) % self.hop_size + if remainder > 0: + n_samples_pad = self.hop_size - remainder + signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) + return signal + + def forward(self, + noisy: torch.Tensor, + h_state: Tuple[torch.Tensor, torch.Tensor] = None, + ): + num_samples = noisy.shape[-1] + noisy = self.signal_prepare(noisy) + batch_size, _, num_samples_pad = noisy.shape + # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") + + mag_noisy, pha_noisy = self.mag_pha_stft(noisy) + # shape: (b, f, t) + # t = (num_samples - win_size) / hop_size + 1 + + mask, h_state = self.forward_chunk(mag_noisy, h_state) + # mask shape: (b, f, t) + + stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) + denoise = self.istft.forward(stft_denoise) + # denoise shape: [b, 1, num_samples_pad] + + denoise = denoise[:, :, :num_samples] + # denoise shape: [b, 1, num_samples] + return denoise, mask, h_state + + def mag_pha_stft(self, noisy: torch.Tensor): + # noisy shape: [b, num_samples] + stft_noisy = self.stft.forward(noisy) + # stft_noisy shape: [b, f, t], torch.complex64 + + real = torch.real(stft_noisy) + imag = torch.imag(stft_noisy) + mag_noisy = torch.sqrt(real ** 2 + imag ** 2) + pha_noisy = torch.atan2(imag, real) + # shape: (b, f, t) + return mag_noisy, pha_noisy + + def forward_chunk(self, + mag_noisy: torch.Tensor, + h_state: Tuple[torch.Tensor, torch.Tensor] = None, + ): + # mag_noisy shape: (b, f, t) + x = torch.transpose(mag_noisy, dim0=2, dim1=1) + # x shape: (b, t, f) + x, h_state = self.lstm.forward(x, hx=h_state) + x = self.linear.forward(x) + mask = self.activation(x) + # mask shape: (b, t, f) + mask = torch.transpose(mask, dim0=2, dim1=1) + # mask shape: (b, f, t) + return mask, h_state + + def do_mask(self, + mag_noisy: torch.Tensor, + pha_noisy: torch.Tensor, + mask: torch.Tensor, + ): + # (b, f, t) + mag_denoise = mag_noisy * mask + stft_denoise = mag_denoise * torch.exp((1j * pha_noisy)) + return stft_denoise + + +class LstmPretrainedModel(LstmModel): + def __init__(self, + config: LstmConfig, + ): + super(LstmPretrainedModel, self).__init__( + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + win_type=config.win_type, + hidden_size=config.hidden_size, + num_layers=config.num_layers, + dropout=config.dropout, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = LstmConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + config = LstmConfig() + model = LstmPretrainedModel(config) + model.eval() + + noisy = torch.randn(size=(1, 16000), dtype=torch.float32) + noisy = model.signal_prepare(noisy) + b, _, num_samples = noisy.shape + t = (num_samples - config.win_size) / config.hop_size + 1 + + waveform, mask, h_state = model.forward(noisy) + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + # noisy_pad shape: [b, 1, num_samples_pad] + + h_state = None + sub_spec_list = list() + for i in range(int(t)): + begin = i * config.hop_size + end = begin + config.win_size + sub_noisy = noisy[:, :, begin:end] + mag_noisy, pha_noisy = model.mag_pha_stft(sub_noisy) + mask, h_state = model.forward_chunk(mag_noisy, h_state) + sub_spec = model.do_mask(mag_noisy, pha_noisy, mask) + sub_spec_list.append(sub_spec) + + spec = torch.concat(sub_spec_list, dim=2) + + # 1 + waveform = model.istft.forward(spec) + waveform = waveform[:, :, :num_samples] + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + # 2 + cache_dict = None + waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32) + for i in range(int(t)): + sub_spec = spec[:, :, i:i+1] + begin = i * config.hop_size + end = begin + config.win_size - config.hop_size + sub_waveform, cache_dict = model.istft.forward_chunk(sub_spec, cache_dict=cache_dict) + # end = begin + config.win_size + # sub_waveform = model.istft.forward(sub_spec) + + # (b, 1, win_size) + waveform[:, :, begin:end] = sub_waveform + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/lstm/yaml/config.yaml b/toolbox/torchaudio/models/lstm/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cbc362dccca99c9dcd7bc746998df1c169ef4c10 --- /dev/null +++ b/toolbox/torchaudio/models/lstm/yaml/config.yaml @@ -0,0 +1,32 @@ +model_name: "lstm" + +# spec +sample_rate: 8000 +segment_size: 32000 +n_fft: 320 +win_size: 320 +hop_size: 160 +win_type: hann + +# data +max_snr_db: 20 +min_snr_db: -10 + +# model +hidden_size: 512 +num_layers: 3 +dropout: 0.1 + +# train +max_epochs: 100 +batch_size: 32 +num_workers: 4 +seed: 1234 + +lr: 0.001 +lr_scheduler: CosineAnnealingLR +lr_scheduler_kwargs: {} + +weight_decay: 0.00001 +clip_grad_norm: 10.0 +eval_steps: 25000 diff --git a/toolbox/torchaudio/models/mpnet/__init__.py b/toolbox/torchaudio/models/mpnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/mpnet/configuration_mpnet.py b/toolbox/torchaudio/models/mpnet/configuration_mpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5f45850b15d1f305859210eaf52a64e3410f0ac2 --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/configuration_mpnet.py @@ -0,0 +1,74 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Tuple + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class MPNetConfig(PretrainedConfig): + """ + https://github.com/yxlu-0102/MP-SENet/blob/main/config.json + """ + def __init__(self, + num_gpus: int = 0, + batch_size: int = 4, + learning_rate: float = 0.0005, + adam_b1: float = 0.8, + adam_b2: float = 0.99, + lr_decay: float = 0.99, + seed: int = 1234, + + dense_channel: int = 64, + compress_factor: float = 0.3, + num_tsconformers: int = 4, + beta: float = 2.0, + + sample_rate: int = 16000, + segment_size: int = 32000, + n_fft: int = 400, + hop_size: int = 100, + win_size: int = 400, + + num_workers: int = 4, + + dist_config: dict = None, + + discriminator_dim: int = 32, + discriminator_in_channel: int = 2, + + **kwargs + ): + super(MPNetConfig, self).__init__(**kwargs) + self.num_gpus = num_gpus + self.batch_size = batch_size + self.learning_rate = learning_rate + self.adam_b1 = adam_b1 + self.adam_b2 = adam_b2 + self.lr_decay = lr_decay + self.seed = seed + + self.dense_channel = dense_channel + self.compress_factor = compress_factor + self.num_tsconformers = num_tsconformers + self.beta = beta + + self.sample_rate = sample_rate + self.segment_size = segment_size + self.n_fft = n_fft + self.hop_size = hop_size + self.win_size = win_size + + self.num_workers = num_workers + + self.dist_config = dist_config or { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } + + self.discriminator_dim = discriminator_dim + self.discriminator_in_channel = discriminator_in_channel + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/mpnet/conformer.py b/toolbox/torchaudio/models/mpnet/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd26d8ff7a9e56ca73930ec356c7d78916332b0 --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/conformer.py @@ -0,0 +1,83 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from einops.layers.torch import Rearrange +import torch.nn as nn + + +def get_padding(kernel_size: int, dilation: int = 1): + return int((kernel_size * dilation - dilation) / 2) + + +class FeedForwardModule(nn.Module): + def __init__(self, dim, mult=4, dropout=0): + super(FeedForwardModule, self).__init__() + self.ffm = nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * mult), + nn.SiLU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.ffm(x) + + +class ConformerConvModule(nn.Module): + def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0.): + super(ConformerConvModule, self).__init__() + inner_dim = dim * expansion_factor + self.ccm = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, inner_dim*2, 1), + nn.GLU(dim=1), + nn.Conv1d(inner_dim, inner_dim, kernel_size=kernel_size, + padding=get_padding(kernel_size), groups=inner_dim), # DepthWiseConv1d + nn.BatchNorm1d(inner_dim), + nn.SiLU(), + nn.Conv1d(inner_dim, dim, 1), + Rearrange('b c n -> b n c'), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.ccm(x) + + +class AttentionModule(nn.Module): + def __init__(self, dim, n_head=8, dropout=0.): + super(AttentionModule, self).__init__() + self.attn = nn.MultiheadAttention(dim, n_head, dropout=dropout) + self.layernorm = nn.LayerNorm(dim) + + def forward(self, x, attn_mask=None, key_padding_mask=None): + x = self.layernorm(x) + x, _ = self.attn(x, x, x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask) + return x + + +class ConformerBlock(nn.Module): + def __init__(self, dim, n_head=8, ffm_mult=4, ccm_expansion_factor=2, ccm_kernel_size=31, + ffm_dropout=0., attn_dropout=0., ccm_dropout=0.): + super(ConformerBlock, self).__init__() + self.ffm1 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout) + self.attn = AttentionModule(dim, n_head, dropout=attn_dropout) + self.ccm = ConformerConvModule(dim, ccm_expansion_factor, ccm_kernel_size, dropout=ccm_dropout) + self.ffm2 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout) + self.post_norm = nn.LayerNorm(dim) + + def forward(self, x): + x = x + 0.5 * self.ffm1(x) + x = x + self.attn(x) + x = x + self.ccm(x) + x = x + 0.5 * self.ffm2(x) + x = self.post_norm(x) + return x + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/mpnet/discriminator.py b/toolbox/torchaudio/models/mpnet/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..055ffe8ca37da23675c2363e50b522b1ebc91282 --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/discriminator.py @@ -0,0 +1,129 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from pesq import pesq +from joblib import Parallel, delayed + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig +from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d + + +# def cal_pesq(clean, noisy, sr=16000): +# try: +# pesq_score = pesq(sr, clean, noisy, 'wb') +# except: +# # error can happen due to silent period +# pesq_score = -1 +# return pesq_score + + +# def batch_pesq(clean, noisy): +# pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy)) +# pesq_score = np.array(pesq_score) +# if -1 in pesq_score: +# return None +# pesq_score = (pesq_score - 1) / 3.5 +# return torch.FloatTensor(pesq_score) + + +def metric_loss(metric_ref, metrics_gen): + loss = 0 + for metric_gen in metrics_gen: + metric_loss = F.mse_loss(metric_ref, metric_gen.flatten()) + loss += metric_loss + + return loss + + +class MetricDiscriminator(nn.Module): + def __init__(self, config: MPNetConfig): + super(MetricDiscriminator, self).__init__() + dim = config.discriminator_dim + in_channel = config.discriminator_in_channel + + self.layers = nn.Sequential( + nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim, affine=True), + nn.PReLU(dim), + nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim*2, affine=True), + nn.PReLU(dim*2), + nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim*4, affine=True), + nn.PReLU(dim*4), + nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), + nn.InstanceNorm2d(dim*8, affine=True), + nn.PReLU(dim*8), + nn.AdaptiveMaxPool2d(1), + nn.Flatten(), + nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), + nn.Dropout(0.3), + nn.PReLU(dim*4), + nn.utils.spectral_norm(nn.Linear(dim*4, 1)), + LearnableSigmoid1d(1) + ) + + def forward(self, x, y): + xy = torch.stack((x, y), dim=1) + return self.layers(xy) + + +MODEL_FILE = "discriminator.pt" + + +class MetricDiscriminatorPretrainedModel(MetricDiscriminator): + def __init__(self, + config: MPNetConfig, + ): + super(MetricDiscriminatorPretrainedModel, self).__init__( + config=config, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/mpnet/inference_mpnet.py b/toolbox/torchaudio/models/mpnet/inference_mpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..43b2c5ed66fe97171e328fc1adc18f4beb794d7c --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/inference_mpnet.py @@ -0,0 +1,118 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +from pathlib import Path +import shutil +import tempfile +import zipfile + +import librosa +import numpy as np +import torch +import torchaudio + +from project_settings import project_path +from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig +from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel, MODEL_FILE +from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft + +logger = logging.getLogger("toolbox") + + +class InferenceMPNet(object): + def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): + self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file + self.device = torch.device(device) + + logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") + config, generator = self.load_models(self.pretrained_model_path_or_zip_file) + logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") + + self.config = config + self.generator = generator + self.generator.to(device) + self.generator.eval() + + def load_models(self, model_path: str): + model_path = Path(model_path) + if model_path.name.endswith(".zip"): + with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: + out_root = Path(tempfile.gettempdir()) / "nx_denoise" + out_root.mkdir(parents=True, exist_ok=True) + f_zip.extractall(path=out_root) + model_path = out_root / model_path.stem + + config = MPNetConfig.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + generator = MPNetPretrainedModel.from_pretrained( + pretrained_model_name_or_path=model_path.as_posix(), + ) + generator.to(self.device) + generator.eval() + + shutil.rmtree(model_path) + return config, generator + + def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + # noisy_audio shape: [batch_size, n_samples] + enhanced_audio = self.enhancement_by_tensor(noisy_audio) + # enhanced_audio shape: [channels, num_samples] + enhanced_audio = enhanced_audio[0] + # enhanced_audio shape: [num_samples] + return enhanced_audio.cpu().numpy() + + def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: + if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: + raise AssertionError(f"The value range of audio samples should be between -1 and 1.") + + noisy_audio = noisy_audio.to(self.device) + + with torch.no_grad(): + noisy_mag, noisy_pha, noisy_com = mag_pha_stft( + noisy_audio, + self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor + ) + mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha) + audio_g = mag_pha_istft( + mag_g, pha_g, + self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor + ) + enhanced_audio = audio_g.detach() + + # shape: [batch_size, num_samples] + enhanced_audio = torch.unsqueeze(enhanced_audio, dim=1) + # shape: [batch_size, 1, num_samples] + + enhanced_audio = enhanced_audio[0] + # shape: [channels, num_samples] + return enhanced_audio + + +def main(): + model_zip_file = project_path / "trained_models/mpnet-aishell-1-epoch.zip" + infer_mpnet = InferenceMPNet(model_zip_file) + + sample_rate = 8000 + noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav" + noisy_audio, _ = librosa.load( + noisy_audio_file.as_posix(), + sr=sample_rate, + ) + noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] + noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) + noisy_audio = noisy_audio.unsqueeze(dim=0) + + enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio) + + filename = "enhanced_audio.wav" + torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) + + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/mpnet/metrics.py b/toolbox/torchaudio/models/mpnet/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..78468894a56d4488021e83ea47e07c785a385269 --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/metrics.py @@ -0,0 +1,80 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from joblib import Parallel, delayed +import numpy as np +from pesq import pesq +from typing import List + +from pesq import cypesq + + +def run_pesq(clean_audio: np.ndarray, + noisy_audio: np.ndarray, + sample_rate: int = 16000, + mode: str = "wb", + ) -> float: + if sample_rate == 8000 and mode == "wb": + raise AssertionError(f"mode should be `nb` when sample_rate is 8000") + try: + pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) + except cypesq.NoUtterancesError as e: + pesq_score = -1 + except Exception as e: + print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") + pesq_score = -1 + return pesq_score + + +def run_batch_pesq(clean_audio_list: List[np.ndarray], + noisy_audio_list: List[np.ndarray], + sample_rate: int = 16000, + mode: str = "wb", + n_jobs: int = 4, + ) -> List[float]: + parallel = Parallel(n_jobs=n_jobs) + + parallel_tasks = list() + for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): + parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) + parallel_tasks.append(parallel_task) + + pesq_score_list = parallel.__call__(parallel_tasks) + return pesq_score_list + + +def run_pesq_score(clean_audio_list: List[np.ndarray], + noisy_audio_list: List[np.ndarray], + sample_rate: int = 16000, + mode: str = "wb", + n_jobs: int = 4, + ) -> List[float]: + + pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, + noisy_audio_list=noisy_audio_list, + sample_rate=sample_rate, + mode=mode, + n_jobs=n_jobs, + ) + + pesq_score = np.mean(pesq_score_list) + return pesq_score + + +def main(): + clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) + noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) + + clean_audio_list = list(clean_audio) + noisy_audio_list = list(noisy_audio) + + pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) + print(pesq_score_list) + + pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) + print(pesq_score) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/mpnet/modeling_mpnet.py b/toolbox/torchaudio/models/mpnet/modeling_mpnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f74b46d411ccfee193e26efd9eaa0afce114ca --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/modeling_mpnet.py @@ -0,0 +1,301 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py + +https://huggingface.co/spaces/JacobLinCool/MP-SENet + +https://arxiv.org/abs/2305.13686 +https://github.com/yxlu-0102/MP-SENet + +应该是不支持流式改造的。 + +""" +import os +from typing import Optional, Union + +from pesq import pesq +from joblib import Parallel, delayed +import numpy as np +import torch +import torch.nn as nn + +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock +from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock +from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig +from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d + + +class SPConvTranspose2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, r=1): + super(SPConvTranspose2d, self).__init__() + self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.) + self.out_channels = out_channels + self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) + self.r = r + + def forward(self, x): + x = self.pad1(x) + out = self.conv(x) + batch_size, nchannels, H, W = out.shape + out = out.view((batch_size, self.r, nchannels // self.r, H, W)) + out = out.permute(0, 2, 3, 4, 1) + out = out.contiguous().view((batch_size, nchannels // self.r, H, -1)) + return out + + +class DenseBlock(nn.Module): + def __init__(self, h, kernel_size=(2, 3), depth=4): + super(DenseBlock, self).__init__() + self.h = h + self.depth = depth + self.dense_block = nn.ModuleList([]) + for i in range(depth): + dilation = 2 ** i + pad_length = dilation + dense_conv = nn.Sequential( + nn.ConstantPad2d((1, 1, pad_length, 0), value=0.), + nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dilation, 1)), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel) + ) + self.dense_block.append(dense_conv) + + def forward(self, x): + skip = x + for i in range(self.depth): + x = self.dense_block[i](skip) + skip = torch.cat([x, skip], dim=1) + return x + + +class DenseEncoder(nn.Module): + def __init__(self, h, in_channel): + super(DenseEncoder, self).__init__() + self.h = h + self.dense_conv_1 = nn.Sequential( + nn.Conv2d(in_channel, h.dense_channel, (1, 1)), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel)) + + self.dense_block = DenseBlock(h, depth=4) + + self.dense_conv_2 = nn.Sequential( + nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2), padding=(0, 1)), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel)) + + def forward(self, x): + x = self.dense_conv_1(x) # [b, 64, T, F] + x = self.dense_block(x) # [b, 64, T, F] + x = self.dense_conv_2(x) # [b, 64, T, F//2] + return x + + +class MaskDecoder(nn.Module): + def __init__(self, h, out_channel=1): + super(MaskDecoder, self).__init__() + self.dense_block = DenseBlock(h, depth=4) + self.mask_conv = nn.Sequential( + SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel), + nn.Conv2d(h.dense_channel, out_channel, (1, 2)) + ) + self.lsigmoid = LearnableSigmoid2d(h.n_fft//2+1, beta=h.beta) + + def forward(self, x): + x = self.dense_block(x) + x = self.mask_conv(x) + x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T] + x = self.lsigmoid(x) + return x + + +class PhaseDecoder(nn.Module): + def __init__(self, h, out_channel=1): + super(PhaseDecoder, self).__init__() + self.dense_block = DenseBlock(h, depth=4) + self.phase_conv = nn.Sequential( + SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2), + nn.InstanceNorm2d(h.dense_channel, affine=True), + nn.PReLU(h.dense_channel) + ) + self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 2)) + self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 2)) + + def forward(self, x): + x = self.dense_block(x) + x = self.phase_conv(x) + x_r = self.phase_conv_r(x) + x_i = self.phase_conv_i(x) + x = torch.atan2(x_i, x_r) + x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T] + return x + + +class TSTransformerBlock(nn.Module): + def __init__(self, h): + super(TSTransformerBlock, self).__init__() + self.h = h + self.time_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4) + self.freq_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4) + + def forward(self, x): + b, c, t, f = x.size() + x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c) + x = self.time_transformer(x) + x + x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c) + x = self.freq_transformer(x) + x + x = x.view(b, t, f, c).permute(0, 3, 1, 2) + return x + + +class MPNet(nn.Module): + def __init__(self, config: MPNetConfig, num_tsblocks=4): + super(MPNet, self).__init__() + self.num_tscblocks = num_tsblocks + self.dense_encoder = DenseEncoder(config, in_channel=2) + + self.TSTransformer = nn.ModuleList([]) + for i in range(num_tsblocks): + self.TSTransformer.append(TSTransformerBlock(config)) + + self.mask_decoder = MaskDecoder(config, out_channel=1) + self.phase_decoder = PhaseDecoder(config, out_channel=1) + + def forward(self, noisy_amp, noisy_pha): # [B, F, T] + + x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F] + x = self.dense_encoder(x) + + for i in range(self.num_tscblocks): + x = self.TSTransformer[i](x) + + denoised_amp = noisy_amp * self.mask_decoder(x) + denoised_pha = self.phase_decoder(x) + denoised_com = torch.stack( + tensors=( + denoised_amp * torch.cos(denoised_pha), + denoised_amp * torch.sin(denoised_pha) + ), + dim=-1 + ) + + return denoised_amp, denoised_pha, denoised_com + + +MODEL_FILE = "generator.pt" + + +class MPNetPretrainedModel(MPNet): + def __init__(self, + config: MPNetConfig, + ): + super(MPNetPretrainedModel, self).__init__( + config=config, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def phase_losses(phase_r, phase_g): + + ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) + gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) + iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) + + return ip_loss, gd_loss, iaf_loss + + +def anti_wrapping_function(x): + + return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) + + +# def pesq_score(utts_r, utts_g, h): +# +# pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)( +# utts_r[i].squeeze().cpu().numpy(), +# utts_g[i].squeeze().cpu().numpy(), +# h.sample_rate, ) +# for i in range(len(utts_r))) +# pesq_score = np.mean(pesq_score) +# +# return pesq_score +# +# +# def eval_pesq(clean_utt, esti_utt, sr): +# try: +# mode = "nb" if sr == 8000 else "wb" +# pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode) +# except: +# pesq_score = -1 +# +# return pesq_score + + +def main(): + import torchaudio + + config = MPNetConfig() + model = MPNet(config=config) + + transformer = torchaudio.transforms.Spectrogram( + n_fft=config.n_fft, + win_length=config.win_size, + hop_length=config.hop_size, + window_fn=torch.hamming_window, + ) + + inputs = torch.randn(size=(1, 32000), dtype=torch.float32) + spec = transformer.forward(inputs) + print(spec.shape) + + denoised_amp, denoised_pha, denoised_com = model.forward(spec, spec) + print(denoised_amp.shape) + print(denoised_pha.shape) + print(denoised_com.shape) + + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/mpnet/transformers.py b/toolbox/torchaudio/models/mpnet/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..359661442edc578fadaa01de95e288ad225a3143 --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/transformers.py @@ -0,0 +1,70 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import MultiheadAttention, GRU, Linear, LayerNorm, Dropout + + +class FFN(nn.Module): + def __init__(self, d_model, bidirectional=True, dropout=0): + super(FFN, self).__init__() + self.gru = GRU(d_model, d_model * 2, 1, bidirectional=bidirectional) + if bidirectional: + self.linear = Linear(d_model * 2 * 2, d_model) + else: + self.linear = Linear(d_model * 2, d_model) + self.dropout = Dropout(dropout) + + def forward(self, x): + self.gru.flatten_parameters() + x, _ = self.gru(x) + x = F.leaky_relu(x) + x = self.dropout(x) + x = self.linear(x) + + return x + + +class TransformerBlock(nn.Module): + def __init__(self, d_model, n_heads, bidirectional=True, dropout=0): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(d_model) + self.attention = MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout1 = Dropout(dropout) + + self.norm2 = LayerNorm(d_model) + self.ffn = FFN(d_model, bidirectional=bidirectional) + self.dropout2 = Dropout(dropout) + + self.norm3 = LayerNorm(d_model) + + def forward(self, x, attn_mask=None, key_padding_mask=None): + xt = self.norm1(x) + xt, _ = self.attention(xt, xt, xt, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask) + x = x + self.dropout1(xt) + + xt = self.norm2(x) + xt = self.ffn(xt) + x = x + self.dropout2(xt) + + x = self.norm3(x) + + return x + + +def main(): + x = torch.randn(4, 64, 401, 201) + b, c, t, f = x.size() + x = x.permute(0, 3, 2, 1).contiguous().view(b, f * t, c) + transformer = TransformerBlock(d_model=64, n_heads=4) + x = transformer(x) + x = x.view(b, f, t, c).permute(0, 3, 2, 1) + print(x.size()) + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/mpnet/utils.py b/toolbox/torchaudio/models/mpnet/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c705751229d6f929ef3e7465649f4981799ee94 --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/utils.py @@ -0,0 +1,106 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from einops.layers.torch import Rearrange +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pesq import pesq +from joblib import Parallel, delayed + + +def phase_losses(phase_r, phase_g): + + ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) + gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) + iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) + + return ip_loss, gd_loss, iaf_loss + + +def anti_wrapping_function(x): + + return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) + + +def pesq_score(utts_r, utts_g, h): + + pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)( + utts_r[i].squeeze().cpu().numpy(), + utts_g[i].squeeze().cpu().numpy(), + h.sample_rate) + for i in range(len(utts_r))) + pesq_score = np.mean(pesq_score) + + return pesq_score + + +def eval_pesq(clean_utt, esti_utt, sr): + try: + pesq_score = pesq(sr, clean_utt, esti_utt) + except: + pesq_score = -1 + + return pesq_score + + +def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): + + hann_window = torch.hann_window(win_size).to(y.device) + stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, + center=center, pad_mode='reflect', normalized=False, return_complex=True) + stft_spec = torch.view_as_real(stft_spec) + mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9) + pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5) + # Magnitude Compression + mag = torch.pow(mag, compress_factor) + com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) + + return mag, pha, com + + +def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): + # Magnitude Decompression + mag = torch.pow(mag, (1.0/compress_factor)) + com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) + hann_window = torch.hann_window(win_size).to(com.device) + wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) + + return wav + + +class LearnableSigmoid1d(nn.Module): + def __init__(self, in_features, beta=1): + super().__init__() + self.beta = beta + self.slope = nn.Parameter(torch.ones(in_features)) + self.slope.requiresGrad = True + + def forward(self, x): + # x shape: [batch_size, time_steps, spec_bins] + return self.beta * torch.sigmoid(self.slope * x) + + +class LearnableSigmoid2d(nn.Module): + def __init__(self, in_features, beta=1): + super().__init__() + self.beta = beta + self.slope = nn.Parameter(torch.ones(in_features, 1)) + self.slope.requiresGrad = True + + def forward(self, x): + return self.beta * torch.sigmoid(self.slope * x) + + +def main(): + learnable_sigmoid = LearnableSigmoid1d(201) + a = torch.randn(4, 100, 201) + + result = learnable_sigmoid.forward(a) + print(result.shape) + + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/mpnet/yaml/config.yaml b/toolbox/torchaudio/models/mpnet/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3bd73873bbd2a8f8599952dfea872a6dc0c3ed2c --- /dev/null +++ b/toolbox/torchaudio/models/mpnet/yaml/config.yaml @@ -0,0 +1,30 @@ +model_name: "mpnet" + +num_gpus: 0 +batch_size: 4 +learning_rate: 0.0005 +adam_b1: 0.8 +adam_b2: 0.99 +lr_decay: 0.99 +seed: 1234 + +dense_channel: 64 +compress_factor: 0.3 +num_tsconformers: 4 +beta: 2.0 + +sample_rate: 16000 +segment_size: 32000 +n_fft: 400 +hop_size: 100 +win_size: 400 + +num_workers: 4 + +dist_config: + dist_backend: nccl + dist_url: tcp://localhost:54321 + world_size: 1 + +discriminator_dim: 32 +discriminator_in_channel: 2 diff --git a/toolbox/torchaudio/models/percepnet/__init__.py b/toolbox/torchaudio/models/percepnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/percepnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/percepnet/modeling_percetnet.py b/toolbox/torchaudio/models/percepnet/modeling_percetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..bcdfcfa2069f8225a9b304ce960d724bc1b9557a --- /dev/null +++ b/toolbox/torchaudio/models/percepnet/modeling_percetnet.py @@ -0,0 +1,100 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/jzi040941/PercepNet + +https://arxiv.org/abs/2008.04259 + +https://modelzoo.co/model/percepnet + +太复杂了。 +(1)pytorch 模型只是整个 pipeline 中的一部分。 +(2)训练样本需经过基音分析,频谱包络之类的计算。 + +""" +import torch +import torch.nn as nn + + +class PercepNet(nn.Module): + """ + https://github.com/jzi040941/PercepNet/blob/main/rnn_train.py#L105 + + 4.1% of an x86 CPU core + """ + def __init__(self, input_dim=70): + super(PercepNet, self).__init__() + # self.hidden_dim = hidden_dim + # self.n_layers = n_layers + + self.fc = nn.Sequential( + nn.Linear(input_dim, 128), + nn.ReLU() + ) + self.conv1 = nn.Sequential( + nn.Conv1d(128, 512, 5, stride=1, padding=4), + nn.ReLU() + )#padding for align with c++ dnn + self.conv2 = nn.Sequential( + nn.Conv1d(512, 512, 3, stride=1, padding=2), + nn.Tanh() + ) + #self.gru = nn.GRU(512, 512, 3, batch_first=True) + self.gru1 = nn.GRU(512, 512, 1, batch_first=True) + self.gru2 = nn.GRU(512, 512, 1, batch_first=True) + self.gru3 = nn.GRU(512, 512, 1, batch_first=True) + + self.gru_gb = nn.GRU(512, 512, 1, batch_first=True) + self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True) + + self.fc_gb = nn.Sequential( + nn.Linear(512*5, 34), + nn.Sigmoid() + ) + self.fc_rb = nn.Sequential( + nn.Linear(128, 34), + nn.Sigmoid() + ) + + def forward(self, x: torch.Tensor): + # x shape: [b, t, f] + x = self.fc(x) + x = x.permute([0, 2, 1]) + # x shape: [b, f, t] + + # causal conv + x = self.conv1(x) + x = x[:, :, :-4] + + # x shape: [b, f, t] + convout = self.conv2(x) + convout = convout[:, :, :-2] + convout = convout.permute([0, 2, 1]) + # convout shape: [b, t, f] + + gru1_out, gru1_state = self.gru1(convout) + gru2_out, gru2_state = self.gru2(gru1_out) + gru3_out, gru3_state = self.gru3(gru2_out) + + gru_gb_out, gru_gb_state = self.gru_gb(gru3_out) + concat_gb_layer = torch.cat(tensors=(convout, gru1_out, gru2_out, gru3_out, gru_gb_out), dim=-1) + gb = self.fc_gb(concat_gb_layer) + + # concat rb need fix + concat_rb_layer = torch.cat(tensors=(gru3_out, convout), dim=-1) + rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer) + rb = self.fc_rb(rnn_rb_out) + + output = torch.cat((gb, rb), dim=-1) + return output + + +def main(): + model = PercepNet() + x = torch.randn(20, 8, 70) + out = model(x) + print(out.shape) + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/phm_unet/__init__.py b/toolbox/torchaudio/models/phm_unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/phm_unet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/phm_unet/modeling_phm_unet.py b/toolbox/torchaudio/models/phm_unet/modeling_phm_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..0042796e18b4c97fbb05b80d537532f2f7c0460c --- /dev/null +++ b/toolbox/torchaudio/models/phm_unet/modeling_phm_unet.py @@ -0,0 +1,9 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2006.00687 +""" + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/rnnoise/__init__.py b/toolbox/torchaudio/models/rnnoise/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/rnnoise/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/rnnoise/configuration_rnnoise.py b/toolbox/torchaudio/models/rnnoise/configuration_rnnoise.py new file mode 100644 index 0000000000000000000000000000000000000000..faa773688cdac3bdb2667c6a2fbd1464c2a93708 --- /dev/null +++ b/toolbox/torchaudio/models/rnnoise/configuration_rnnoise.py @@ -0,0 +1,75 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class RNNoiseConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + segment_size: int = 32000, + nfft: int = 512, + win_size: int = 512, + hop_size: int = 256, + win_type: str = "hann", + + erb_bins: int = 32, + min_freq_bins_for_erb: int = 2, + + conv_size: int = 128, + gru_size: int = 256, + + min_snr_db: float = -10, + max_snr_db: float = 20, + + lr: float = 0.001, + lr_scheduler: str = "CosineAnnealingLR", + lr_scheduler_kwargs: dict = None, + + max_epochs: int = 100, + clip_grad_norm: float = 10., + seed: int = 1234, + + batch_size: int = 64, + num_workers: int = 4, + eval_steps: int = 25000, + + **kwargs + ): + super(RNNoiseConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.segment_size = segment_size + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.erb_bins = erb_bins + self.min_freq_bins_for_erb = min_freq_bins_for_erb + + self.conv_size = conv_size + self.gru_size = gru_size + + self.min_snr_db = min_snr_db + self.max_snr_db = max_snr_db + + self.lr = lr + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() + + self.max_epochs = max_epochs + self.clip_grad_norm = clip_grad_norm + self.seed = seed + + self.batch_size = batch_size + self.num_workers = num_workers + self.eval_steps = eval_steps + + +def main(): + config = RNNoiseConfig() + config.to_yaml_file("yaml/config.yaml") + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py b/toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py new file mode 100644 index 0000000000000000000000000000000000000000..165d73a2028d71cd1d4b71b46cbf65af27eb6f39 --- /dev/null +++ b/toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py @@ -0,0 +1,401 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/xiph/rnnoise +https://github.com/xiph/rnnoise/blob/main/torch/rnnoise/rnnoise.py + +https://arxiv.org/abs/1709.08243 + +""" +import os +from typing import Optional, Union, Tuple + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from toolbox.torch.sparsification.gru_sparsifier import GRUSparsifier +from toolbox.torchaudio.models.rnnoise.configuration_rnnoise import RNNoiseConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE +from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT +from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands + + +sparsify_start = 6000 +sparsify_stop = 20000 +sparsify_interval = 100 +sparsify_exponent = 3 + + +sparse_params1 = { + "W_hr" : (0.3, [8, 4], True), + "W_hz" : (0.2, [8, 4], True), + "W_hn" : (0.5, [8, 4], True), + "W_ir" : (0.3, [8, 4], False), + "W_iz" : (0.2, [8, 4], False), + "W_in" : (0.5, [8, 4], False), +} + + +def init_weights(module): + if isinstance(module, nn.GRU): + for p in module.named_parameters(): + if p[0].startswith("weight_hh_"): + nn.init.orthogonal_(p[1]) + + +class RNNoise(nn.Module): + def __init__(self, + sample_rate: int = 8000, + nfft: int = 512, + win_size: int = 512, + hop_size: int = 256, + win_type: str = "hann", + erb_bins: int = 32, + min_freq_bins_for_erb: int = 2, + conv_size: int = 128, + gru_size: int = 256, + ): + super(RNNoise, self).__init__() + self.sample_rate = sample_rate + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.erb_bins = erb_bins + self.min_freq_bins_for_erb = min_freq_bins_for_erb + self.conv_size = conv_size + self.gru_size = gru_size + + self.input_dim = nfft // 2 + 1 + + self.eps = 1e-12 + + self.erb_bands = ErbBands( + sample_rate=self.sample_rate, + nfft=self.nfft, + erb_bins=self.erb_bins, + min_freq_bins_for_erb=self.min_freq_bins_for_erb, + ) + + self.stft = ConvSTFT( + nfft=self.nfft, + win_size=self.win_size, + hop_size=self.hop_size, + win_type=self.win_type, + power=None, + requires_grad=False + ) + self.istft = ConviSTFT( + nfft=self.nfft, + win_size=self.win_size, + hop_size=self.hop_size, + win_type=self.win_type, + requires_grad=False + ) + + self.pad = nn.ConstantPad1d(padding=(2, 2), value=0) + self.conv1 = nn.Conv1d(self.erb_bins, conv_size, kernel_size=3, padding="valid") + self.conv2 = nn.Conv1d(conv_size, gru_size, kernel_size=3, padding="valid") + + self.gru1 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) + self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) + self.gru3 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) + + self.dense_out = nn.Linear(4*self.gru_size, self.erb_bins) + + nb_params = sum(p.numel() for p in self.parameters()) + print(f"model: {nb_params} weights") + self.apply(init_weights) + + self.sparsifier = [ + GRUSparsifier( + task_list=[(self.gru1, sparse_params1)], + start=sparsify_start, + stop=sparsify_stop, + interval=sparsify_interval, + exponent=sparsify_exponent, + ), + GRUSparsifier( + task_list=[(self.gru2, sparse_params1)], + start=sparsify_start, + stop=sparsify_stop, + interval=sparsify_interval, + exponent=sparsify_exponent, + ), + GRUSparsifier( + task_list=[(self.gru3, sparse_params1)], + start=sparsify_start, + stop=sparsify_stop, + interval=sparsify_interval, + exponent=sparsify_exponent, + ) + ] + + def sparsify(self): + for sparsifier in self.sparsifier: + sparsifier.step() + + def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: + if signal.dim() == 2: + signal = torch.unsqueeze(signal, dim=1) + _, _, n_samples = signal.shape + remainder = (n_samples - self.win_size) % self.hop_size + if remainder > 0: + n_samples_pad = self.hop_size - remainder + signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) + return signal + + def forward(self, + noisy: torch.Tensor, + states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None, + ): + num_samples = noisy.shape[-1] + noisy = self.signal_prepare(noisy) + batch_size, _, num_samples_pad = noisy.shape + # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") + + mag_noisy, pha_noisy = self.mag_pha_stft(noisy) + # shape: (b, f, t) + # t = (num_samples - win_size) / hop_size + 1 + + mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2) + # shape: (b, t, f) + mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True) + # shape: (b, t, erb_bins) + mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2) + # shape: (b, erb_bins, t) + + mag_noisy_t_erb = self.pad(mag_noisy_t_erb) + mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb) + gru_out, states = self.forward_gru(mag_noisy_t_erb, states) + # gru_out shape: [b, t, f] + mask_erb = torch.sigmoid(self.dense_out(gru_out)) + # mask_erb shape: (b, t, erb_bins) + + mask = self.erb_bands.erb_scale_inv(mask_erb) + # mask shape: (b, t, f) + mask = torch.transpose(mask, dim0=1, dim1=2) + # mask shape: (b, f, t) + + stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) + denoise = self.istft.forward(stft_denoise) + # denoise shape: [b, 1, num_samples_pad] + + denoise = denoise[:, :, :num_samples] + # denoise shape: [b, 1, num_samples] + return denoise, mask, states + + def forward_conv(self, mag_noisy: torch.Tensor): + # mag_noisy shape: [b, f, t] + tmp = mag_noisy + # tmp shape: [b, f, t] + tmp = torch.tanh(self.conv1(tmp)) + tmp = torch.tanh(self.conv2(tmp)) + # tmp shape: [b, f, t] + return tmp + + def forward_gru(self, + mag_noisy: torch.Tensor, + states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None, + ): + if states is None: + gru1_state = None + gru2_state = None + gru3_state = None + else: + gru1_state = states[0] + gru2_state = states[1] + gru3_state = states[2] + + # mag_noisy shape: [b, f, t] + tmp = mag_noisy.permute(0, 2, 1) + # tmp shape: [b, t, f] + + gru1_out, gru1_state = self.gru1(tmp, gru1_state) + gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) + gru3_out, gru3_state = self.gru3(gru2_out, gru3_state) + new_states = [gru1_state, gru2_state, gru3_state] + + gru_out = torch.cat(tensors=[tmp, gru1_out, gru2_out, gru3_out], dim=-1) + # gru_out shape: [b, t, f] + return gru_out, new_states + + def forward_chunk_by_chunk(self, + noisy: torch.Tensor, + ): + noisy = self.signal_prepare(noisy) + b, _, num_samples = noisy.shape + t = (num_samples - self.win_size) / self.hop_size + 1 + + waveform = torch.zeros(size=(b, 1, 0), dtype=torch.float32) + + states = None + cache_dict = None + + cache_list = list() + for i in range(int(t)): + begin = i * self.hop_size + end = begin + self.win_size + sub_noisy = noisy[:, :, begin:end] + mag_noisy, pha_noisy = self.mag_pha_stft(sub_noisy) + mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2) + mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True) + mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2) + # mag_noisy_t_erb shape: (b, erb_bins, t) + + if len(cache_list) == 0: + cache_list.extend([{ + "mag_noisy": torch.zeros_like(mag_noisy), + "pha_noisy": torch.zeros_like(pha_noisy), + "mag_noisy_t_erb": torch.zeros_like(mag_noisy_t_erb), + }] * 2) + cache_list.append({ + "mag_noisy": mag_noisy, + "pha_noisy": pha_noisy, + "mag_noisy_t_erb": mag_noisy_t_erb, + }) + if len(cache_list) < 5: + continue + mag_noisy_t_erb = torch.concat( + tensors=[c["mag_noisy_t_erb"] for c in cache_list], + dim=-1 + ) + mag_noisy = cache_list[2]["mag_noisy"] + pha_noisy = cache_list[2]["pha_noisy"] + cache_list.pop(0) + # mag_noisy_t_erb shape: [b, f, 5] + mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb) + # mag_noisy_t_erb shape: [b, f, 1] + gru_out, states = self.forward_gru(mag_noisy_t_erb, states) + mask_erb = torch.sigmoid(self.dense_out(gru_out)) + mask = self.erb_bands.erb_scale_inv(mask_erb) + mask = torch.transpose(mask, dim0=1, dim1=2) + stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) + sub_waveform, cache_dict = self.istft.forward_chunk(stft_denoise, cache_dict=cache_dict) + waveform = torch.concat(tensors=[waveform, sub_waveform], dim=-1) + + return waveform + + def do_mask(self, + mag_noisy: torch.Tensor, + pha_noisy: torch.Tensor, + mask: torch.Tensor, + ): + # (b, f, t) + mag_denoise = mag_noisy * mask + stft_denoise = mag_denoise * torch.exp((1j * pha_noisy)) + return stft_denoise + + def mag_pha_stft(self, noisy: torch.Tensor): + # noisy shape: [b, num_samples] + stft_noisy = self.stft.forward(noisy) + # stft_noisy shape: [b, f, t], torch.complex64 + + real = torch.real(stft_noisy) + imag = torch.imag(stft_noisy) + mag_noisy = torch.sqrt(real ** 2 + imag ** 2) + pha_noisy = torch.atan2(imag, real) + # shape: (b, f, t) + return mag_noisy, pha_noisy + + +MODEL_FILE = "model.pt" + + +class RNNoisePretrainedModel(RNNoise): + def __init__(self, + config: RNNoiseConfig, + ): + super(RNNoisePretrainedModel, self).__init__( + sample_rate=config.sample_rate, + nfft=config.nfft, + win_size=config.win_size, + hop_size=config.hop_size, + win_type=config.win_type, + erb_bins=config.erb_bins, + min_freq_bins_for_erb=config.min_freq_bins_for_erb, + conv_size=config.conv_size, + gru_size=config.gru_size, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = RNNoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main1(): + config = RNNoiseConfig() + model = RNNoisePretrainedModel(config) + model.eval() + + noisy = torch.randn(size=(1, 16000), dtype=torch.float32) + noisy = model.signal_prepare(noisy) + b, _, num_samples = noisy.shape + t = (num_samples - config.win_size) / config.hop_size + 1 + + waveform, mask, h_state = model.forward(noisy) + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + return + + +def main2(): + config = RNNoiseConfig() + model = RNNoisePretrainedModel(config) + model.eval() + + noisy = torch.randn(size=(1, 16000), dtype=torch.float32) + noisy = model.signal_prepare(noisy) + b, _, num_samples = noisy.shape + t = (num_samples - config.win_size) / config.hop_size + 1 + + waveform, mask, h_state = model.forward(noisy) + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + waveform = model.forward_chunk_by_chunk(noisy) + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + return + + +if __name__ == "__main__": + main2() diff --git a/toolbox/torchaudio/models/rnnoise/yaml/config.yaml b/toolbox/torchaudio/models/rnnoise/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b1e5fd7b66220585736aa8e20bb5fc3d482b7e87 --- /dev/null +++ b/toolbox/torchaudio/models/rnnoise/yaml/config.yaml @@ -0,0 +1,35 @@ +model_name: "rnnoise" + +# spec +sample_rate: 8000 +segment_size: 32000 +nfft: 160 +win_size: 160 +hop_size: 80 +win_type: hann + +erb_bins: 32 +min_freq_bins_for_erb: 2 + +# model +conv_size: 256 +gru_size: 256 + +# data +max_snr_db: 20 +min_snr_db: -10 + +# train +lr: 0.001 +lr_scheduler: "CosineAnnealingLR" +lr_scheduler_kwargs: + T_max: 250000 + eta_min: 0.0001 + +max_epochs: 100 +clip_grad_norm: 10.0 +seed: 1234 + +batch_size: 64 +num_workers: 4 +eval_steps: 15000 diff --git a/toolbox/torchaudio/models/simple_linear_irm/__init__.py b/toolbox/torchaudio/models/simple_linear_irm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79ab8cb427356b85a4fd65def3fc8101a5ee301a --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/__init__.py @@ -0,0 +1,59 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/zhaoforever/nn-irm +""" +import torch +import torch.nn as nn + +import torchaudio + + +class NNIRM(nn.Module): + """ + Ideal ratio mask estimator: + default config: 1799(257 x 7) => 2048 => 2048 => 2048 => 257 + """ + + def __init__(self, num_bins=257, n_frames=7, hidden_size=2048): + super(NNIRM, self).__init__() + self.nn = nn.Sequential( + nn.Linear(num_bins * n_frames, hidden_size), + nn.ReLU(), + nn.BatchNorm1d(hidden_size), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.BatchNorm1d(hidden_size), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.BatchNorm1d(hidden_size), + nn.Linear(hidden_size, num_bins), + nn.Sigmoid() + ) + + def forward(self, x): + return self.nn(x) + + +def main(): + signal = torch.rand(size=(16000,)) + + transformer = torchaudio.transforms.MelSpectrogram( + sample_rate=8000, + n_fft=512, + win_length=200, + hop_length=80, + f_min=10, + f_max=3800, + window_fn=torch.hamming_window, + n_mels=80, + ) + + inputs = torch.tensor([signal], dtype=torch.float32) + output = transformer.forward(inputs) + print(output) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/simple_linear_irm/configuration_simple_linear_irm.py b/toolbox/torchaudio/models/simple_linear_irm/configuration_simple_linear_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..732016b9bb1cfee9b528b69776a1a700a22170c0 --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/configuration_simple_linear_irm.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class SimpleLinearIRMConfig(PretrainedConfig): + def __init__(self, + sample_rate: int, + n_fft: int, + win_length: int, + hop_length: int, + + num_bins: int, + hidden_size: int, + lookback: int, + lookahead: int, + + **kwargs + ): + super(SimpleLinearIRMConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + self.num_bins = num_bins + self.hidden_size = hidden_size + self.lookback = lookback + self.lookahead = lookahead + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/simple_linear_irm/modeling_simple_linear_irm.py b/toolbox/torchaudio/models/simple_linear_irm/modeling_simple_linear_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc4f78150cfd9962a23e376b1dca60fa7bab5f4 --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/modeling_simple_linear_irm.py @@ -0,0 +1,167 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +https://github.com/zhaoforever/nn-irm +""" +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +import torchaudio + +from toolbox.torchaudio.models.simple_linear_irm.configuration_simple_linear_irm import SimpleLinearIRMConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE + + +MODEL_FILE = "model.pt" + + +class Transpose(nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, inputs: torch.Tensor): + inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1) + return inputs + + +class SimpleLinearIRM(nn.Module): + """ + Ideal ratio mask estimator: + default config: 1799(257 x 7) => 2048 => 2048 => 2048 => 257 + """ + + def __init__(self, num_bins=257, hidden_size=2048, lookback: int = 3, lookahead: int = 3): + super(SimpleLinearIRM, self).__init__() + self.num_bins = num_bins + self.hidden_size = hidden_size + self.lookback = lookback + self.lookahead = lookahead + + self.n_frames = lookback + 1 + lookahead + + self.nn = nn.Sequential( + Transpose(dim0=2, dim1=1), + nn.Linear(num_bins * self.n_frames, hidden_size), + nn.ReLU(), + Transpose(dim0=2, dim1=1), + nn.BatchNorm1d(hidden_size), + Transpose(dim0=2, dim1=1), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + Transpose(dim0=2, dim1=1), + nn.BatchNorm1d(hidden_size), + Transpose(dim0=2, dim1=1), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + Transpose(dim0=2, dim1=1), + nn.BatchNorm1d(hidden_size), + Transpose(dim0=2, dim1=1), + nn.Linear(hidden_size, num_bins), + Transpose(dim0=2, dim1=1), + + nn.Sigmoid() + ) + + @staticmethod + def frame_spec(spec: torch.Tensor, lookback: int = 3, lookahead: int = 3): + context = lookback + 1 + lookahead + + # batch, num_bins, time_steps + batch, num_bins, time_steps = spec.shape + + spec_ = torch.zeros([batch, context * num_bins, time_steps], dtype=spec.dtype) + for t in range(time_steps): + for c in range(context): + begin = c * num_bins + + t_ = t - lookback + c + t_ = 0 if t_ < 0 else t_ + t_ = time_steps - 1 if t_ > time_steps - 1 else t_ + spec_[:, begin: begin + num_bins, t] = spec[:, :, t_] + spec_ = spec_.to(spec.device) + return spec_ + + def forward(self, spec: torch.Tensor): + # spec shape: (batch_size, num_bins, time_steps) + frame_spec = self.frame_spec(spec, 3, 3) + # frame_spec shape: (batch_size, context * num_bins, time_steps) + mask = self.nn.forward(frame_spec) + return mask + + +class SimpleLinearIRMPretrainedModel(SimpleLinearIRM): + def __init__(self, + config: SimpleLinearIRMConfig, + ): + super(SimpleLinearIRMPretrainedModel, self).__init__( + num_bins=config.num_bins, + hidden_size=config.hidden_size, + lookback=config.lookback, + lookahead=config.lookahead, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = SimpleLinearIRMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + transformer = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, + ) + + model = SimpleLinearIRM() + + inputs = torch.randn(size=(1, 1600), dtype=torch.float32) + spec = transformer.forward(inputs) + + output = model.forward(spec) + print(output.shape) + print(output) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/simple_linear_irm/yaml/config.yaml b/toolbox/torchaudio/models/simple_linear_irm/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f269122b6f589c012376cda288e9a256f1453de2 --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/yaml/config.yaml @@ -0,0 +1,13 @@ +model_name: "simple_linear_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +# model +num_bins: 257 +hidden_size: 2048 +lookback: 3 +lookahead: 3 diff --git a/toolbox/torchaudio/models/spectrum_dfnet/__init__.py b/toolbox/torchaudio/models/spectrum_dfnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_dfnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/spectrum_dfnet/configuration_spectrum_dfnet.py b/toolbox/torchaudio/models/spectrum_dfnet/configuration_spectrum_dfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..29e25e659b2e800347be6b6e90e48e5a8a0694d8 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_dfnet/configuration_spectrum_dfnet.py @@ -0,0 +1,107 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Tuple + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class SpectrumDfNetConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + + spec_bins: int = 256, + + conv_channels: int = 64, + conv_kernel_size_input: Tuple[int, int] = (3, 3), + conv_kernel_size_inner: Tuple[int, int] = (1, 3), + conv_lookahead: int = 0, + + convt_kernel_size_inner: Tuple[int, int] = (1, 3), + + embedding_hidden_size: int = 256, + encoder_combine_op: str = "concat", + + encoder_emb_skip_op: str = "none", + encoder_emb_linear_groups: int = 16, + encoder_emb_hidden_size: int = 256, + + encoder_linear_groups: int = 32, + + lsnr_max: int = 30, + lsnr_min: int = -15, + norm_tau: float = 1., + + decoder_emb_num_layers: int = 3, + decoder_emb_skip_op: str = "none", + decoder_emb_linear_groups: int = 16, + decoder_emb_hidden_size: int = 256, + + df_decoder_hidden_size: int = 256, + df_num_layers: int = 2, + df_order: int = 5, + df_bins: int = 96, + df_gru_skip: str = "grouped_linear", + df_decoder_linear_groups: int = 16, + df_pathway_kernel_size_t: int = 5, + df_lookahead: int = 2, + + use_post_filter: bool = False, + **kwargs + ): + super(SpectrumDfNetConfig, self).__init__(**kwargs) + # transform + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + # spectrum + self.spec_bins = spec_bins + + # conv + self.conv_channels = conv_channels + self.conv_kernel_size_input = conv_kernel_size_input + self.conv_kernel_size_inner = conv_kernel_size_inner + self.conv_lookahead = conv_lookahead + + self.convt_kernel_size_inner = convt_kernel_size_inner + + self.embedding_hidden_size = embedding_hidden_size + + # encoder + self.encoder_emb_skip_op = encoder_emb_skip_op + self.encoder_emb_linear_groups = encoder_emb_linear_groups + self.encoder_emb_hidden_size = encoder_emb_hidden_size + + self.encoder_linear_groups = encoder_linear_groups + self.encoder_combine_op = encoder_combine_op + + self.lsnr_max = lsnr_max + self.lsnr_min = lsnr_min + self.norm_tau = norm_tau + + # decoder + self.decoder_emb_num_layers = decoder_emb_num_layers + self.decoder_emb_skip_op = decoder_emb_skip_op + self.decoder_emb_linear_groups = decoder_emb_linear_groups + self.decoder_emb_hidden_size = decoder_emb_hidden_size + + # df decoder + self.df_decoder_hidden_size = df_decoder_hidden_size + self.df_num_layers = df_num_layers + self.df_order = df_order + self.df_bins = df_bins + self.df_gru_skip = df_gru_skip + self.df_decoder_linear_groups = df_decoder_linear_groups + self.df_pathway_kernel_size_t = df_pathway_kernel_size_t + self.df_lookahead = df_lookahead + + # runtime + self.use_post_filter = use_post_filter + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py b/toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1d82439a597b86a67a9570643bec64e1823e5247 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py @@ -0,0 +1,933 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os +import math +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torchaudio + +from toolbox.torchaudio.models.spectrum_dfnet.configuration_spectrum_dfnet import SpectrumDfNetConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE + + +MODEL_FILE = "model.pt" + + +norm_layer_dict = { + "batch_norm_2d": torch.nn.BatchNorm2d +} + + +activation_layer_dict = { + "relu": torch.nn.ReLU, + "identity": torch.nn.Identity, + "sigmoid": torch.nn.Sigmoid, +} + + +class CausalConv2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal Conv2d by delaying the signal for any lookahead. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + + :param in_channels: + :param out_channels: + :param kernel_size: + :param fstride: + :param dilation: + :param fpad: + """ + super(CausalConv2d, self).__init__() + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) + + if fpad: + fpad_ = kernel_size[1] // 2 + dilation - 1 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = list() + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + if max(kernel_size) == 1: + separable = False + + layers.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + def forward(self, inputs): + for module in self: + inputs = module(inputs) + return inputs + + +class CausalConvTranspose2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal ConvTranspose2d. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + """ + super(CausalConvTranspose2d, self).__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + if fpad: + fpad_ = kernel_size[1] // 2 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + + layers.append( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size[0] - 1, fpad_ + dilation - 1), + output_padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + +class GroupedLinear(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" + assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" + self.ws = input_size // groups + self.register_parameter( + "weight", + torch.nn.Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True + ), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., I] + b, t, _ = x.shape + # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] + new_shape = (b, t, self.groups, self.ws) + x = x.view(new_shape) + # The better way, but not supported by torchscript + # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] + x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] + x = x.flatten(2, 3) # [B, T, H] + return x + + def __repr__(self): + cls = self.__class__.__name__ + return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" + + +class SqueezedGRU_S(nn.Module): + """ + SGE net: Video object detection with squeezed GRU and information entropy map + https://arxiv.org/abs/2106.07224 + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + skip_op: str = "none", + activation_layer: str = "identity", + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + self.linear_in = nn.Sequential( + GroupedLinear( + input_size=input_size, + hidden_size=hidden_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + + # gru skip operator + self.gru_skip_op = None + + if skip_op == "none": + self.gru_skip_op = None + elif skip_op == "identity": + if not input_size != output_size: + raise AssertionError("Dimensions do not match") + self.gru_skip_op = nn.Identity() + elif skip_op == "grouped_linear": + self.gru_skip_op = GroupedLinear( + input_size=hidden_size, + hidden_size=hidden_size, + groups=linear_groups, + ) + else: + raise NotImplementedError() + + self.gru = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + bidirectional=False, + ) + + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinear( + input_size=hidden_size, + hidden_size=output_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + else: + self.linear_out = nn.Identity() + + def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.linear_in(inputs) + + x, h = self.gru.forward(x, h) + + x = self.linear_out(x) + + if self.gru_skip_op is not None: + x = x + self.gru_skip_op(inputs) + + return x, h + + +class Add(nn.Module): + def forward(self, a, b): + return a + b + + +class Concat(nn.Module): + def forward(self, a, b): + return torch.cat((a, b), dim=-1) + + +class Encoder(nn.Module): + def __init__(self, config: SpectrumDfNetConfig): + super(Encoder, self).__init__() + self.embedding_input_size = config.conv_channels * config.spec_bins // 4 + self.embedding_output_size = config.conv_channels * config.spec_bins // 4 + self.embedding_hidden_size = config.embedding_hidden_size + + self.spec_conv0 = CausalConv2d( + in_channels=1, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.spec_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.spec_conv2 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.spec_conv3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + self.df_conv0 = CausalConv2d( + in_channels=2, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + ) + self.df_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.df_fc_emb = nn.Sequential( + GroupedLinear( + config.conv_channels * config.df_bins // 2, + self.embedding_input_size, + groups=config.encoder_linear_groups + ), + nn.ReLU(inplace=True) + ) + + if config.encoder_combine_op == "concat": + self.embedding_input_size *= 2 + self.combine = Concat() + else: + self.combine = Add() + + # emb_gru + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_gru = SqueezedGRU_S( + self.embedding_input_size, + self.embedding_hidden_size, + output_size=self.embedding_output_size, + num_layers=1, + batch_first=True, + skip_op=config.encoder_emb_skip_op, + linear_groups=config.encoder_emb_linear_groups, + activation_layer="relu", + ) + + # lsnr + self.lsnr_fc = nn.Sequential( + nn.Linear(self.embedding_output_size, 1), + nn.Sigmoid() + ) + self.lsnr_scale = config.lsnr_max - config.lsnr_min + self.lsnr_offset = config.lsnr_min + + def forward(self, + feat_power: torch.Tensor, + feat_spec: torch.Tensor, + hidden_state: torch.Tensor = None, + ): + # feat_power shape: (batch_size, 1, time_steps, spec_dim) + e0 = self.spec_conv0.forward(feat_power) + e1 = self.spec_conv1.forward(e0) + e2 = self.spec_conv2.forward(e1) + e3 = self.spec_conv3.forward(e2) + # e0 shape: [batch_size, channels, time_steps, spec_dim] + # e1 shape: [batch_size, channels, time_steps, spec_dim // 2] + # e2 shape: [batch_size, channels, time_steps, spec_dim // 4] + # e3 shape: [batch_size, channels, time_steps, spec_dim // 4] + + # feat_spec, shape: (batch_size, 2, time_steps, df_bins) + c0 = self.df_conv0(feat_spec) + c1 = self.df_conv1(c0) + # c0 shape: [batch_size, channels, time_steps, df_bins] + # c1 shape: [batch_size, channels, time_steps, df_bins // 2] + + cemb = c1.permute(0, 2, 3, 1) + # cemb shape: [batch_size, time_steps, df_bins // 2, channels] + cemb = cemb.flatten(2) + # cemb shape: [batch_size, time_steps, df_bins // 2 * channels] + cemb = self.df_fc_emb(cemb) + # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels] + + # e3 shape: [batch_size, channels, time_steps, spec_dim // 4] + emb = e3.permute(0, 2, 3, 1) + # emb shape: [batch_size, time_steps, spec_dim // 4, channels] + emb = emb.flatten(2) + # emb shape: [batch_size, time_steps, spec_dim // 4 * channels] + + emb = self.combine(emb, cemb) + # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2] + # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels] + + emb, h = self.emb_gru.forward(emb, hidden_state) + # emb shape: [batch_size, time_steps, spec_dim // 4 * channels] + # h shape: [batch_size, 1, spec_dim] + + lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset + # lsnr shape: [batch_size, time_steps, 1] + + return e0, e1, e2, e3, emb, c0, lsnr, h + + +class Decoder(nn.Module): + def __init__(self, config: SpectrumDfNetConfig): + super(Decoder, self).__init__() + + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * config.spec_bins // 4 + self.emb_out_dim = config.conv_channels * config.spec_bins // 4 + self.emb_hidden_dim = config.decoder_emb_hidden_size + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=config.decoder_emb_num_layers - 1, + batch_first=True, + skip_op=config.decoder_emb_skip_op, + linear_groups=config.decoder_emb_linear_groups, + activation_layer="relu", + ) + self.conv3p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv2p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt2 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv1p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt1 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv0p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv0_out = CausalConv2d( + in_channels=config.conv_channels, + out_channels=1, + kernel_size=config.conv_kernel_size_inner, + activation_layer="sigmoid", + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: + # Estimates erb mask + b, _, t, f8 = e3.shape + + # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] + emb, _ = self.emb_gru(emb) + # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) + e3 = self.convt3(self.conv3p(e3) + emb) + # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + e2 = self.convt2(self.conv2p(e2) + e3) + # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] + e1 = self.convt1(self.conv1p(e1) + e2) + # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] + mask = self.conv0_out(self.conv0p(e0) + e1) + # mask shape: [batch_size, 1, time_steps, freq_dim] + return mask + + +class DfDecoder(nn.Module): + def __init__(self, config: SpectrumDfNetConfig): + super(DfDecoder, self).__init__() + + self.embedding_input_size = config.conv_channels * config.spec_bins // 4 + self.df_decoder_hidden_size = config.df_decoder_hidden_size + self.df_num_layers = config.df_num_layers + + self.df_order = config.df_order + + self.df_bins = config.df_bins + self.df_out_ch = config.df_order * 2 + + self.df_convp = CausalConv2d( + config.conv_channels, + self.df_out_ch, + fstride=1, + kernel_size=(config.df_pathway_kernel_size_t, 1), + separable=True, + bias=False, + ) + self.df_gru = SqueezedGRU_S( + self.embedding_input_size, + self.df_decoder_hidden_size, + num_layers=self.df_num_layers, + batch_first=True, + skip_op="none", + activation_layer="relu", + ) + + if config.df_gru_skip == "none": + self.df_skip = None + elif config.df_gru_skip == "identity": + if config.embedding_hidden_size != config.df_decoder_hidden_size: + raise AssertionError("Dimensions do not match") + self.df_skip = nn.Identity() + elif config.df_gru_skip == "grouped_linear": + self.df_skip = GroupedLinear( + self.embedding_input_size, + self.df_decoder_hidden_size, + groups=config.df_decoder_linear_groups + ) + else: + raise NotImplementedError() + + self.df_out: nn.Module + out_dim = self.df_bins * self.df_out_ch + + self.df_out = nn.Sequential( + GroupedLinear( + input_size=self.df_decoder_hidden_size, + hidden_size=out_dim, + groups=config.df_decoder_linear_groups + ), + nn.Tanh() + ) + self.df_fc_a = nn.Sequential( + nn.Linear(self.df_decoder_hidden_size, 1), + nn.Sigmoid() + ) + + def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: + # emb shape: [batch_size, time_steps, df_bins // 4 * channels] + b, t, _ = emb.shape + df_coefs, _ = self.df_gru(emb) + if self.df_skip is not None: + df_coefs = df_coefs + self.df_skip(emb) + # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] + + # c0 shape: [batch_size, channels, time_steps, df_bins] + c0 = self.df_convp(c0) + # c0 shape: [batch_size, df_order * 2, time_steps, df_bins] + c0 = c0.permute(0, 2, 3, 1) + # c0 shape: [batch_size, time_steps, df_bins, df_order * 2] + + df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order + # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] + df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) + # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] + df_coefs = df_coefs + c0 + # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] + return df_coefs + + +class DfOutputReshapeMF(nn.Module): + """Coefficients output reshape for multiframe/MultiFrameModule + + Requires input of shape B, C, T, F, 2. + """ + + def __init__(self, df_order: int, df_bins: int): + super().__init__() + self.df_order = df_order + self.df_bins = df_bins + + def forward(self, coefs: torch.Tensor) -> torch.Tensor: + # [B, T, F, O*2] -> [B, O, T, F, 2] + new_shape = list(coefs.shape) + new_shape[-1] = -1 + new_shape.append(2) + coefs = coefs.view(new_shape) + coefs = coefs.permute(0, 3, 1, 2, 4) + return coefs + + +class Mask(nn.Module): + def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): + super().__init__() + self.use_post_filter = use_post_filter + self.eps = eps + + def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: + """ + Post-Filter + + A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. + https://arxiv.org/abs/2008.04259 + + :param mask: Real valued mask, typically of shape [B, C, T, F]. + :param beta: Global gain factor. + :return: + """ + mask_sin = mask * torch.sin(np.pi * mask / 2) + mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) + return mask_pf + + def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + # spec shape: [batch_size, 1, time_steps, spec_bins, 2] + + if not self.training and self.use_post_filter: + mask = self.post_filter(mask) + + # mask shape: [batch_size, 1, time_steps, spec_bins] + mask = mask.unsqueeze(4) + # mask shape: [batch_size, 1, time_steps, spec_bins, 1] + return spec * mask + + +class DeepFiltering(nn.Module): + def __init__(self, + df_bins: int, + df_order: int, + lookahead: int = 0, + ): + super(DeepFiltering, self).__init__() + self.df_bins = df_bins + self.df_order = df_order + self.need_unfold = df_order > 1 + self.lookahead = lookahead + + self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) + + def spec_unfold(self, spec: torch.Tensor): + """ + Pads and unfolds the spectrogram according to frame_size. + :param spec: complex Tensor, Spectrogram of shape [B, C, T, F]. + :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. + """ + if self.need_unfold: + # spec shape: [batch_size, spec_bins, time_steps] + spec_pad = self.pad(spec) + # spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins] + spec_unfold = spec_pad.unfold(2, self.df_order, 1) + # spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order] + return spec_unfold + else: + return spec.unsqueeze(-1) + + def forward(self, + spec: torch.Tensor, + coefs: torch.Tensor, + ): + # spec shape: [batch_size, 1, time_steps, spec_bins, 2] + spec_u = self.spec_unfold(torch.view_as_complex(spec)) + # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order] + + # coefs shape: [batch_size, df_order, time_steps, df_bins, 2] + coefs = torch.view_as_complex(coefs) + # coefs shape: [batch_size, df_order, time_steps, df_bins] + spec_f = spec_u.narrow(-2, 0, self.df_bins) + # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order] + + coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:]) + # coefs shape: [batch_size, 1, df_order, time_steps, df_bins] + + spec_f = self.df(spec_f, coefs) + # spec_f shape: [batch_size, 1, time_steps, df_bins] + + if self.training: + spec = spec.clone() + spec[..., :self.df_bins, :] = torch.view_as_real(spec_f) + # spec shape: [batch_size, 1, time_steps, spec_bins, 2] + return spec + + @staticmethod + def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: + """ + Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. + :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. + :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. + :return: (complex Tensor). Spectrogram of shape [B, C, T, F]. + """ + return torch.einsum("...tfn,...ntf->...tf", spec, coefs) + + +class SpectrumDfNet(nn.Module): + def __init__(self, config: SpectrumDfNetConfig): + super(SpectrumDfNet, self).__init__() + self.config = config + self.encoder = Encoder(config) + self.decoder = Decoder(config) + + self.df_decoder = DfDecoder(config) + self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) + self.df_op = DeepFiltering( + df_bins=config.df_bins, + df_order=config.df_order, + lookahead=config.df_lookahead, + ) + + self.mask = Mask(use_post_filter=config.use_post_filter) + + def forward(self, + spec_complex: torch.Tensor, + ): + feat_power = torch.square(torch.abs(spec_complex)) + feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2) + # feat_power shape: [batch_size, spec_bins, time_steps] + # feat_power shape: [batch_size, 1, spec_bins, time_steps] + # feat_power shape: [batch_size, 1, time_steps, spec_bins] + feat_power = feat_power.detach() + + # spec shape: [batch_size, spec_bins, time_steps] + feat_spec = torch.view_as_real(spec_complex) + # spec shape: [batch_size, spec_bins, time_steps, 2] + feat_spec = feat_spec.permute(0, 3, 2, 1) + # feat_spec shape: [batch_size, 2, time_steps, spec_bins] + feat_spec = feat_spec[..., :self.df_decoder.df_bins] + # feat_spec shape: [batch_size, 2, time_steps, df_bins] + feat_spec = feat_spec.detach() + + # spec shape: [batch_size, spec_bins, time_steps] + spec = torch.unsqueeze(spec_complex, dim=1) + # spec shape: [batch_size, 1, spec_bins, time_steps] + spec = spec.permute(0, 1, 3, 2) + # spec shape: [batch_size, 1, time_steps, spec_bins] + spec = torch.view_as_real(spec) + # spec shape: [batch_size, 1, time_steps, spec_bins, 2] + spec = spec.detach() + + e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec) + + mask = self.decoder.forward(emb, e3, e2, e1, e0) + # mask shape: [batch_size, 1, time_steps, spec_bins] + if torch.any(mask > 1) or torch.any(mask < 0): + raise AssertionError + + spec_m = self.mask.forward(spec, mask) + + # lsnr shape: [batch_size, time_steps, 1] + lsnr = torch.transpose(lsnr, dim0=2, dim1=1) + # lsnr shape: [batch_size, 1, time_steps] + + df_coefs = self.df_decoder.forward(emb, c0) + df_coefs = self.df_out_transform(df_coefs) + # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2] + + spec_e = self.df_op.forward(spec.clone(), df_coefs) + # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2] + + spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :] + + spec_e = torch.squeeze(spec_e, dim=1) + spec_e = spec_e.permute(0, 2, 1, 3) + # spec_e shape: [batch_size, spec_bins, time_steps, 2] + + mask = torch.squeeze(mask, dim=1) + mask = mask.permute(0, 2, 1) + # mask shape: [batch_size, spec_bins, time_steps] + + return spec_e, mask, lsnr + + +class SpectrumDfNetPretrainedModel(SpectrumDfNet): + def __init__(self, + config: SpectrumDfNetConfig, + ): + super(SpectrumDfNetPretrainedModel, self).__init__( + config=config, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = SpectrumDfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + + transformer = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, + power=None, + ) + + config = SpectrumDfNetConfig() + model = SpectrumDfNet(config=config) + + inputs = torch.randn(size=(1, 16000), dtype=torch.float32) + spec_complex = transformer.forward(inputs) + spec_complex = spec_complex[:, :-1, :] + + output = model.forward(spec_complex) + print(output[1].shape) + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/__init__.py b/toolbox/torchaudio/models/spectrum_unet_irm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py b/toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb4253eafffae89476b13feefc5b1ca8f59c581 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py @@ -0,0 +1,76 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Tuple + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class SpectrumUnetIRMConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + + spec_bins: int = 256, + + conv_channels: int = 64, + conv_kernel_size_input: Tuple[int, int] = (3, 3), + conv_kernel_size_inner: Tuple[int, int] = (1, 3), + conv_lookahead: int = 0, + + convt_kernel_size_inner: Tuple[int, int] = (1, 3), + + encoder_emb_skip_op: str = "none", + encoder_emb_linear_groups: int = 16, + encoder_emb_hidden_size: int = 256, + + lsnr_max: int = 30, + lsnr_min: int = -15, + + decoder_emb_num_layers: int = 3, + decoder_emb_skip_op: str = "none", + decoder_emb_linear_groups: int = 16, + decoder_emb_hidden_size: int = 256, + + use_post_filter: bool = False, + **kwargs + ): + super(SpectrumUnetIRMConfig, self).__init__(**kwargs) + # transform + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + # spectrum + self.spec_bins = spec_bins + + # conv + self.conv_channels = conv_channels + self.conv_kernel_size_input = conv_kernel_size_input + self.conv_kernel_size_inner = conv_kernel_size_inner + self.conv_lookahead = conv_lookahead + + self.convt_kernel_size_inner = convt_kernel_size_inner + + # encoder + self.encoder_emb_skip_op = encoder_emb_skip_op + self.encoder_emb_linear_groups = encoder_emb_linear_groups + self.encoder_emb_hidden_size = encoder_emb_hidden_size + + self.lsnr_max = lsnr_max + self.lsnr_min = lsnr_min + + # decoder + self.decoder_emb_num_layers = decoder_emb_num_layers + self.decoder_emb_skip_op = decoder_emb_skip_op + self.decoder_emb_linear_groups = decoder_emb_linear_groups + self.decoder_emb_hidden_size = decoder_emb_hidden_size + + # runtime + self.use_post_filter = use_post_filter + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py b/toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..67e578f177c053fd0b9ee1a0d67098253621fd3b --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py @@ -0,0 +1,650 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os +import math +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torchaudio + +from toolbox.torchaudio.models.spectrum_unet_irm.configuration_specturm_unet_irm import SpectrumUnetIRMConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE + + +MODEL_FILE = "model.pt" + + +norm_layer_dict = { + "batch_norm_2d": torch.nn.BatchNorm2d +} + + +activation_layer_dict = { + "relu": torch.nn.ReLU, + "identity": torch.nn.Identity, + "sigmoid": torch.nn.Sigmoid, +} + + +class CausalConv2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal Conv2d by delaying the signal for any lookahead. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + + :param in_channels: + :param out_channels: + :param kernel_size: + :param fstride: + :param dilation: + :param fpad: + """ + super(CausalConv2d, self).__init__() + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) + + if fpad: + fpad_ = kernel_size[1] // 2 + dilation - 1 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = list() + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + if max(kernel_size) == 1: + separable = False + + layers.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + def forward(self, inputs): + for module in self: + inputs = module(inputs) + return inputs + + +class CausalConvTranspose2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal ConvTranspose2d. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + """ + super(CausalConvTranspose2d, self).__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + if fpad: + fpad_ = kernel_size[1] // 2 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + + layers.append( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size[0] - 1, fpad_ + dilation - 1), + output_padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + +class GroupedLinear(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" + assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" + self.ws = input_size // groups + self.register_parameter( + "weight", + torch.nn.Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True + ), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., I] + b, t, _ = x.shape + # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] + new_shape = (b, t, self.groups, self.ws) + x = x.view(new_shape) + # The better way, but not supported by torchscript + # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] + x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] + x = x.flatten(2, 3) # [B, T, H] + return x + + def __repr__(self): + cls = self.__class__.__name__ + return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" + + +class SqueezedGRU_S(nn.Module): + """ + SGE net: Video object detection with squeezed GRU and information entropy map + https://arxiv.org/abs/2106.07224 + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + skip_op: str = "none", + activation_layer: str = "identity", + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + self.linear_in = nn.Sequential( + GroupedLinear( + input_size=input_size, + hidden_size=hidden_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + + # gru skip operator + self.gru_skip_op = None + + if skip_op == "none": + self.gru_skip_op = None + elif skip_op == "identity": + if not input_size != output_size: + raise AssertionError("Dimensions do not match") + self.gru_skip_op = nn.Identity() + elif skip_op == "grouped_linear": + self.gru_skip_op = GroupedLinear( + input_size=hidden_size, + hidden_size=hidden_size, + groups=linear_groups, + ) + else: + raise NotImplementedError() + + self.gru = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + bidirectional=False, + ) + + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinear( + input_size=hidden_size, + hidden_size=output_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + else: + self.linear_out = nn.Identity() + + def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.linear_in(inputs) + + x, h = self.gru.forward(x, h) + + x = self.linear_out(x) + + if self.gru_skip_op is not None: + x = x + self.gru_skip_op(inputs) + + return x, h + + +class Encoder(nn.Module): + def __init__(self, config: SpectrumUnetIRMConfig): + super(Encoder, self).__init__() + + self.conv0 = CausalConv2d( + in_channels=1, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv2 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + # emb_gru + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * (config.spec_bins // 4) + self.emb_out_dim = config.conv_channels * (config.spec_bins // 4) + self.emb_hidden_dim = config.encoder_emb_hidden_size + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=1, + batch_first=True, + skip_op=config.encoder_emb_skip_op, + linear_groups=config.encoder_emb_linear_groups, + activation_layer="relu", + ) + + # lsnr + self.lsnr_fc = nn.Sequential( + nn.Linear(self.emb_out_dim, 1), + nn.Sigmoid() + ) + self.lsnr_scale = config.lsnr_max - config.lsnr_min + self.lsnr_offset = config.lsnr_min + + def forward(self, + spec: torch.Tensor, + hidden_state: torch.Tensor = None, + ): + # spec shape: (batch_size, 1, time_steps, spec_dim) + e0 = self.conv0.forward(spec) + e1 = self.conv1.forward(e0) + e2 = self.conv2.forward(e1) + e3 = self.conv3.forward(e2) + + # e3 shape: [batch_size, channels, time_steps, hidden_size] + emb = e3.permute(0, 2, 3, 1) + # emb shape: [batch_size, time_steps, hidden_size, channels] + emb = emb.flatten(2) + # emb shape: [batch_size, time_steps, hidden_size * channels] + emb, h = self.emb_gru.forward(emb, hidden_state) + + lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset + + return e0, e1, e2, e3, emb, lsnr + + +class Decoder(nn.Module): + def __init__(self, config: SpectrumUnetIRMConfig): + super(Decoder, self).__init__() + + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * config.spec_bins // 4 + self.emb_out_dim = config.conv_channels * config.spec_bins // 4 + self.emb_hidden_dim = config.decoder_emb_hidden_size + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=config.decoder_emb_num_layers - 1, + batch_first=True, + skip_op=config.decoder_emb_skip_op, + linear_groups=config.decoder_emb_linear_groups, + activation_layer="relu", + ) + self.conv3p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv2p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt2 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv1p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt1 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv0p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv0_out = CausalConv2d( + in_channels=config.conv_channels, + out_channels=1, + kernel_size=config.conv_kernel_size_inner, + activation_layer="sigmoid", + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + + def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: + # Estimates erb mask + b, _, t, f8 = e3.shape + + # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] + emb, _ = self.emb_gru(emb) + # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) + e3 = self.convt3(self.conv3p(e3) + emb) + # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + e2 = self.convt2(self.conv2p(e2) + e3) + # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] + e1 = self.convt1(self.conv1p(e1) + e2) + # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] + mask = self.conv0_out(self.conv0p(e0) + e1) + # mask shape: [batch_size, 1, time_steps, freq_dim] + return mask + + +class SpectrumUnetIRM(nn.Module): + def __init__(self, config: SpectrumUnetIRMConfig, eps: float = 1e-8): + super(SpectrumUnetIRM, self).__init__() + self.config = config + self.encoder = Encoder(config) + self.decoder = Decoder(config) + + self.eps = eps + + def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: + """ + 总体上来说, 它会将 mask 中的值都调大一点. 可能是为了保留更多的声音以免损伤音质, 因为预测的 mask 肯定不是特别正确. + 这个不参与训练, 只在推理时应用在 mask 上. + + Post-Filter + + A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. + https://arxiv.org/abs/2008.04259 + + :param mask: Real valued mask, typically of shape [B, C, T, F]. + :param beta: Global gain factor. + :return: + """ + mask_sin = mask * torch.sin(np.pi * mask / 2) + mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) + return mask_pf + + def forward(self, + spec: torch.Tensor, + ): + # spec shape: [batch_size, freq_dim, time_steps] + # spec shape: [batch_size, 1, freq_dim, time_steps] + + spec = spec.unsqueeze(1).permute(0, 1, 3, 2) + # spec shape: [batch_size, channel, time_steps, freq_dim] + + e0, e1, e2, e3, emb, lsnr = self.encoder.forward(spec) + + # e0 shape: [batch_size, conv_channels, time_steps, freq_dim] + # e1 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] + # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] + # lsnr shape: [batch_size, time_steps, 1] + # h shape: [batch_size, 1, freq_dim] + + mask = self.decoder.forward(emb, e3, e2, e1, e0) + + if torch.any(mask > 1): + raise AssertionError + if torch.any(mask < 0): + raise AssertionError + + # mask shape: [batch_size, 1, time_steps, freq_dim] + # lsnr shape: [batch_size, time_steps, 1] + mask = torch.squeeze(mask, dim=1) + mask = torch.transpose(mask, dim0=2, dim1=1) + lsnr = torch.transpose(lsnr, dim0=2, dim1=1) + + if not self.training and self.config.use_post_filter: + mask = self.post_filter(mask) + + # mask shape: [batch_size, freq_dim, time_steps] + # lsnr shape: [batch_size, 1, time_steps] + return mask, lsnr + + +class SpectrumUnetIRMPretrainedModel(SpectrumUnetIRM): + def __init__(self, + config: SpectrumUnetIRMConfig, + ): + super(SpectrumUnetIRMPretrainedModel, self).__init__( + config=config, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = SpectrumUnetIRMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + transformer = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, + ) + + config = SpectrumUnetIRMConfig() + model = SpectrumUnetIRM(config=config) + + inputs = torch.randn(size=(1, 16000), dtype=torch.float32) + spec = transformer.forward(inputs) + spec = spec[:, :-1, :] + + output = model.forward(spec) + print(output[0].shape) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml b/toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..756e93747642a4c1bdfebc82c739c341c85a5855 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml @@ -0,0 +1,38 @@ +model_name: "spectrum_unet_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +spec_bins: 256 + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +conv_lookahead: 0 + +convt_kernel_size_inner: + - 1 + - 3 + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +lsnr_max: 30 +lsnr_min: -15 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 + +# runtime +use_post_filter: true diff --git a/toolbox/torchaudio/models/tcnn/__init__.py b/toolbox/torchaudio/models/tcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/tcnn/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/tcnn/modeling_tcnn.py b/toolbox/torchaudio/models/tcnn/modeling_tcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..39b43f6155670b6879843f6b20f77d54c0903bbc --- /dev/null +++ b/toolbox/torchaudio/models/tcnn/modeling_tcnn.py @@ -0,0 +1,353 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/LXP-Never/TCNN +https://github.com/LXP-Never/TCNN/blob/main/TCNN_model.py +https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement + +https://ieeexplore.ieee.org/abstract/document/8683634 + +参考来源: +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement + +""" +from typing import Union + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t + + +class Chomp1d(nn.Module): + def __init__(self, chomp_size: int): + super(Chomp1d, self).__init__() + self.chomp_size = chomp_size + + def forward(self, x: torch.Tensor): + return x[:, :, :-self.chomp_size].contiguous() + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: _size_1_t, + stride: _size_1_t = 1, + padding: Union[str, _size_1_t] = 0, + dilation: _size_1_t = 1, + causal: bool = False, + ): + super(DepthwiseSeparableConv, self).__init__() + # Use `groups` option to implement depthwise convolution + self.depthwise_conv = nn.Conv1d( + in_channels=in_channels, out_channels=in_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + groups=in_channels, + bias=False, + ) + self.chomp1d = Chomp1d(padding) if causal else nn.Identity() + self.prelu = nn.PReLU() + self.norm = nn.BatchNorm1d(in_channels) + self.pointwise_conv = nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, + kernel_size=1, + bias=False, + ) + + def forward(self, x: torch.Tensor): + # x shape: [b, c, t] + x = self.depthwise_conv.forward(x) + # x shape: [b, c, t_pad] + x = self.chomp1d(x) + # x shape: [b, c, t] + x = self.prelu(x) + x = self.norm(x) + x = self.pointwise_conv.forward(x) + return x + + +class ResBlock(nn.Module): + def __init__(self, + in_channels: int, + hidden_channels: int, + kernel_size: _size_1_t, + dilation: _size_1_t = 1, + ): + super(ResBlock, self).__init__() + + self.conv1d = nn.Conv1d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1) + self.prelu = nn.PReLU(num_parameters=1) + self.norm = nn.BatchNorm1d(num_features=hidden_channels) + self.sconv = DepthwiseSeparableConv( + in_channels=hidden_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) * dilation, + dilation=dilation, + causal=True, + ) + + def forward(self, inputs: torch.Tensor): + x = inputs + # x shape: [b, in_channels, t] + x = self.conv1d.forward(x) + # x shape: [b, out_channels, t] + x = self.prelu(x) + x = self.norm(x) + # x shape: [b, out_channels, t] + x = self.sconv.forward(x) + # x shape: [b, in_channels, t] + result = x + inputs + return result + + +class TCNNBlock(nn.Module): + def __init__(self, + in_channels: int, + hidden_channels: int, + kernel_size: int = 3, + init_dilation: int = 2, + num_layers: int = 6 + ): + super(TCNNBlock, self).__init__() + self.layers = nn.ModuleList(modules=[]) + for i in range(num_layers): + dilation_size = init_dilation ** i + # in_channels = in_channels if i == 0 else out_channels + + self.layers.append( + ResBlock( + in_channels, + hidden_channels, + kernel_size, + dilation=dilation_size, + ) + ) + + def forward(self, x: torch.Tensor): + for layer in self.layers: + # x shape: [b, c, t] + x = layer.forward(x) + # x shape: [b, c, t] + return x + + +class TCNN(nn.Module): + def __init__(self): + super(TCNN, self).__init__() + self.win_size = 320 + self.hop_size = 160 + + self.conv2d_1 = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2)), + nn.BatchNorm2d(num_features=16), + nn.PReLU() + ) + self.conv2d_2 = nn.Sequential( + nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)), + nn.BatchNorm2d(num_features=16), + nn.PReLU() + ) + self.conv2d_3 = nn.Sequential( + nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), + nn.BatchNorm2d(num_features=16), + nn.PReLU() + ) + self.conv2d_4 = nn.Sequential( + nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), + nn.BatchNorm2d(num_features=32), + nn.PReLU() + ) + self.conv2d_5 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), + nn.BatchNorm2d(num_features=32), + nn.PReLU() + ) + self.conv2d_6 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), + nn.BatchNorm2d(num_features=64), + nn.PReLU() + ) + self.conv2d_7 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), + nn.BatchNorm2d(num_features=64), + nn.PReLU() + ) + + # 256 = 64 * 4 + self.tcnn_block_1 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) + self.tcnn_block_2 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) + self.tcnn_block_3 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) + + self.dconv2d_7 = nn.Sequential( + nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), + output_padding=(0, 0)), + nn.BatchNorm2d(num_features=64), + nn.PReLU() + ) + self.dconv2d_6 = nn.Sequential( + nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), + output_padding=(0, 0)), + nn.BatchNorm2d(num_features=32), + nn.PReLU() + ) + self.dconv2d_5 = nn.Sequential( + nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), + output_padding=(0, 0)), + nn.BatchNorm2d(num_features=32), + nn.PReLU() + ) + self.dconv2d_4 = nn.Sequential( + nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), + output_padding=(0, 0)), + nn.BatchNorm2d(num_features=16), + nn.PReLU() + ) + self.dconv2d_3 = nn.Sequential( + nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), + output_padding=(0, 1)), + nn.BatchNorm2d(num_features=16), + nn.PReLU() + ) + self.dconv2d_2 = nn.Sequential( + nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2), + output_padding=(0, 1)), + nn.BatchNorm2d(num_features=16), + nn.PReLU() + ) + self.dconv2d_1 = nn.Sequential( + nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2), + output_padding=(0, 0)), + nn.BatchNorm2d(num_features=1), + nn.PReLU() + ) + + def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: + if signal.dim() == 2: + signal = torch.unsqueeze(signal, dim=1) + _, _, n_samples = signal.shape + remainder = (n_samples - self.win_size) % self.hop_size + if remainder > 0: + n_samples_pad = self.hop_size - remainder + signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) + return signal + + def forward(self, + noisy: torch.Tensor, + ): + num_samples = noisy.shape[-1] + noisy = self.signal_prepare(noisy) + batch_size, _, num_samples_pad = noisy.shape + + # n_frame = (num_samples_pad - self.win_size) / self.hop_size + 1 + + # unfold + # noisy shape: [b, 1, num_samples_pad] + noisy = noisy.unsqueeze(1) + # noisy shape: [b, 1, 1, num_samples_pad] + noisy_frame = torch.nn.functional.unfold( + input=noisy, + kernel_size=(1, self.win_size), + padding=(0, 0), + stride=(1, self.hop_size), + ) + # noisy_frame shape: [b, win_size, n_frame] + noisy_frame = noisy_frame.unsqueeze(1) + # noisy_frame shape: [b, 1, win_size, n_frame] + noisy_frame = noisy_frame.permute(0, 1, 3, 2) + # noisy_frame shape: [b, 1, n_frame, win_size] + + denoise_frame = self.forward_chunk(noisy_frame) + # denoise_frame shape: [b, c, n_frame, win_size] + denoise_frame = denoise_frame.squeeze(1) + # denoise_frame shape: [b, n_frame, win_size] + denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad) + # denoise shape: [b, num_samples_pad] + + denoise = denoise[:, :num_samples] + # denoise shape: [b, num_samples] + denoise = torch.unsqueeze(denoise, dim=1) + # denoise shape: [b, 1, num_samples] + return denoise + + def forward_chunk(self, inputs: torch.Tensor): + # inputs shape: [b, c, t, segment_length] + conv2d_1 = self.conv2d_1(inputs) + conv2d_2 = self.conv2d_2(conv2d_1) + conv2d_3 = self.conv2d_3(conv2d_2) + conv2d_4 = self.conv2d_4(conv2d_3) + conv2d_5 = self.conv2d_5(conv2d_4) + conv2d_6 = self.conv2d_6(conv2d_5) + conv2d_7 = self.conv2d_7(conv2d_6) + # shape: [b, c, t, 4] + + reshape_1 = conv2d_7.permute(0, 1, 3, 2) + # shape: [b, c, 4, t] + batch_size, C, frame_len, frame_num = reshape_1.shape + reshape_1 = reshape_1.reshape(batch_size, C * frame_len, frame_num) + # shape: [b, c*4, t] + + tcnn_block_1 = self.tcnn_block_1.forward(reshape_1) + tcnn_block_2 = self.tcnn_block_2.forward(tcnn_block_1) + tcnn_block_3 = self.tcnn_block_3.forward(tcnn_block_2) + + # shape: [b, c*4, t] + reshape_2 = tcnn_block_3.reshape(batch_size, C, frame_len, frame_num) + reshape_2 = reshape_2.permute(0, 1, 3, 2) + # shape: [b, c, t, 4] + + dconv2d_7 = self.dconv2d_7(torch.cat((conv2d_7, reshape_2), dim=1)) + dconv2d_6 = self.dconv2d_6(torch.cat((conv2d_6, dconv2d_7), dim=1)) + dconv2d_5 = self.dconv2d_5(torch.cat((conv2d_5, dconv2d_6), dim=1)) + dconv2d_4 = self.dconv2d_4(torch.cat((conv2d_4, dconv2d_5), dim=1)) + dconv2d_3 = self.dconv2d_3(torch.cat((conv2d_3, dconv2d_4), dim=1)) + dconv2d_2 = self.dconv2d_2(torch.cat((conv2d_2, dconv2d_3), dim=1)) + dconv2d_1 = self.dconv2d_1(torch.cat((conv2d_1, dconv2d_2), dim=1)) + + return dconv2d_1 + + def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int): + # overlap and add + # https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement/blob/main/TCNN/util/utils.py#L40 + + b, t, f = denoise_frame.shape + if f != self.win_size: + raise AssertionError + + denoise = torch.zeros(size=(b, num_samples), dtype=denoise_frame.dtype) + count = torch.zeros(size=(b, num_samples), dtype=torch.float32) + + start = 0 + end = start + self.win_size + for i in range(t): + denoise[..., start:end] += denoise_frame[:, i, :] + count[..., start:end] += 1. + + start += self.hop_size + end = start + self.win_size + + denoise = denoise / count + return denoise + + +def main(): + model = TCNN() + model.eval() + + x = torch.randn(64, 1, 5, 320) + # x = torch.randn(64, 1, 5, 160) + y = model.forward_chunk(x) + print("output", y.shape) + + noisy = torch.randn(size=(2, 16000), dtype=torch.float32) + denoise = model.forward(noisy) + print(f"denoise.shape: {denoise.shape}") + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/models/wave_unet/__init__.py b/toolbox/torchaudio/models/wave_unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/wave_unet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/wave_unet/modeling_wave_unet.py b/toolbox/torchaudio/models/wave_unet/modeling_wave_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e6429cae0b97658ecf47842007ce4fd7839443f5 --- /dev/null +++ b/toolbox/torchaudio/models/wave_unet/modeling_wave_unet.py @@ -0,0 +1,12 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/YosukeSugiura/Wave-U-Net-for-Speech-Enhancement-NNabla + +https://arxiv.org/abs/1811.11307 + +""" + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/zip_enhancer/__init__.py b/toolbox/torchaudio/models/zip_enhancer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4aad738e112896111c38ae6624c8632aee62a234 --- /dev/null +++ b/toolbox/torchaudio/models/zip_enhancer/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/zip_enhancer/modeling_zip_enhancer.py b/toolbox/torchaudio/models/zip_enhancer/modeling_zip_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..884b5f71a7a7be095232c0452131b0b552adbeae --- /dev/null +++ b/toolbox/torchaudio/models/zip_enhancer/modeling_zip_enhancer.py @@ -0,0 +1,154 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2501.05183 +https://zipenhancer.github.io/ZipEnhancer/ + +https://modelscope.cn/models/iic/speech_zipenhancer_ans_multiloss_16k_base + +https://github.com/boreas-l/zipEnhancer +""" +import torch +import torch.nn as nn + + +class DenseBlockV2(nn.Module): + def __init__(self, config, kernel_size=(2, 3), depth=4): + super(DenseBlockV2, self).__init__() + self.config = config + self.depth = depth + + self.dense_block = nn.ModuleList([]) + for i in range(depth): + dil = 2 ** i + pad_length = kernel_size[0] + (dil - 1) * (kernel_size[0] - 1) - 1 + dense_conv = nn.Sequential( + nn.ConstantPad2d((1, 1, pad_length, 0), value=0.), + nn.Conv2d( + config.dense_channel * (i + 1), + config.dense_channel, + kernel_size, + dilation=(dil, 1) + ), + nn.InstanceNorm2d(config.dense_channel, affine=True), + nn.PReLU(config.dense_channel) + ) + self.dense_block.append(dense_conv) + + def forward(self, x): + skip = x + # b, c, t, f + for i in range(self.depth): + _x = skip + x = self.dense_block[i](_x) + # print(x.size()) + skip = torch.cat([x, skip], dim=1) + return x + + +class DenseEncoder(nn.Module): + + def __init__(self, config, in_channel): + super(DenseEncoder, self).__init__() + self.config = config + self.dense_conv_1 = nn.Sequential( + nn.Conv2d(in_channel, config.dense_channel, (1, 1)), + nn.InstanceNorm2d(config.dense_channel, affine=True), + nn.PReLU(config.dense_channel) + ) + + self.dense_block = DenseBlockV2(config, depth=4) + + encoder_pad_kersize = (0, 1) + # Here pad was originally (0, 0),now change to (0, 1) + self.dense_conv_2 = nn.Sequential( + nn.Conv2d( + config.dense_channel, + config.dense_channel, + kernel_size=(1, 3), + stride=(1, 2), + padding=encoder_pad_kersize + ), + nn.InstanceNorm2d(config.dense_channel, affine=True), + nn.PReLU(config.dense_channel) + ) + + def forward(self, x): + """ + Forward pass of the DenseEncoder module. + + Args: + x (Tensor): Input tensor of shape [B, C=in_channel, T, F]. + + Returns: + Tensor: Output tensor after passing through the dense encoder. Maybe: [B, C=dense_channel, T, F // 2]. + """ + # print("x: {}".format(x.size())) + x = self.dense_conv_1(x) # [b, 64, T, F] + if self.dense_block is not None: + x = self.dense_block(x) # [b, 64, T, F] + x = self.dense_conv_2(x) # [b, 64, T, F//2] + return x + + +class ZipEnhancer(nn.Module): + + def __init__(self, config): + super(ZipEnhancer, self).__init__() + self.config = config + + num_tsconformers = config.num_tsconformers + self.num_tscblocks = num_tsconformers + + self.dense_encoder = DenseEncoder(config, in_channel=2) + + self.TSConformer = Zipformer2DualPathEncoder( + output_downsampling_factor=1, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + **config.former_conf + ) + + self.mask_decoder = MappingDecoder(config, out_channel=config.model_num_spks) + self.phase_decoder = PhaseDecoder(config, out_channel=config.model_num_spks) + + def forward(self, noisy_mag, noisy_pha): # [B, F, T] + """ + Forward pass of the ZipEnhancer module. + + Args: + noisy_mag (torch.Tensor): Noisy magnitude input torch.tensor of shape [B, F, T]. + noisy_pha (torch.Tensor): Noisy phase input torch.tensor of shape [B, F, T]. + + Returns: + Tuple: denoised magnitude, denoised phase, denoised complex representation, + (optional) predicted noise components, and other auxiliary information. + """ + others = dict() + + noisy_mag = noisy_mag.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F] + noisy_pha = noisy_pha.unsqueeze(-1).permute(0, 3, 2, 1) # [B, 1, T, F] + x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F] + x = self.dense_encoder(x) + + # [B, C, T, F] + x = self.TSConformer(x) + + pred_mag = self.mask_decoder(x) + pred_pha = self.phase_decoder(x) + # b, c, t, f -> b, 1, t, f -> b, f, t, 1 -> b, f, t + denoised_mag = pred_mag[:, 0, :, :].unsqueeze(1).permute(0, 3, 2, + 1).squeeze(-1) + + # b, t, f + denoised_pha = pred_pha[:, 0, :, :].unsqueeze(1).permute(0, 3, 2, + 1).squeeze(-1) + # b, t, f + denoised_com = torch.stack((denoised_mag * torch.cos(denoised_pha), + denoised_mag * torch.sin(denoised_pha)), + dim=-1) + + return denoised_mag, denoised_pha, denoised_com, None, others + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/zip_enhancer/scaling.py b/toolbox/torchaudio/models/zip_enhancer/scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..5005bfe3c8d8ad9f7559eb2bc0202c24085f0214 --- /dev/null +++ b/toolbox/torchaudio/models/zip_enhancer/scaling.py @@ -0,0 +1,249 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/scaling.py +""" +import logging +import random +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + + +def logaddexp_onnx(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + + def __init__(self, *args): + assert len(args) >= 1, len(args) + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [(float(x), float(y)) for x, y in args] + + for x, y in self.pairs: + assert isinstance(x, (float, int)), type(x) + assert isinstance(y, (float, int)), type(y) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if cur_x <= x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / ( + next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, (float, int)): + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear(*[(sp[0], sp[1] + xp[1]) + for sp, xp in zip(s.pairs, x.pairs)]) + + def max(self, x): + if isinstance(x, (float, int)): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear(*[(sp[0], max(sp[1], xp[1])) + for sp, xp in zip(s.pairs, x.pairs)]) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear(*[(sp[0], min(sp[1], xp[1])) + for sp, xp in zip(s.pairs, x.pairs)]) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, + p: 'PiecewiseLinear', + include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise linear + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p cross. + """ + assert isinstance(p, PiecewiseLinear), type(p) + + # get sorted x-values without repetition. + x_vals = sorted( + set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + _compare_results1 = (y_vals1[i] > y_vals2[i]) + _compare_results2 = (y_vals1[i + 1] > y_vals2[i + 1]) + if _compare_results1 != _compare_results2: + # if ((y_vals1[i] > y_vals2[i]) != + # (y_vals1[i + 1] > y_vals2[i + 1])): + # if the two lines in this subsegment potentially cross each other. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specify the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or not in training mode or in + torch.jit scripting mode. + """ + + def __init__(self, *args, default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return ( + f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + ) + + def __float__(self): + batch_count = self.batch_count + if (batch_count is None or not self.training + or torch.jit.is_scripting() or torch.jit.is_tracing()): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.info( + f'ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}' + ) + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, default=self.default) + else: + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), default=self.default) + else: + return ScheduledFloat( + self.schedule.max(x.schedule), + default=max(self.default, x.default)) + + +FloatLike = Union[float, ScheduledFloat] + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: torch.Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/zip_enhancer/zip_enhancer_layer.py b/toolbox/torchaudio/models/zip_enhancer/zip_enhancer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e998208122e7f32ebb63a86ed0bcf22263b58d26 --- /dev/null +++ b/toolbox/torchaudio/models/zip_enhancer/zip_enhancer_layer.py @@ -0,0 +1,9 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/zipenhancer_layer.py +""" + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/zip_enhancer/zipformer.py b/toolbox/torchaudio/models/zip_enhancer/zipformer.py new file mode 100644 index 0000000000000000000000000000000000000000..476e5dae72d8e2db7e69d130167a1569e3841a80 --- /dev/null +++ b/toolbox/torchaudio/models/zip_enhancer/zipformer.py @@ -0,0 +1,9 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/boreas-l/zipEnhancer/blob/main/models/layers/zipformer.py +""" + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/modules/__init__.py b/toolbox/torchaudio/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/modules/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/modules/conv_stft.py b/toolbox/torchaudio/modules/conv_stft.py new file mode 100644 index 0000000000000000000000000000000000000000..8702596edf527a94deb8a121b46299d04aee8d1c --- /dev/null +++ b/toolbox/torchaudio/modules/conv_stft.py @@ -0,0 +1,271 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py +""" +from collections import defaultdict +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.signal import get_window + + +def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False): + if win_type == "None" or win_type is None: + window = np.ones(win_size) + else: + window = get_window(win_type, win_size, fftbins=True)**0.5 + + fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size] + real_kernel = np.real(fourier_basis) + image_kernel = np.imag(fourier_basis) + kernel = np.concatenate([real_kernel, image_kernel], 1).T + + if inverse: + kernel = np.linalg.pinv(kernel).T + + kernel = kernel * window + kernel = kernel[:, None, :] + result = ( + torch.from_numpy(kernel.astype(np.float32)), + torch.from_numpy(window[None, :, None].astype(np.float32)) + ) + return result + + +class ConvSTFT(nn.Module): + + def __init__(self, + nfft: int, + win_size: int, + hop_size: int, + win_type: str = "hamming", + power: int = None, + requires_grad: bool = False): + super(ConvSTFT, self).__init__() + + if nfft is None: + self.nfft = int(2**np.ceil(np.log2(win_size))) + else: + self.nfft = nfft + + kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type) + self.weight = nn.Parameter(kernel, requires_grad=requires_grad) + + self.win_size = win_size + self.hop_size = hop_size + + self.stride = hop_size + self.dim = self.nfft + self.power = power + + def forward(self, waveform: torch.Tensor): + if waveform.dim() == 2: + waveform = torch.unsqueeze(waveform, 1) + + matrix = F.conv1d(waveform, self.weight, stride=self.stride) + dim = self.dim // 2 + 1 + real = matrix[:, :dim, :] + imag = matrix[:, dim:, :] + spec = torch.complex(real, imag) + # spec shape: [b, f, t], torch.complex64 + + if self.power is None: + return spec + elif self.power == 1: + mags = torch.sqrt(real**2 + imag**2) + # phase = torch.atan2(imag, real) + return mags + elif self.power == 2: + power = real**2 + imag**2 + return power + else: + raise AssertionError + + +class ConviSTFT(nn.Module): + + def __init__(self, + win_size: int, + hop_size: int, + nfft: int = None, + win_type: str = "hamming", + requires_grad: bool = False): + super(ConviSTFT, self).__init__() + if nfft is None: + self.nfft = int(2**np.ceil(np.log2(win_size))) + else: + self.nfft = nfft + + kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True) + self.weight = nn.Parameter(kernel, requires_grad=requires_grad) + # weight shape: [f*2, 1, nfft] + # f = nfft // 2 + 1 + + self.win_size = win_size + self.hop_size = hop_size + self.win_type = win_type + + self.stride = hop_size + self.dim = self.nfft + + self.register_buffer("window", window) + self.register_buffer("enframe", torch.eye(win_size)[:, None, :]) + # window shape: [1, nfft, 1] + # enframe shape: [nfft, 1, nfft] + + def forward(self, + spec: torch.Tensor): + """ + self.weight shape: [f*2, 1, win_size] + self.window shape: [1, win_size, 1] + self.enframe shape: [win_size, 1, win_size] + + :param spec: torch.Tensor, shape: [b, f, t, 2] + :return: + """ + spec = torch.view_as_real(spec) + # spec shape: [b, f, t, 2] + matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1) + # matrix shape: [b, f*2, t] + + waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride) + # waveform shape: [b, 1, num_samples] + + # this is from torch-stft: https://github.com/pseeth/torch-stft + t = self.window.repeat(1, 1, matrix.size(-1))**2 + # t shape: [1, win_size, t] + coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) + # coff shape: [1, 1, num_samples] + waveform = waveform / (coff + 1e-8) + # waveform = waveform / coff + return waveform + + @torch.no_grad() + def forward_chunk(self, + spec: torch.Tensor, + cache_dict: dict = None + ): + """ + :param spec: shape: [b, f, t] + :param cache_dict: dict, + waveform_cache shape: [b, 1, win_size - hop_size] + coff_cache shape: [b, 1, win_size - hop_size] + :return: + """ + if cache_dict is None: + cache_dict = defaultdict(lambda: None) + waveform_cache = cache_dict["waveform_cache"] + coff_cache = cache_dict["coff_cache"] + + spec = torch.view_as_real(spec) + matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1) + + waveform_current = F.conv_transpose1d(matrix, self.weight, stride=self.stride) + + t = self.window.repeat(1, 1, matrix.size(-1))**2 + coff_current = F.conv_transpose1d(t, self.enframe, stride=self.stride) + + overlap_size = self.win_size - self.hop_size + + if waveform_cache is not None: + waveform_current[:, :, :overlap_size] += waveform_cache + waveform_output = waveform_current[:, :, :self.hop_size] + new_waveform_cache = waveform_current[:, :, self.hop_size:] + + if coff_cache is not None: + coff_current[:, :, :overlap_size] += coff_cache + coff_output = coff_current[:, :, :self.hop_size] + new_coff_cache = coff_current[:, :, self.hop_size:] + + waveform_output = waveform_output / (coff_output + 1e-8) + + new_cache_dict = { + "waveform_cache": new_waveform_cache, + "coff_cache": new_coff_cache, + } + return waveform_output, new_cache_dict + + +def main(): + nfft = 512 + win_size = 512 + hop_size = 256 + + stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None) + istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size) + + mixture = torch.rand(size=(1, 16000), dtype=torch.float32) + b, num_samples = mixture.shape + t = (num_samples - win_size) / hop_size + 1 + + spec = stft.forward(mixture) + b, f, t = spec.shape + + # 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。 + # spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32) + print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") + + waveform = istft.forward(spec) + # shape: [batch_size, channels, num_samples] + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32) + for i in range(int(t)): + begin = i * hop_size + end = begin + win_size + sub_spec = spec[:, :, i:i+1] + sub_waveform = istft.forward(sub_spec) + # (b, 1, win_size) + waveform[:, :, begin:end] = sub_waveform + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + return + + +def main2(): + nfft = 512 + win_size = 512 + hop_size = 256 + + stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None) + istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size) + + mixture = torch.rand(size=(1, 16128), dtype=torch.float32) + b, num_samples = mixture.shape + + spec = stft.forward(mixture) + b, f, t = spec.shape + + # 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。 + spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32) + print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") + + waveform = istft.forward(spec) + # shape: [batch_size, channels, num_samples] + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + cache_dict = None + waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32) + for i in range(int(t)): + sub_spec = spec[:, :, i:i+1] + begin = i * hop_size + + end = begin + win_size - hop_size + sub_waveform, cache_dict = istft.forward_chunk(sub_spec, cache_dict=cache_dict) + # end = begin + win_size + # sub_waveform = istft.forward(sub_spec) + + waveform[:, :, begin:end] = sub_waveform + print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") + print(waveform[:, :, 300: 302]) + + return + + +if __name__ == "__main__": + main2() diff --git a/toolbox/torchaudio/modules/freq_bands/__init__.py b/toolbox/torchaudio/modules/freq_bands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/modules/freq_bands/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/modules/freq_bands/erb_bands.py b/toolbox/torchaudio/modules/freq_bands/erb_bands.py new file mode 100644 index 0000000000000000000000000000000000000000..5c94da702b7f6deaba6fbb7c21675fc2e726e861 --- /dev/null +++ b/toolbox/torchaudio/modules/freq_bands/erb_bands.py @@ -0,0 +1,176 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import math + +import numpy as np +import torch +import torch.nn as nn + + +class ErbBandsNumpy(object): + + @staticmethod + def freq2erb(freq_hz: float) -> float: + """ + https://www.cnblogs.com/LXP-Never/p/16011229.html + 1 / (24.7 * 9.265) = 0.00436976 + """ + return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1) + + @staticmethod + def erb2freq(n_erb: float) -> float: + return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1) + + @classmethod + def get_erb_widths(cls, sample_rate: int, nfft: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray: + """ + https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs + :param sample_rate: + :param nfft: + :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数. + :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band + :return: + """ + nyq_freq = sample_rate / 2. + freq_width: float = sample_rate / nfft + + min_erb: float = cls.freq2erb(0.) + max_erb: float = cls.freq2erb(nyq_freq) + + erb = [0] * erb_bins + step = (max_erb - min_erb) / erb_bins + + prev_freq_bin = 0 + freq_over = 0 + for i in range(1, erb_bins + 1): + f = cls.erb2freq(min_erb + i * step) + freq_bin = int(round(f / freq_width)) + freq_bins = freq_bin - prev_freq_bin - freq_over + + if freq_bins < min_freq_bins_for_erb: + freq_over = min_freq_bins_for_erb - freq_bins + freq_bins = min_freq_bins_for_erb + else: + freq_over = 0 + erb[i - 1] = freq_bins + prev_freq_bin = freq_bin + + erb[erb_bins - 1] += 1 + too_large = sum(erb) - (nfft / 2 + 1) + if too_large > 0: + erb[erb_bins - 1] -= too_large + return np.array(erb, dtype=np.uint64) + + @staticmethod + def get_erb_filter_bank(erb_widths: np.ndarray, + normalized: bool = True, + inverse: bool = False, + ): + num_freq_bins = int(np.sum(erb_widths)) + num_erb_bins = len(erb_widths) + + fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins)) + + points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1] + for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())): + fb[b: b + w, i] = 1 + + if inverse: + fb = fb.T + if not normalized: + fb /= np.sum(fb, axis=1, keepdims=True) + else: + if normalized: + fb /= np.sum(fb, axis=0) + return fb + + @staticmethod + def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True): + """ + ERB filterbank and transform to decibel scale. + + :param spec: Spectrum of shape [B, C, T, F]. + :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths, + where B are the number of ERB bins. + :param db: Whether to transform the output into decibel scale. Defaults to `True`. + :return: + """ + # complex spec to power spec. (real * real + image * image) + spec_ = np.abs(spec) ** 2 + + # spec to erb feature. + erb_feat = np.matmul(spec_, erb_fb) + + if db: + erb_feat = 10 * np.log10(erb_feat + 1e-10) + + erb_feat = np.array(erb_feat, dtype=np.float32) + return erb_feat + + +class ErbBands(nn.Module): + def __init__(self, + sample_rate: int = 8000, + nfft: int = 512, + erb_bins: int = 32, + min_freq_bins_for_erb: int = 2, + ): + super().__init__() + self.sample_rate = sample_rate + self.nfft = nfft + self.erb_bins = erb_bins + self.min_freq_bins_for_erb = min_freq_bins_for_erb + + erb_fb, erb_fb_inv = self.init_erb_fb() + erb_fb = torch.tensor(erb_fb, dtype=torch.float32, requires_grad=False) + erb_fb_inv = torch.tensor(erb_fb_inv, dtype=torch.float32, requires_grad=False) + self.erb_fb = nn.Parameter(erb_fb, requires_grad=False) + self.erb_fb_inv = nn.Parameter(erb_fb_inv, requires_grad=False) + + def init_erb_fb(self): + erb_widths = ErbBandsNumpy.get_erb_widths( + sample_rate=self.sample_rate, + nfft=self.nfft, + erb_bins=self.erb_bins, + min_freq_bins_for_erb=self.min_freq_bins_for_erb, + ) + erb_fb = ErbBandsNumpy.get_erb_filter_bank( + erb_widths=erb_widths, + normalized=True, + inverse=False, + ) + erb_fb_inv = ErbBandsNumpy.get_erb_filter_bank( + erb_widths=erb_widths, + normalized=True, + inverse=True, + ) + return erb_fb, erb_fb_inv + + def erb_scale(self, spec: torch.Tensor, db: bool = True): + # spec shape: (b, t, f) + spec_erb = torch.matmul(spec, self.erb_fb) + if db: + spec_erb = 10 * torch.log10(spec_erb + 1e-10) + return spec_erb + + def erb_scale_inv(self, spec_erb: torch.Tensor): + spec = torch.matmul(spec_erb, self.erb_fb_inv) + return spec + + +def main(): + + erb_bands = ErbBands() + + spec = torch.randn(size=(2, 199, 257), dtype=torch.float32) + spec_erb = erb_bands.erb_scale(spec) + print(spec_erb.shape) + + spec = erb_bands.erb_scale_inv(spec_erb) + print(spec.shape) + + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/modules/local_snr_target.py b/toolbox/torchaudio/modules/local_snr_target.py new file mode 100644 index 0000000000000000000000000000000000000000..e09be53e2f3550c943e5743ea2f597327ac7fdb2 --- /dev/null +++ b/toolbox/torchaudio/modules/local_snr_target.py @@ -0,0 +1,151 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816 +""" +from typing import Tuple + +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchaudio + + +def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor: + if n_frame % 2 == 0: + n_frame += 1 + n_frame_half = n_frame // 2 + + # spec shape: [b, c, t, f, 2] + spec = spec.pow(2).sum(-1).sum(-1) + # spec shape: [b, c, t] + spec = F.pad(spec, (n_frame_half, n_frame_half, 0, 0)) + # spec shape: [b, c, t-pad] + + weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype) + # w shape: [n_frame] + + spec = spec.unfold(-1, size=n_frame, step=1) * weight + # x shape: [b, c, t, n_frame] + + result = torch.sum(spec, dim=-1).div(n_frame) + # result shape: [b, c, t] + return result + + +def local_snr(spec_clean: torch.Tensor, + spec_noise: torch.Tensor, + n_frame: int = 5, + db: bool = False, + eps: float = 1e-12, + ): + # [b, c, t, f] + spec_clean = torch.view_as_real(spec_clean) + spec_noise = torch.view_as_real(spec_noise) + # [b, c, t, f, 2] + + energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device) + energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device) + # [b, c, t] + + snr = energy_clean / energy_noise.clamp_min(eps) + # snr shape: [b, c, t] + + if db: + snr = snr.clamp_min(eps).log10().mul(10) + return snr, energy_clean, energy_noise + + +class LocalSnrTarget(nn.Module): + def __init__(self, + sample_rate: int = 8000, + nfft: int = 512, + win_size: int = 512, + hop_size: int = 256, + + n_frame: int = 3, + + min_local_snr: int = -15, + max_local_snr: int = 30, + + db: bool = True, + ): + super().__init__() + self.sample_rate = sample_rate + self.nfft = nfft + self.win_size = win_size + self.hop_size = hop_size + + self.n_frame = n_frame + + self.min_local_snr = min_local_snr + self.max_local_snr = max_local_snr + + self.db = db + + def forward(self, + spec_clean: torch.Tensor, + spec_noise: torch.Tensor, + ) -> torch.Tensor: + """ + + :param spec_clean: torch.complex, shape: [b, c, t, f] + :param spec_noise: torch.complex, shape: [b, c, t, f] + :return: lsnr, shape: [b, t] + """ + + lsnr, _, _ = local_snr( + spec_clean=spec_clean, + spec_noise=spec_noise, + n_frame=self.n_frame, + db=self.db, + ) + # lsnr shape: [b, c, t] + lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1) + # lsnr shape: [b, t] + return lsnr + + +def main(): + sample_rate = 8000 + nfft = 512 + win_size = 512 + hop_size = 256 + window_fn = "hamming" + + transform = torchaudio.transforms.Spectrogram( + n_fft=nfft, + win_length=win_size, + hop_length=hop_size, + power=None, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + noisy = torch.randn(size=(1, 16000), dtype=torch.float32) + + spec = transform.forward(noisy) + spec = spec.permute(0, 2, 1) + spec = torch.unsqueeze(spec, dim=1) + print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") + + # [b, c, t, f] + # spec = torch.view_as_real(spec) + # [b, c, t, f, 2] + + local = LocalSnrTarget( + sample_rate=sample_rate, + nfft=nfft, + win_size=win_size, + hop_size=hop_size, + n_frame=5, + min_local_snr=-15, + max_local_snr=30, + db=True, + ) + lsnr_target = local.forward(spec, spec) + print(f"lsnr_target.shape: {lsnr_target.shape}") + return + + +if __name__ == "__main__": + main() diff --git a/toolbox/torchaudio/modules/utils/__init__.py b/toolbox/torchaudio/modules/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a66fc40cec5e1bad20c94ebc03002f9772eb07 --- /dev/null +++ b/toolbox/torchaudio/modules/utils/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/modules/utils/ema.py b/toolbox/torchaudio/modules/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..8829990465bd46bbf6c3f49b178f9d338bfc6190 --- /dev/null +++ b/toolbox/torchaudio/modules/utils/ema.py @@ -0,0 +1,203 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import math + +import numpy as np +import torch +import torch.nn as nn + + +class EMANumpy(object): + + @classmethod + def _calculate_norm_alpha(cls, sample_rate: int, hop_size: int, tau: float): + """Exponential decay factor alpha for a given tau (decay window size [s]).""" + dt = hop_size / sample_rate + result = math.exp(-dt / tau) + return result + + @classmethod + def get_norm_alpha(cls, sample_rate: int, hop_size: int, norm_tau: float) -> float: + a_ = cls._calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau) + + precision = 3 + a = 1.0 + while a >= 1.0: + a = round(a_, precision) + precision += 1 + + return a + + +class ErbEMA(nn.Module, EMANumpy): + def __init__(self, + sample_rate: int = 8000, + hop_size: int = 80, + erb_bins: int = 32, + mean_norm_init_start: float = -60., + mean_norm_init_end: float = -90., + norm_tau: float = 1., + ): + super().__init__() + self.sample_rate = sample_rate + self.hop_size = hop_size + self.erb_bins = erb_bins + self.mean_norm_init_start = mean_norm_init_start + self.mean_norm_init_end = mean_norm_init_end + self.norm_tau = norm_tau + + self.alpha = self.get_norm_alpha(sample_rate, hop_size, norm_tau) + + def make_erb_norm_state(self) -> torch.Tensor: + state = torch.linspace(start=self.mean_norm_init_start, end=self.mean_norm_init_end, + steps=self.erb_bins) + state = state.unsqueeze(0).unsqueeze(0) + # state shape: [b, c, erb_bins] + # state shape: [1, 1, erb_bins] + return state + + def norm(self, + feat_erb: torch.Tensor, + state: torch.Tensor = None, + ): + feat_erb = feat_erb.clone() + b, c, t, f = feat_erb.shape + + # erb_feat shape: [b, c, t, f] + if state is None: + state = self.make_erb_norm_state() + state = state.to(feat_erb.device) + state = state.clone() + + for j in range(t): + current = feat_erb[:, :, j, :] + new_state = current * (1 - self.alpha) + state * self.alpha + + feat_erb[:, :, j, :] = (current - new_state) / 40.0 + state = new_state + + return feat_erb, state + + +class SpecEMA(nn.Module, EMANumpy): + """ + https://github.com/grazder/DeepFilterNet/blob/torchDF_main/libDF/src/lib.rs + """ + def __init__(self, + sample_rate: int = 8000, + hop_size: int = 80, + df_bins: int = 96, + unit_norm_init_start: float = 0.001, + unit_norm_init_end: float = 0.0001, + norm_tau: float = 1., + ): + super().__init__() + self.sample_rate = sample_rate + self.hop_size = hop_size + self.df_bins = df_bins + self.unit_norm_init_start = unit_norm_init_start + self.unit_norm_init_end = unit_norm_init_end + self.norm_tau = norm_tau + + self.alpha = self.get_norm_alpha(sample_rate, hop_size, norm_tau) + + def make_spec_norm_state(self) -> torch.Tensor: + state = torch.linspace(start=self.unit_norm_init_start, end=self.unit_norm_init_end, + steps=self.df_bins) + state = state.unsqueeze(0).unsqueeze(0) + # state shape: [b, c, df_bins] + # state shape: [1, 1, df_bins] + return state + + def norm(self, + feat_spec: torch.Tensor, + state: torch.Tensor = None, + ): + feat_spec = feat_spec.clone() + b, c, t, f = feat_spec.shape + + # feat_spec shape: [b, 2, t, df_bins] + if state is None: + state = self.make_spec_norm_state() + state = state.to(feat_spec.device) + state = state.clone() + + for j in range(t): + current = feat_spec[:, :, j, :] + current_abs = torch.sum(torch.square(current), dim=1, keepdim=True) + # current_abs shape: [b, 1, df_bins] + new_state = current_abs * (1 - self.alpha) + state * self.alpha + + feat_spec[:, :, j, :] = current / torch.sqrt(new_state) + state = new_state + + return feat_spec, state + + +MEAN_NORM_INIT = [-60., -90.] + + +def make_erb_norm_state(erb_bins: int, channels: int) -> np.ndarray: + state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins) + state = np.expand_dims(state, axis=0) + state = np.repeat(state, channels, axis=0) + + # state shape: (audio_channels, erb_bins) + return state + + +def erb_normalize(erb_feat: np.ndarray, alpha: float, state: np.ndarray = None): + erb_feat = np.copy(erb_feat) + batch_size, time_steps, erb_bins = erb_feat.shape + + if state is None: + state = make_erb_norm_state(erb_bins, erb_feat.shape[0]) + # state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins) + # state = np.expand_dims(state, axis=0) + # state = np.repeat(state, erb_feat.shape[0], axis=0) + + for i in range(batch_size): + for j in range(time_steps): + for k in range(erb_bins): + x = erb_feat[i][j][k] + s = state[i][k] + + state[i][k] = x * (1. - alpha) + s * alpha + erb_feat[i][j][k] -= state[i][k] + erb_feat[i][j][k] /= 40. + + return erb_feat + + +UNIT_NORM_INIT = [0.001, 0.0001] + + +def make_spec_norm_state(df_bins: int, channels: int) -> np.ndarray: + state = np.linspace(UNIT_NORM_INIT[0], UNIT_NORM_INIT[1], df_bins) + state = np.expand_dims(state, axis=0) + state = np.repeat(state, channels, axis=0) + + # state shape: (audio_channels, df_bins) + return state + + +def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None): + spec_feat = np.copy(spec_feat) + batch_size, time_steps, df_bins = spec_feat.shape + + if state is None: + state = make_spec_norm_state(df_bins, spec_feat.shape[0]) + + for i in range(batch_size): + for j in range(time_steps): + for k in range(df_bins): + x = spec_feat[i][j][k] + s = state[i][k] + + state[i][k] = np.abs(x) * (1. - alpha) + s * alpha + spec_feat[i][j][k] /= np.sqrt(state[i][k]) + return spec_feat + + +if __name__ == "__main__": + pass