diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..39facf920bf31aaed825ccce83f7aa8a11956241 --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ + +.git/ +.idea/ + +**/evaluation_audio/ +**/file_dir/ +**/flagged/ +**/log/ +**/logs/ +**/__pycache__/ + +/data/ +/docs/ +/dotenv/ +/hub_datasets/ +/thirdparty/ +/trained_models/ +/temp/ + +#**/*.wav +**/*.xlsx diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..20a85a401d90f6602a79eca3bb2b2e49c5fe6891 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,21 @@ +FROM python:3.12 + +WORKDIR /code + +COPY . /code + +RUN pip install --upgrade pip +RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt + +RUN useradd -m -u 1000 user + +USER user + +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +WORKDIR $HOME/app + +COPY --chown=user . $HOME/app + +CMD ["python3", "main.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..732a5d6a46b1fc516801d4bd44080b480ff4fb86 --- /dev/null +++ b/README.md @@ -0,0 +1,26 @@ +--- +title: VM Sound Classification +emoji: 🐢 +colorFrom: purple +colorTo: blue +sdk: docker +pinned: false +license: apache-2.0 +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +## NX Denoise + + +### speech datasets + +```text + +AISHELL (15G) +https://openslr.trmal.net/resources/33/ + +AISHELL-3 (19G) +http://www.openslr.org/93/ + +``` + diff --git a/examples/simple_linear_irm_aishell/run.sh b/examples/simple_linear_irm_aishell/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..52d9691c1c3ced86192dfae6da6bed1106306390 --- /dev/null +++ b/examples/simple_linear_irm_aishell/run.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/simple_linear_irm_aishell/step_1_prepare_data.py b/examples/simple_linear_irm_aishell/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f1ff4f3b60ae19116d3716f2c234733f053e95 --- /dev/null +++ b/examples/simple_linear_irm_aishell/step_1_prepare_data.py @@ -0,0 +1,196 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_nsr_db", default=-20, type=float) + parser.add_argument("--max_nsr_db", default=5, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/simple_linear_irm_aishell/step_2_train_model.py b/examples/simple_linear_irm_aishell/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..27742b223497c49b314d41d9e8ff058e7e2f2e64 --- /dev/null +++ b/examples/simple_linear_irm_aishell/step_2_train_model.py @@ -0,0 +1,348 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +from torch import dtype + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.simple_linear_irm.configuration_simple_linear_irm import SimpleLinearIRMConfig +from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-3, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + window_fn: str = "hamming", + irm_beta: float = 1.0, + epsilon: float = 1e-8, + ): + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.window_fn = window_fn + self.irm_beta = irm_beta + self.epsilon = epsilon + + self.transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=2.0, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + def __call__(self, batch: List[dict]): + mix_spec_list = list() + speech_irm_list = list() + snr_db_list = list() + for sample in batch: + noise_wave: torch.Tensor = sample["noise_wave"] + speech_wave: torch.Tensor = sample["speech_wave"] + mix_wave: torch.Tensor = sample["mix_wave"] + snr_db: float = sample["snr_db"] + + noise_spec = self.transform.forward(noise_wave) + speech_spec = self.transform.forward(speech_wave) + mix_spec = self.transform.forward(mix_wave) + + # noise_irm = noise_spec / (noise_spec + speech_spec) + speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) + speech_irm = torch.pow(speech_irm, self.irm_beta) + + mix_spec_list.append(mix_spec) + speech_irm_list.append(speech_irm) + snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32)) + + mix_spec_list = torch.stack(mix_spec_list) + speech_irm_list = torch.stack(speech_irm_list) + snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,) + + # assert + if torch.any(torch.isnan(mix_spec_list)): + raise AssertionError("nan in mix_spec Tensor") + if torch.any(torch.isnan(speech_irm_list)): + raise AssertionError("nan in speech_irm Tensor") + if torch.any(torch.isnan(snr_db_list)): + raise AssertionError("nan in snr_db Tensor") + + return mix_spec_list, speech_irm_list, snr_db_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + config = SimpleLinearIRMConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + # num_labels=vocabulary.get_vocab_size(namespace="labels") + ) + model = SimpleLinearIRMPretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + param_optimizer = model.parameters() + optimizer = torch.optim.Adam( + param_optimizer, + lr=args.learning_rate, + ) + # lr_scheduler = torch.optim.lr_scheduler.StepLR( + # optimizer, + # step_size=2000 + # ) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + mse_loss = nn.MSELoss( + reduction="mean", + ) + + # training loop + logger.info("training") + + training_loss = 10000000000 + evaluation_loss = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + for idx_epoch in range(args.max_epochs): + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + + for batch in train_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + training_loss = total_loss / total_examples + training_loss = round(training_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "training_loss": training_loss, + }) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + for batch in valid_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + elif evaluation_loss < best_metric: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "training_loss": training_loss, + "evaluation_loss": evaluation_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + return + + +if __name__ == '__main__': + main() diff --git a/examples/simple_linear_irm_aishell/step_3_evaluation.py b/examples/simple_linear_irm_aishell/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..fe60fead3d247509e3916f435dc722f9cedad8f9 --- /dev/null +++ b/examples/simple_linear_irm_aishell/step_3_evaluation.py @@ -0,0 +1,239 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +stft_power = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=2.0, + window_fn=torch.hamming_window, +) + + +stft_complex = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=None, + window_fn=torch.hamming_window, +) + + +istft = torchaudio.transforms.InverseSpectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, +) + + +def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): + mix_spec_complex = mix_spec_complex.detach().cpu() + speech_irm_prediction = speech_irm_prediction.detach().cpu() + + mask_speech = speech_irm_prediction + mask_noise = 1.0 - speech_irm_prediction + + speech_spec = mix_spec_complex * mask_speech + noise_spec = mix_spec_complex * mask_noise + + speech_wave = istft.forward(speech_spec) + noise_wave = istft.forward(noise_spec) + + return speech_wave, noise_wave + + +def save_audios(noise_wave: torch.Tensor, + speech_wave: torch.Tensor, + mix_wave: torch.Tensor, + speech_wave_enhanced: torch.Tensor, + noise_wave_enhanced: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_wave.wav" + torchaudio.save(filename, noise_wave, sample_rate) + filename = output_dir / "speech_wave.wav" + torchaudio.save(filename, speech_wave, sample_rate) + filename = output_dir / "mix_wave.wav" + torchaudio.save(filename, mix_wave, sample_rate) + + filename = output_dir / "speech_wave_enhanced.wav" + torchaudio.save(filename, speech_wave_enhanced, sample_rate) + filename = output_dir / "noise_wave_enhanced.wav" + torchaudio.save(filename, noise_wave_enhanced, sample_rate) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + model = SimpleLinearIRMPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + model.to(device) + model.eval() + + # optimizer + logger.info("prepare loss_fn") + mse_loss = nn.MSELoss( + reduction="mean", + ) + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_wave, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + speech_wave, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + mix_wave: np.ndarray = mix_speech_and_noise( + speech=speech_wave, + noise=noise_wave, + snr_db=snr_db, + ) + noise_wave = torch.tensor(noise_wave, dtype=torch.float32) + speech_wave = torch.tensor(speech_wave, dtype=torch.float32) + mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) + + noise_wave = noise_wave.unsqueeze(dim=0) + speech_wave = speech_wave.unsqueeze(dim=0) + mix_wave = mix_wave.unsqueeze(dim=0) + + noise_spec: torch.Tensor = stft_power.forward(noise_wave) + speech_spec: torch.Tensor = stft_power.forward(speech_wave) + mix_spec: torch.Tensor = stft_power.forward(mix_wave) + mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) + + speech_irm = speech_spec / (noise_spec + speech_spec) + speech_irm = torch.pow(speech_irm, 1.0) + + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + with torch.no_grad(): + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) + save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/simple_linear_irm_aishell/yaml/config.yaml b/examples/simple_linear_irm_aishell/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f269122b6f589c012376cda288e9a256f1453de2 --- /dev/null +++ b/examples/simple_linear_irm_aishell/yaml/config.yaml @@ -0,0 +1,13 @@ +model_name: "simple_linear_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +# model +num_bins: 257 +hidden_size: 2048 +lookback: 3 +lookahead: 3 diff --git a/examples/simple_lstm_irm_aishell/run.sh b/examples/simple_lstm_irm_aishell/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..52d9691c1c3ced86192dfae6da6bed1106306390 --- /dev/null +++ b/examples/simple_lstm_irm_aishell/run.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash + +: <<'END' + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir + +sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/simple_lstm_irm_aishell/step_1_prepare_data.py b/examples/simple_lstm_irm_aishell/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..16398c263d39e9accfff6ffb4a64650037eba385 --- /dev/null +++ b/examples/simple_lstm_irm_aishell/step_1_prepare_data.py @@ -0,0 +1,197 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_nsr_db", default=-20, type=float) + parser.add_argument("--max_nsr_db", default=5, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/simple_lstm_irm_aishell/step_2_train_model.py b/examples/simple_lstm_irm_aishell/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..522fb72aa7d27560e1cda32a3fba109bf3fa4422 --- /dev/null +++ b/examples/simple_lstm_irm_aishell/step_2_train_model.py @@ -0,0 +1,348 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +from torch import dtype + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.simple_lstm_irm.configuration_simple_lstm_irm import SimpleLstmIRMConfig +from toolbox.torchaudio.models.simple_lstm_irm.modeling_simple_lstm_irm import SimpleLstmIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-3, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + window_fn: str = "hamming", + irm_beta: float = 1.0, + epsilon: float = 1e-8, + ): + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.window_fn = window_fn + self.irm_beta = irm_beta + self.epsilon = epsilon + + self.transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=2.0, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + def __call__(self, batch: List[dict]): + mix_spec_list = list() + speech_irm_list = list() + snr_db_list = list() + for sample in batch: + noise_wave: torch.Tensor = sample["noise_wave"] + speech_wave: torch.Tensor = sample["speech_wave"] + mix_wave: torch.Tensor = sample["mix_wave"] + snr_db: float = sample["snr_db"] + + noise_spec = self.transform.forward(noise_wave) + speech_spec = self.transform.forward(speech_wave) + mix_spec = self.transform.forward(mix_wave) + + # noise_irm = noise_spec / (noise_spec + speech_spec) + speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) + speech_irm = torch.pow(speech_irm, self.irm_beta) + + mix_spec_list.append(mix_spec) + speech_irm_list.append(speech_irm) + snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32)) + + mix_spec_list = torch.stack(mix_spec_list) + speech_irm_list = torch.stack(speech_irm_list) + snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,) + + # assert + if torch.any(torch.isnan(mix_spec_list)): + raise AssertionError("nan in mix_spec Tensor") + if torch.any(torch.isnan(speech_irm_list)): + raise AssertionError("nan in speech_irm Tensor") + if torch.any(torch.isnan(snr_db_list)): + raise AssertionError("nan in snr_db Tensor") + + return mix_spec_list, speech_irm_list, snr_db_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + config = SimpleLstmIRMConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + # num_labels=vocabulary.get_vocab_size(namespace="labels") + ) + model = SimpleLstmIRMPretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + param_optimizer = model.parameters() + optimizer = torch.optim.Adam( + param_optimizer, + lr=args.learning_rate, + ) + # lr_scheduler = torch.optim.lr_scheduler.StepLR( + # optimizer, + # step_size=2000 + # ) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + mse_loss = nn.MSELoss( + reduction="mean", + ) + + # training loop + logger.info("training") + + training_loss = 10000000000 + evaluation_loss = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + for idx_epoch in range(args.max_epochs): + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + + for batch in train_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + training_loss = total_loss / total_examples + training_loss = round(training_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "training_loss": training_loss, + }) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + for batch in valid_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + elif evaluation_loss < best_metric: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "training_loss": training_loss, + "evaluation_loss": evaluation_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + return + + +if __name__ == '__main__': + main() diff --git a/examples/simple_lstm_irm_aishell/step_3_evaluation.py b/examples/simple_lstm_irm_aishell/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc2616d3a88153416a8ab0a5ab16ec8816b3faa --- /dev/null +++ b/examples/simple_lstm_irm_aishell/step_3_evaluation.py @@ -0,0 +1,239 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.simple_lstm_irm.modeling_simple_lstm_irm import SimpleLstmIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +stft_power = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=2.0, + window_fn=torch.hamming_window, +) + + +stft_complex = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=None, + window_fn=torch.hamming_window, +) + + +istft = torchaudio.transforms.InverseSpectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, +) + + +def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): + mix_spec_complex = mix_spec_complex.detach().cpu() + speech_irm_prediction = speech_irm_prediction.detach().cpu() + + mask_speech = speech_irm_prediction + mask_noise = 1.0 - speech_irm_prediction + + speech_spec = mix_spec_complex * mask_speech + noise_spec = mix_spec_complex * mask_noise + + speech_wave = istft.forward(speech_spec) + noise_wave = istft.forward(noise_spec) + + return speech_wave, noise_wave + + +def save_audios(noise_wave: torch.Tensor, + speech_wave: torch.Tensor, + mix_wave: torch.Tensor, + speech_wave_enhanced: torch.Tensor, + noise_wave_enhanced: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_wave.wav" + torchaudio.save(filename, noise_wave, sample_rate) + filename = output_dir / "speech_wave.wav" + torchaudio.save(filename, speech_wave, sample_rate) + filename = output_dir / "mix_wave.wav" + torchaudio.save(filename, mix_wave, sample_rate) + + filename = output_dir / "speech_wave_enhanced.wav" + torchaudio.save(filename, speech_wave_enhanced, sample_rate) + filename = output_dir / "noise_wave_enhanced.wav" + torchaudio.save(filename, noise_wave_enhanced, sample_rate) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + model = SimpleLstmIRMPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + model.to(device) + model.eval() + + # optimizer + logger.info("prepare loss_fn") + mse_loss = nn.MSELoss( + reduction="mean", + ) + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_wave, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + speech_wave, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + mix_wave: np.ndarray = mix_speech_and_noise( + speech=speech_wave, + noise=noise_wave, + snr_db=snr_db, + ) + noise_wave = torch.tensor(noise_wave, dtype=torch.float32) + speech_wave = torch.tensor(speech_wave, dtype=torch.float32) + mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) + + noise_wave = noise_wave.unsqueeze(dim=0) + speech_wave = speech_wave.unsqueeze(dim=0) + mix_wave = mix_wave.unsqueeze(dim=0) + + noise_spec: torch.Tensor = stft_power.forward(noise_wave) + speech_spec: torch.Tensor = stft_power.forward(speech_wave) + mix_spec: torch.Tensor = stft_power.forward(mix_wave) + mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) + + speech_irm = speech_spec / (noise_spec + speech_spec) + speech_irm = torch.pow(speech_irm, 1.0) + + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + with torch.no_grad(): + speech_irm_prediction = model.forward(mix_spec) + loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) + + speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) + save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/spectrum_unet_irm_aishell/run.sh b/examples/spectrum_unet_irm_aishell/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..1759595e54141f900ef32a9eba90f1e700b26969 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/run.sh @@ -0,0 +1,174 @@ +#!/usr/bin/env bash + +: <<'END' + + +sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \ +--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train" + + +sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" + + +END + + +# params +system_version="windows"; +verbose=true; +stage=0 # start from 0 if you need to start from data preparation +stop_stage=9 + +work_dir="$(pwd)" +file_folder_name=file_folder_name +final_model_name=final_model_name +config_file="yaml/config.yaml" +limit=10 + +noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise +speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train + +nohup_name=nohup.out + +# model params +batch_size=64 +max_epochs=200 +save_top_k=10 +patience=5 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +file_dir="${work_dir}/${file_folder_name}" +final_model_dir="${work_dir}/../../trained_models/${final_model_name}"; +evaluation_audio_dir="${file_dir}/evaluation_audio" + +dataset="${file_dir}/dataset.xlsx" +train_dataset="${file_dir}/train.xlsx" +valid_dataset="${file_dir}/valid.xlsx" + +$verbose && echo "system_version: ${system_version}" +$verbose && echo "file_folder_name: ${file_folder_name}" + +if [ $system_version == "windows" ]; then + alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe' +elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then + #source /data/local/bin/nx_denoise/bin/activate + alias python3='/data/local/bin/nx_denoise/bin/python3' +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: prepare data" + cd "${work_dir}" || exit 1 + python3 step_1_prepare_data.py \ + --file_dir "${file_dir}" \ + --noise_dir "${noise_dir}" \ + --speech_dir "${speech_dir}" \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: train model" + cd "${work_dir}" || exit 1 + python3 step_2_train_model.py \ + --train_dataset "${train_dataset}" \ + --valid_dataset "${valid_dataset}" \ + --serialization_dir "${file_dir}" \ + --config_file "${config_file}" \ + +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + $verbose && echo "stage 3: test model" + cd "${work_dir}" || exit 1 + python3 step_3_evaluation.py \ + --valid_dataset "${valid_dataset}" \ + --model_dir "${file_dir}/best" \ + --evaluation_audio_dir "${evaluation_audio_dir}" \ + --limit "${limit}" \ + +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + $verbose && echo "stage 4: export model" + cd "${work_dir}" || exit 1 + python3 step_5_export_models.py \ + --vocabulary_dir "${vocabulary_dir}" \ + --model_dir "${file_dir}/best" \ + --serialization_dir "${file_dir}" \ + +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + $verbose && echo "stage 5: collect files" + cd "${work_dir}" || exit 1 + + mkdir -p ${final_model_dir} + + cp "${file_dir}/best"/* "${final_model_dir}" + cp -r "${file_dir}/vocabulary" "${final_model_dir}" + + cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx" + + cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip" + cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip" + cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip" + cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip" + + cd "${final_model_dir}/.." || exit 1; + + if [ -e "${final_model_name}.zip" ]; then + rm -rf "${final_model_name}_backup.zip" + mv "${final_model_name}.zip" "${final_model_name}_backup.zip" + fi + + zip -r "${final_model_name}.zip" "${final_model_name}" + rm -rf "${final_model_name}" + +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + $verbose && echo "stage 6: clear file_dir" + cd "${work_dir}" || exit 1 + + rm -rf "${file_dir}"; + +fi diff --git a/examples/spectrum_unet_irm_aishell/step_1_prepare_data.py b/examples/spectrum_unet_irm_aishell/step_1_prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..16398c263d39e9accfff6ffb4a64650037eba385 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/step_1_prepare_data.py @@ -0,0 +1,197 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import os +from pathlib import Path +import random +import sys +import shutil + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import pandas as pd +from scipy.io import wavfile +from tqdm import tqdm +import librosa + +from project_settings import project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--file_dir", default="./", type=str) + + parser.add_argument( + "--noise_dir", + default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise", + type=str + ) + parser.add_argument( + "--speech_dir", + default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train", + type=str + ) + + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--duration", default=2.0, type=float) + parser.add_argument("--min_nsr_db", default=-20, type=float) + parser.add_argument("--max_nsr_db", default=5, type=float) + + parser.add_argument("--target_sample_rate", default=8000, type=int) + + args = parser.parse_args() + return args + + +def filename_generator(data_dir: str): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + yield filename.as_posix() + + +def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000): + data_dir = Path(data_dir) + for filename in data_dir.glob("**/*.wav"): + signal, _ = librosa.load(filename.as_posix(), sr=sample_rate) + raw_duration = librosa.get_duration(y=signal, sr=sample_rate) + + if raw_duration < duration: + # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}") + continue + if signal.ndim != 1: + raise AssertionError(f"expected ndim 1, instead of {signal.ndim}") + + signal_length = len(signal) + win_size = int(duration * sample_rate) + for begin in range(0, signal_length - win_size, win_size): + row = { + "filename": filename.as_posix(), + "raw_duration": round(raw_duration, 4), + "offset": round(begin / sample_rate, 4), + "duration": round(duration, 4), + } + yield row + + +def get_dataset(args): + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + noise_dir = Path(args.noise_dir) + speech_dir = Path(args.speech_dir) + + noise_generator = target_second_signal_generator( + noise_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + speech_generator = target_second_signal_generator( + speech_dir.as_posix(), + duration=args.duration, + sample_rate=args.target_sample_rate + ) + + dataset = list() + + count = 0 + process_bar = tqdm(desc="build dataset excel") + for noise, speech in zip(noise_generator, speech_generator): + + noise_filename = noise["filename"] + noise_raw_duration = noise["raw_duration"] + noise_offset = noise["offset"] + noise_duration = noise["duration"] + + speech_filename = speech["filename"] + speech_raw_duration = speech["raw_duration"] + speech_offset = speech["offset"] + speech_duration = speech["duration"] + + random1 = random.random() + random2 = random.random() + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db), + + "random1": random1, + "random2": random2, + "flag": "TRAIN" if random2 < 0.8 else "TEST", + } + dataset.append(row) + count += 1 + duration_seconds = count * args.duration + duration_hours = duration_seconds / 3600 + + process_bar.update(n=1) + process_bar.set_postfix({ + # "duration_seconds": round(duration_seconds, 4), + "duration_hours": round(duration_hours, 4), + + }) + + dataset = pd.DataFrame(dataset) + dataset = dataset.sort_values(by=["random1"], ascending=False) + dataset.to_excel( + file_dir / "dataset.xlsx", + index=False, + ) + return + + + +def split_dataset(args): + """分割训练集, 测试集""" + file_dir = Path(args.file_dir) + file_dir.mkdir(exist_ok=True) + + df = pd.read_excel(file_dir / "dataset.xlsx") + + train = list() + test = list() + + for i, row in df.iterrows(): + flag = row["flag"] + if flag == "TRAIN": + train.append(row) + else: + test.append(row) + + train = pd.DataFrame(train) + train.to_excel( + args.train_dataset, + index=False, + # encoding="utf_8_sig" + ) + test = pd.DataFrame(test) + test.to_excel( + args.valid_dataset, + index=False, + # encoding="utf_8_sig" + ) + + return + + +def main(): + args = get_args() + + get_dataset(args) + split_dataset(args) + return + + +if __name__ == "__main__": + main() diff --git a/examples/spectrum_unet_irm_aishell/step_2_train_model.py b/examples/spectrum_unet_irm_aishell/step_2_train_model.py new file mode 100644 index 0000000000000000000000000000000000000000..da54e0a3a91bee5fa0c8851f934b3ec778852fe1 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/step_2_train_model.py @@ -0,0 +1,371 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +""" +import argparse +import json +import logging +from logging.handlers import TimedRotatingFileHandler +import os +import platform +from pathlib import Path +import random +import sys +import shutil +from typing import List + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data.dataloader import DataLoader +import torchaudio +from tqdm import tqdm + +from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset +from toolbox.torchaudio.models.spectrum_unet_irm.configuration_specturm_unet_irm import SpectrumUnetIRMConfig +from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_dataset", default="train.xlsx", type=str) + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + + parser.add_argument("--max_epochs", default=100, type=int) + + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--learning_rate", default=1e-3, type=float) + parser.add_argument("--num_serialized_models_to_keep", default=10, type=int) + parser.add_argument("--patience", default=5, type=int) + parser.add_argument("--serialization_dir", default="serialization_dir", type=str) + parser.add_argument("--seed", default=0, type=int) + + parser.add_argument("--config_file", default="config.yaml", type=str) + + args = parser.parse_args() + return args + + +def logging_config(file_dir: str): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + file_handler = TimedRotatingFileHandler( + filename=os.path.join(file_dir, "main.log"), + encoding="utf-8", + when="D", + interval=1, + backupCount=7 + ) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter(fmt)) + logger = logging.getLogger(__name__) + logger.addHandler(file_handler) + + return logger + + +class CollateFunction(object): + def __init__(self, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + window_fn: str = "hamming", + irm_beta: float = 1.0, + epsilon: float = 1e-8, + ): + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.window_fn = window_fn + self.irm_beta = irm_beta + self.epsilon = epsilon + + self.transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + power=2.0, + window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, + ) + + def __call__(self, batch: List[dict]): + mix_spec_list = list() + speech_irm_list = list() + snr_db_list = list() + for sample in batch: + noise_wave: torch.Tensor = sample["noise_wave"] + speech_wave: torch.Tensor = sample["speech_wave"] + mix_wave: torch.Tensor = sample["mix_wave"] + # snr_db: float = sample["snr_db"] + + noise_spec = self.transform.forward(noise_wave) + speech_spec = self.transform.forward(speech_wave) + mix_spec = self.transform.forward(mix_wave) + + # noise_irm = noise_spec / (noise_spec + speech_spec) + speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon) + speech_irm = torch.pow(speech_irm, self.irm_beta) + + # noise_spec, speech_spec, mix_spec, speech_irm + # shape: [freq_dim, time_steps] + + snr_db: torch.Tensor = 10 * torch.log10( + speech_spec / (noise_spec + self.epsilon) + ) + snr_db = torch.mean(snr_db, dim=0, keepdim=True) + # snr_db shape: [1, time_steps] + + mix_spec_list.append(mix_spec) + speech_irm_list.append(speech_irm) + snr_db_list.append(snr_db) + + mix_spec_list = torch.stack(mix_spec_list) + speech_irm_list = torch.stack(speech_irm_list) + snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1) + + mix_spec_list = mix_spec_list[:, :-1, :] + speech_irm_list = speech_irm_list[:, :-1, :] + + # mix_spec_list shape: [batch_size, freq_dim, time_steps] + # speech_irm_list shape: [batch_size, freq_dim, time_steps] + # snr_db shape: [batch_size, 1, time_steps] + + # assert + if torch.any(torch.isnan(mix_spec_list)): + raise AssertionError("nan in mix_spec Tensor") + if torch.any(torch.isnan(speech_irm_list)): + raise AssertionError("nan in speech_irm Tensor") + if torch.any(torch.isnan(snr_db_list)): + raise AssertionError("nan in snr_db Tensor") + + return mix_spec_list, speech_irm_list, snr_db_list + + +collate_fn = CollateFunction() + + +def main(): + args = get_args() + + serialization_dir = Path(args.serialization_dir) + serialization_dir.mkdir(parents=True, exist_ok=True) + + logger = logging_config(serialization_dir) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + logger.info("set seed: {}".format(args.seed)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + # datasets + logger.info("prepare datasets") + train_dataset = DenoiseExcelDataset( + excel_file=args.train_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + valid_dataset = DenoiseExcelDataset( + excel_file=args.valid_dataset, + expected_sample_rate=8000, + max_wave_value=32768.0, + ) + train_data_loader = DataLoader( + dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + valid_data_loader = DataLoader( + dataset=valid_dataset, + batch_size=args.batch_size, + shuffle=True, + # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能. + num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2, + collate_fn=collate_fn, + pin_memory=False, + # prefetch_factor=64, + ) + + # models + logger.info(f"prepare models. config_file: {args.config_file}") + config = SpectrumUnetIRMConfig.from_pretrained( + pretrained_model_name_or_path=args.config_file, + # num_labels=vocabulary.get_vocab_size(namespace="labels") + ) + model = SpectrumUnetIRMPretrainedModel( + config=config, + ) + model.to(device) + model.train() + + # optimizer + logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy") + param_optimizer = model.parameters() + optimizer = torch.optim.Adam( + param_optimizer, + lr=args.learning_rate, + ) + # lr_scheduler = torch.optim.lr_scheduler.StepLR( + # optimizer, + # step_size=2000 + # ) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5 + ) + irm_mse_loss = nn.MSELoss( + reduction="mean", + ) + snr_mse_loss = nn.MSELoss( + reduction="mean", + ) + + # training loop + logger.info("training") + + training_loss = 10000000000 + evaluation_loss = 10000000000 + + model_list = list() + best_idx_epoch = None + best_metric = None + patience_count = 0 + + for idx_epoch in range(args.max_epochs): + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(train_data_loader), + desc="Training; epoch: {}".format(idx_epoch), + ) + + for batch in train_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + speech_irm_prediction, lsnr_prediction = model.forward(mix_spec) + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + # loss = irm_loss + 0.1 * snr_loss + loss = irm_loss + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + training_loss = total_loss / total_examples + training_loss = round(training_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "training_loss": training_loss, + }) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm( + total=len(valid_data_loader), + desc="Evaluation; epoch: {}".format(idx_epoch), + ) + for batch in valid_data_loader: + mix_spec, speech_irm, snr_db = batch + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_irm_prediction, lsnr_prediction = model.forward(mix_spec) + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + # loss = irm_loss + 0.1 * snr_loss + loss = irm_loss + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + # save path + epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch) + epoch_dir.mkdir(parents=True, exist_ok=False) + + # save models + model.save_pretrained(epoch_dir.as_posix()) + + model_list.append(epoch_dir) + if len(model_list) >= args.num_serialized_models_to_keep: + model_to_delete: Path = model_list.pop(0) + shutil.rmtree(model_to_delete.as_posix()) + + # save metric + if best_metric is None: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + elif evaluation_loss < best_metric: + best_idx_epoch = idx_epoch + best_metric = evaluation_loss + else: + pass + + metrics = { + "idx_epoch": idx_epoch, + "best_idx_epoch": best_idx_epoch, + "training_loss": training_loss, + "evaluation_loss": evaluation_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + } + metrics_filename = epoch_dir / "metrics_epoch.json" + with open(metrics_filename, "w", encoding="utf-8") as f: + json.dump(metrics, f, indent=4, ensure_ascii=False) + + # save best + best_dir = serialization_dir / "best" + if best_idx_epoch == idx_epoch: + if best_dir.exists(): + shutil.rmtree(best_dir) + shutil.copytree(epoch_dir, best_dir) + + # early stop + early_stop_flag = False + if best_idx_epoch == idx_epoch: + patience_count = 0 + else: + patience_count += 1 + if patience_count >= args.patience: + early_stop_flag = True + + # early stop + if early_stop_flag: + break + return + + +if __name__ == '__main__': + main() diff --git a/examples/spectrum_unet_irm_aishell/step_3_evaluation.py b/examples/spectrum_unet_irm_aishell/step_3_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..38ef3054e4688d1b94f8d14cb3f3f3ca6b8b8296 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/step_3_evaluation.py @@ -0,0 +1,270 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import logging +import os +from pathlib import Path +import sys +import uuid + +pwd = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.join(pwd, "../../")) + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torch.nn as nn +import torchaudio +from tqdm import tqdm + +from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) + parser.add_argument("--model_dir", default="serialization_dir/best", type=str) + parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) + + parser.add_argument("--limit", default=10, type=int) + + args = parser.parse_args() + return args + + +def logging_config(): + fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" + + logging.basicConfig(format=fmt, + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(logging.Formatter(fmt)) + + logger = logging.getLogger(__name__) + + return logger + + +def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal + + +stft_power = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=2.0, + window_fn=torch.hamming_window, +) + + +stft_complex = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + power=None, + window_fn=torch.hamming_window, +) + + +istft = torchaudio.transforms.InverseSpectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, +) + + +def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): + mix_spec_complex = mix_spec_complex.detach().cpu() + speech_irm_prediction = speech_irm_prediction.detach().cpu() + + mask_speech = speech_irm_prediction + mask_noise = 1.0 - speech_irm_prediction + + speech_spec = mix_spec_complex * mask_speech + noise_spec = mix_spec_complex * mask_noise + + speech_wave = istft.forward(speech_spec) + noise_wave = istft.forward(noise_spec) + + return speech_wave, noise_wave + + +def save_audios(noise_wave: torch.Tensor, + speech_wave: torch.Tensor, + mix_wave: torch.Tensor, + speech_wave_enhanced: torch.Tensor, + noise_wave_enhanced: torch.Tensor, + output_dir: str, + sample_rate: int = 8000, + ): + basename = uuid.uuid4().__str__() + output_dir = Path(output_dir) / basename + output_dir.mkdir(parents=True, exist_ok=True) + + filename = output_dir / "noise_wave.wav" + torchaudio.save(filename, noise_wave, sample_rate) + filename = output_dir / "speech_wave.wav" + torchaudio.save(filename, speech_wave, sample_rate) + filename = output_dir / "mix_wave.wav" + torchaudio.save(filename, mix_wave, sample_rate) + + filename = output_dir / "speech_wave_enhanced.wav" + torchaudio.save(filename, speech_wave_enhanced, sample_rate) + filename = output_dir / "noise_wave_enhanced.wav" + torchaudio.save(filename, noise_wave_enhanced, sample_rate) + + return output_dir.as_posix() + + +def main(): + args = get_args() + + logger = logging_config() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_gpu = torch.cuda.device_count() + logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) + + logger.info("prepare model") + model = SpectrumUnetIRMPretrainedModel.from_pretrained( + pretrained_model_name_or_path=args.model_dir, + ) + model.to(device) + model.eval() + + # optimizer + logger.info("prepare loss_fn") + irm_mse_loss = nn.MSELoss( + reduction="mean", + ) + snr_mse_loss = nn.MSELoss( + reduction="mean", + ) + + logger.info("read excel") + df = pd.read_excel(args.valid_dataset) + + total_loss = 0. + total_examples = 0. + progress_bar = tqdm(total=len(df), desc="Evaluation") + for idx, row in df.iterrows(): + noise_filename = row["noise_filename"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + noise_wave, _ = librosa.load( + noise_filename, + sr=8000, + offset=noise_offset, + duration=noise_duration, + ) + speech_wave, _ = librosa.load( + speech_filename, + sr=8000, + offset=speech_offset, + duration=speech_duration, + ) + mix_wave: np.ndarray = mix_speech_and_noise( + speech=speech_wave, + noise=noise_wave, + snr_db=snr_db, + ) + noise_wave = torch.tensor(noise_wave, dtype=torch.float32) + speech_wave = torch.tensor(speech_wave, dtype=torch.float32) + mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) + + noise_wave = noise_wave.unsqueeze(dim=0) + speech_wave = speech_wave.unsqueeze(dim=0) + mix_wave = mix_wave.unsqueeze(dim=0) + + noise_spec: torch.Tensor = stft_power.forward(noise_wave) + speech_spec: torch.Tensor = stft_power.forward(speech_wave) + mix_spec: torch.Tensor = stft_power.forward(mix_wave) + + noise_spec = noise_spec[:, :-1, :] + speech_spec = speech_spec[:, :-1, :] + mix_spec = mix_spec[:, :-1, :] + + mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) + # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2] + + speech_irm = speech_spec / (noise_spec + speech_spec) + speech_irm = torch.pow(speech_irm, 1.0) + + snr_db: torch.Tensor = 10 * torch.log10( + speech_spec / (noise_spec + 1e-8) + ) + snr_db = torch.mean(snr_db, dim=1, keepdim=True) + # snr_db shape: [batch_size, 1, time_steps] + + mix_spec = mix_spec.to(device) + speech_irm_target = speech_irm.to(device) + snr_db_target = snr_db.to(device) + + with torch.no_grad(): + speech_irm_prediction, lsnr_prediction = model.forward(mix_spec) + irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target) + # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target) + # loss = irm_loss + 0.1 * snr_loss + loss = irm_loss + + # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2] + # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps] + batch_size, _, time_steps = speech_irm_prediction.shape + speech_irm_prediction = torch.concat( + [ + speech_irm_prediction, + 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device) + ], + dim=1, + ) + # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps] + speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) + save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) + + total_loss += loss.item() + total_examples += mix_spec.size(0) + + evaluation_loss = total_loss / total_examples + evaluation_loss = round(evaluation_loss, 4) + + progress_bar.update(1) + progress_bar.set_postfix({ + "evaluation_loss": evaluation_loss, + }) + + if idx > args.limit: + break + + return + + +if __name__ == '__main__': + main() diff --git a/examples/spectrum_unet_irm_aishell/yaml/config.yaml b/examples/spectrum_unet_irm_aishell/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b756c46aeccf91428ed137a1c11aeb127509ac00 --- /dev/null +++ b/examples/spectrum_unet_irm_aishell/yaml/config.yaml @@ -0,0 +1,35 @@ +model_name: "spectrum_unet_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +spec_bins: 256 + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +conv_lookahead: 0 + +convt_kernel_size_inner: + - 1 + - 3 + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +lsnr_max: 20 +lsnr_min: -10 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..826e360a65034f9b489f2ff0e9feb1a286fede9b --- /dev/null +++ b/install.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash + +# bash install.sh --stage 2 --stop_stage 2 --system_version centos + + +python_version=3.8.10 +system_version="centos"; + +verbose=true; +stage=-1 +stop_stage=0 + + +# parse options +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g); + eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + old_value="(eval echo \\$$name)"; + if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval "${name}=\"$2\""; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + + *) break; + esac +done + +work_dir="$(pwd)" + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + $verbose && echo "stage 1: install python" + cd "${work_dir}" || exit 1; + + sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}" +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + $verbose && echo "stage 2: create virtualenv" + + # /usr/local/python-3.9.9/bin/virtualenv nx_denoise + # source /data/local/bin/nx_denoise/bin/activate + /usr/local/python-${python_version}/bin/pip3 install virtualenv + mkdir -p /data/local/bin + cd /data/local/bin || exit 1; + /usr/local/python-${python_version}/bin/virtualenv nx_denoise + +fi diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..68d8d1a6d3204bbff2ad36f1f6761029e210f33d --- /dev/null +++ b/main.py @@ -0,0 +1,45 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import argparse +import platform + +import gradio as gr + +from project_settings import environment, project_path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--hf_token", + default=environment.get("hf_token"), + type=str, + ) + parser.add_argument( + "--server_port", + default=environment.get("server_port", 7860), + type=int + ) + + args = parser.parse_args() + return args + + +def main(): + args = get_args() + + # ui + with gr.Blocks() as blocks: + gr.Markdown(value="in progress.") + + # http://127.0.0.1:7864/ + blocks.queue().launch( + share=False if platform.system() == "Windows" else False, + server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", + server_port=args.server_port + ) + return + + +if __name__ == "__main__": + main() diff --git a/project_settings.py b/project_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..61d3f885814edf7f3f31e5ffff7f3bdae0e330e2 --- /dev/null +++ b/project_settings.py @@ -0,0 +1,25 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os +from pathlib import Path + +from toolbox.os.environment import EnvironmentManager + + +project_path = os.path.abspath(os.path.dirname(__file__)) +project_path = Path(project_path) + +log_directory = project_path / "logs" +log_directory.mkdir(parents=True, exist_ok=True) + +temp_directory = project_path / "temp" +temp_directory.mkdir(parents=True, exist_ok=True) + +environment = EnvironmentManager( + path=os.path.join(project_path, "dotenv"), + env=os.environ.get("environment", "dev"), +) + + +if __name__ == '__main__': + pass diff --git a/requirements-python-3-9-9.txt b/requirements-python-3-9-9.txt new file mode 100644 index 0000000000000000000000000000000000000000..342f55ae208753e8c19adebbdce603128cc7a95b --- /dev/null +++ b/requirements-python-3-9-9.txt @@ -0,0 +1,10 @@ +gradio==4.44.1 +datasets==3.2.0 +python-dotenv==1.0.1 +scipy==1.13.1 +librosa==0.10.2.post1 +pandas==2.2.3 +openpyxl==3.1.5 +torch==2.5.1 +torchaudio==2.5.1 +overrides==7.7.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e4af41e62949fc1b7f3839b84b4498d9f5692bb4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +gradio==5.12.0 +datasets==3.2.0 +python-dotenv==1.0.1 +scipy==1.15.1 +librosa==0.10.2.post1 +pandas==2.2.3 +openpyxl==3.1.5 +torch==2.5.1 +torchaudio==2.5.1 +overrides==7.7.0 diff --git a/toolbox/__init__.py b/toolbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/json/__init__.py b/toolbox/json/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/json/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/json/misc.py b/toolbox/json/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..56022e111a555aa370ff833f4ee68f880d183ed9 --- /dev/null +++ b/toolbox/json/misc.py @@ -0,0 +1,63 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Callable + + +def traverse(js, callback: Callable, *args, **kwargs): + if isinstance(js, list): + result = list() + for l in js: + l = traverse(l, callback, *args, **kwargs) + result.append(l) + return result + elif isinstance(js, tuple): + result = list() + for l in js: + l = traverse(l, callback, *args, **kwargs) + result.append(l) + return tuple(result) + elif isinstance(js, dict): + result = dict() + for k, v in js.items(): + k = traverse(k, callback, *args, **kwargs) + v = traverse(v, callback, *args, **kwargs) + result[k] = v + return result + elif isinstance(js, int): + return callback(js, *args, **kwargs) + elif isinstance(js, str): + return callback(js, *args, **kwargs) + else: + return js + + +def demo1(): + d = { + "env": "ppe", + "mysql_connect": { + "host": "$mysql_connect_host", + "port": 3306, + "user": "callbot", + "password": "NxcloudAI2021!", + "database": "callbot_ppe", + "charset": "utf8" + }, + "es_connect": { + "hosts": ["10.20.251.8"], + "http_auth": ["elastic", "ElasticAI2021!"], + "port": 9200 + } + } + + def callback(s): + if isinstance(s, str) and s.startswith('$'): + return s[1:] + return s + + result = traverse(d, callback=callback) + print(result) + return + + +if __name__ == '__main__': + demo1() diff --git a/toolbox/os/__init__.py b/toolbox/os/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/os/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/os/command.py b/toolbox/os/command.py new file mode 100644 index 0000000000000000000000000000000000000000..40a1880fb08d21be79f31a13e51e389adb18c283 --- /dev/null +++ b/toolbox/os/command.py @@ -0,0 +1,59 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os + + +class Command(object): + custom_command = [ + "cd" + ] + + @staticmethod + def _get_cmd(command): + command = str(command).strip() + if command == "": + return None + cmd_and_args = command.split(sep=" ") + cmd = cmd_and_args[0] + args = " ".join(cmd_and_args[1:]) + return cmd, args + + @classmethod + def popen(cls, command): + cmd, args = cls._get_cmd(command) + if cmd in cls.custom_command: + method = getattr(cls, cmd) + return method(args) + else: + resp = os.popen(command) + result = resp.read() + resp.close() + return result + + @classmethod + def cd(cls, args): + if args.startswith("/"): + os.chdir(args) + else: + pwd = os.getcwd() + path = os.path.join(pwd, args) + os.chdir(path) + + @classmethod + def system(cls, command): + return os.system(command) + + def __init__(self): + pass + + +def ps_ef_grep(keyword: str): + cmd = "ps -ef | grep {}".format(keyword) + rows = Command.popen(cmd) + rows = str(rows).split("\n") + rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")] + return rows + + +if __name__ == "__main__": + pass diff --git a/toolbox/os/environment.py b/toolbox/os/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9bc004fc8777ae906c1caccd47bb257b126d55 --- /dev/null +++ b/toolbox/os/environment.py @@ -0,0 +1,114 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import json +import os + +from dotenv import load_dotenv +from dotenv.main import DotEnv + +from toolbox.json.misc import traverse + + +class EnvironmentManager(object): + def __init__(self, path, env, override=False): + filename = os.path.join(path, '{}.env'.format(env)) + self.filename = filename + + load_dotenv( + dotenv_path=filename, + override=override + ) + + self._environ = dict() + + def open_dotenv(self, filename: str = None): + filename = filename or self.filename + dotenv = DotEnv( + dotenv_path=filename, + stream=None, + verbose=False, + interpolate=False, + override=False, + encoding="utf-8", + ) + result = dotenv.dict() + return result + + def get(self, key, default=None, dtype=str): + result = os.environ.get(key) + if result is None: + if default is None: + result = None + else: + result = default + else: + result = dtype(result) + self._environ[key] = result + return result + + +_DEFAULT_DTYPE_MAP = { + 'int': int, + 'float': float, + 'str': str, + 'json.loads': json.loads +} + + +class JsonConfig(object): + """ + 将 json 中, 形如 `$float:threshold` 的值, 处理为: + 从环境变量中查到 threshold, 再将其转换为 float 类型. + """ + def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None): + self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP + self.environment = environment or os.environ + + def sanitize_by_filename(self, filename: str): + with open(filename, 'r', encoding='utf-8') as f: + js = json.load(f) + + return self.sanitize_by_json(js) + + def sanitize_by_json(self, js): + js = traverse( + js, + callback=self.sanitize, + environment=self.environment + ) + return js + + def sanitize(self, string, environment): + """支持 $ 符开始的, 环境变量配置""" + if isinstance(string, str) and string.startswith('$'): + dtype, key = string[1:].split(':') + dtype = self.dtype_map[dtype] + + value = environment.get(key) + if value is None: + raise AssertionError('environment not exist. key: {}'.format(key)) + + value = dtype(value) + result = value + else: + result = string + return result + + +def demo1(): + import json + + from project_settings import project_path + + environment = EnvironmentManager( + path=os.path.join(project_path, 'server/callbot_server/dotenv'), + env='dev', + ) + init_scenes = environment.get(key='init_scenes', dtype=json.loads) + print(init_scenes) + print(environment._environ) + return + + +if __name__ == '__main__': + demo1() diff --git a/toolbox/os/other.py b/toolbox/os/other.py new file mode 100644 index 0000000000000000000000000000000000000000..f215505eedfd714442d2fab241c8b1aff871d18a --- /dev/null +++ b/toolbox/os/other.py @@ -0,0 +1,9 @@ +import os +import inspect + + +def pwd(): + """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标""" + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + return os.path.dirname(os.path.abspath(module.__file__)) diff --git a/toolbox/torch/__init__.py b/toolbox/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/__init__.py b/toolbox/torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/utils/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/data/__init__.py b/toolbox/torch/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/utils/data/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/data/dataset/__init__.py b/toolbox/torch/utils/data/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torch/utils/data/dataset/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torch/utils/data/dataset/denoise_excel_dataset.py b/toolbox/torch/utils/data/dataset/denoise_excel_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..de7832efb1ac2de1f856df7fbaf6ae41a0f30545 --- /dev/null +++ b/toolbox/torch/utils/data/dataset/denoise_excel_dataset.py @@ -0,0 +1,131 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os + +import librosa +import numpy as np +import pandas as pd +from scipy.io import wavfile +import torch +import torchaudio +from torch.utils.data import Dataset +from tqdm import tqdm + + +class DenoiseExcelDataset(Dataset): + def __init__(self, + excel_file: str, + expected_sample_rate: int, + resample: bool = False, + max_wave_value: float = 1.0, + ): + self.excel_file = excel_file + self.expected_sample_rate = expected_sample_rate + self.resample = resample + self.max_wave_value = max_wave_value + + self.samples = self.load_samples(excel_file) + + @staticmethod + def load_samples(filename: str): + df = pd.read_excel(filename) + samples = list() + for i, row in tqdm(df.iterrows(), total=len(df)): + noise_filename = row["noise_filename"] + noise_raw_duration = row["noise_raw_duration"] + noise_offset = row["noise_offset"] + noise_duration = row["noise_duration"] + + speech_filename = row["speech_filename"] + speech_raw_duration = row["speech_raw_duration"] + speech_offset = row["speech_offset"] + speech_duration = row["speech_duration"] + + snr_db = row["snr_db"] + + row = { + "noise_filename": noise_filename, + "noise_raw_duration": noise_raw_duration, + "noise_offset": noise_offset, + "noise_duration": noise_duration, + + "speech_filename": speech_filename, + "speech_raw_duration": speech_raw_duration, + "speech_offset": speech_offset, + "speech_duration": speech_duration, + + "snr_db": snr_db, + } + samples.append(row) + return samples + + def __getitem__(self, index): + sample = self.samples[index] + noise_filename = sample["noise_filename"] + noise_offset = sample["noise_offset"] + noise_duration = sample["noise_duration"] + + speech_filename = sample["speech_filename"] + speech_offset = sample["speech_offset"] + speech_duration = sample["speech_duration"] + + snr_db = sample["snr_db"] + + noise_wave = self.filename_to_waveform(noise_filename, noise_offset, noise_duration) + speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration) + + mix_wave, noise_wave_adjusted = self.mix_speech_and_noise( + speech=speech_wave.numpy(), + noise=noise_wave.numpy(), + snr_db=snr_db, + ) + mix_wave = torch.tensor(mix_wave, dtype=torch.float32) + noise_wave_adjusted = torch.tensor(noise_wave_adjusted, dtype=torch.float32) + + result = { + "noise_wave": noise_wave_adjusted, + "speech_wave": speech_wave, + "mix_wave": mix_wave, + "snr_db": snr_db, + } + return result + + def __len__(self): + return len(self.samples) + + def filename_to_waveform(self, filename: str, offset: float, duration: float): + try: + waveform, sample_rate = librosa.load( + filename, + sr=self.expected_sample_rate, + offset=offset, + duration=duration, + ) + except ValueError as e: + print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}") + raise e + waveform = torch.tensor(waveform, dtype=torch.float32) + return waveform + + @staticmethod + def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): + l1 = len(speech) + l2 = len(noise) + l = min(l1, l2) + speech = speech[:l] + noise = noise[:l] + + # np.float32, value between (-1, 1). + + speech_power = np.mean(np.square(speech)) + noise_power = speech_power / (10 ** (snr_db / 10)) + + noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) + + noisy_signal = speech + noise_adjusted + + return noisy_signal, noise_adjusted + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/__init__.py b/toolbox/torchaudio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4aad738e112896111c38ae6624c8632aee62a234 --- /dev/null +++ b/toolbox/torchaudio/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/configuration_utils.py b/toolbox/torchaudio/configuration_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5ef0d2009681d4aa0c9b79a3728aa40622d0fe85 --- /dev/null +++ b/toolbox/torchaudio/configuration_utils.py @@ -0,0 +1,63 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import copy +import os +from typing import Any, Dict, Union + +import yaml + + +CONFIG_FILE = "config.yaml" + + +class PretrainedConfig(object): + def __init__(self, **kwargs): + pass + + @classmethod + def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]): + with open(yaml_file, encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + return config_dict + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike] + ) -> Dict[str, Any]: + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE) + else: + config_file = pretrained_model_name_or_path + config_dict = cls._dict_from_yaml_file(config_file) + return config_dict + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs): + for k, v in kwargs.items(): + if k in config_dict.keys(): + config_dict[k] = v + config = cls(**config_dict) + return config + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ): + config_dict = cls.get_config_dict(pretrained_model_name_or_path) + return cls.from_dict(config_dict, **kwargs) + + def to_dict(self): + output = copy.deepcopy(self.__dict__) + return output + + def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]): + config_dict = self.to_dict() + + with open(yaml_file_path, "w", encoding="utf-8") as writer: + yaml.safe_dump(config_dict, writer) + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/__init__.py b/toolbox/torchaudio/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4aad738e112896111c38ae6624c8632aee62a234 --- /dev/null +++ b/toolbox/torchaudio/models/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/clean_unet/__init__.py b/toolbox/torchaudio/models/clean_unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py b/toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..b73ae694b16235cc8c4e7b369e7ed19f1007100e --- /dev/null +++ b/toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py @@ -0,0 +1,9 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2202.07790 +""" + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/dfnet3/__init__.py b/toolbox/torchaudio/models/dfnet3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet3/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/dfnet3/configuration_dfnet3.py b/toolbox/torchaudio/models/dfnet3/configuration_dfnet3.py new file mode 100644 index 0000000000000000000000000000000000000000..571c3088a0063596c7b8da92beeb236eed40b528 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet3/configuration_dfnet3.py @@ -0,0 +1,89 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Any, Dict, List, Tuple, Union + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class DfNetConfig(PretrainedConfig): + def __init__(self, + sample_rate: int, + fft_size: int, + hop_size: int, + df_bins: int, + erb_bins: int, + min_freq_bins_for_erb: int, + df_order: int, + df_lookahead: int, + norm_tau: int, + lsnr_max: int, + lsnr_min: int, + conv_channels: int, + conv_kernel_size_input: Tuple[int, int], + conv_kernel_size_inner: Tuple[int, int], + convt_kernel_size_inner: Tuple[int, int], + conv_lookahead: int, + emb_hidden_dim: int, + mask_post_filter: bool, + df_hidden_dim: int, + df_num_layers: int, + df_pathway_kernel_size_t: int, + df_gru_skip: str, + post_filter_beta: float, + df_n_iter: float, + lsnr_dropout: bool, + encoder_gru_skip_op: str, + encoder_linear_groups: int, + encoder_squeezed_gru_linear_groups: int, + encoder_concat: bool, + erb_decoder_gru_skip_op: str, + erb_decoder_linear_groups: int, + erb_decoder_emb_num_layers: int, + df_decoder_linear_groups: int, + **kwargs + ): + super(DfNetConfig, self).__init__(**kwargs) + if df_gru_skip not in ("none", "identity", "grouped_linear"): + raise AssertionError + + self.sample_rate = sample_rate + self.fft_size = fft_size + self.hop_size = hop_size + self.df_bins = df_bins + self.erb_bins = erb_bins + self.min_freq_bins_for_erb = min_freq_bins_for_erb + self.df_order = df_order + self.df_lookahead = df_lookahead + self.norm_tau = norm_tau + self.lsnr_max = lsnr_max + self.lsnr_min = lsnr_min + + self.conv_channels = conv_channels + self.conv_kernel_size_input = conv_kernel_size_input + self.conv_kernel_size_inner = conv_kernel_size_inner + self.convt_kernel_size_inner = convt_kernel_size_inner + self.conv_lookahead = conv_lookahead + + self.emb_hidden_dim = emb_hidden_dim + self.mask_post_filter = mask_post_filter + self.df_hidden_dim = df_hidden_dim + self.df_num_layers = df_num_layers + self.df_pathway_kernel_size_t = df_pathway_kernel_size_t + self.df_gru_skip = df_gru_skip + self.post_filter_beta = post_filter_beta + self.df_n_iter = df_n_iter + self.lsnr_dropout = lsnr_dropout + self.encoder_gru_skip_op = encoder_gru_skip_op + self.encoder_linear_groups = encoder_linear_groups + self.encoder_squeezed_gru_linear_groups = encoder_squeezed_gru_linear_groups + self.encoder_concat = encoder_concat + + self.erb_decoder_gru_skip_op = erb_decoder_gru_skip_op + self.erb_decoder_linear_groups = erb_decoder_linear_groups + self.erb_decoder_emb_num_layers = erb_decoder_emb_num_layers + + self.df_decoder_linear_groups = df_decoder_linear_groups + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dfnet3/features.py b/toolbox/torchaudio/models/dfnet3/features.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba9f800e9dcbc7cf7730e3b7f00858fee072dfe --- /dev/null +++ b/toolbox/torchaudio/models/dfnet3/features.py @@ -0,0 +1,192 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import math + +import numpy as np + + +def freq2erb(freq_hz: float) -> float: + """ + https://www.cnblogs.com/LXP-Never/p/16011229.html + 1 / (24.7 * 9.265) = 0.00436976 + """ + return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1) + + +def erb2freq(n_erb: float) -> float: + return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1) + + +def get_erb_widths(sample_rate: int, fft_size: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray: + """ + https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs + :param sample_rate: + :param fft_size: + :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数. + :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band + :return: + """ + nyq_freq = sample_rate / 2. + freq_width: float = sample_rate / fft_size + + min_erb: float = freq2erb(0.) + max_erb: float = freq2erb(nyq_freq) + + erb = [0] * erb_bins + step = (max_erb - min_erb) / erb_bins + + prev_freq_bin = 0 + freq_over = 0 + for i in range(1, erb_bins + 1): + f = erb2freq(min_erb + i * step) + freq_bin = int(round(f / freq_width)) + freq_bins = freq_bin - prev_freq_bin - freq_over + + if freq_bins < min_freq_bins_for_erb: + freq_over = min_freq_bins_for_erb - freq_bins + freq_bins = min_freq_bins_for_erb + else: + freq_over = 0 + erb[i - 1] = freq_bins + prev_freq_bin = freq_bin + + erb[erb_bins - 1] += 1 + too_large = sum(erb) - (fft_size / 2 + 1) + if too_large > 0: + erb[erb_bins - 1] -= too_large + return np.array(erb, dtype=np.uint64) + + +def get_erb_filter_bank(erb_widths: np.ndarray, + sample_rate: int, + normalized: bool = True, + inverse: bool = False, + ): + num_freq_bins = int(np.sum(erb_widths)) + num_erb_bins = len(erb_widths) + + fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins)) + + points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1] + for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())): + fb[b: b + w, i] = 1 + + if inverse: + fb = fb.T + if not normalized: + fb /= np.sum(fb, axis=1, keepdims=True) + else: + if normalized: + fb /= np.sum(fb, axis=0) + return fb + + +def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True): + """ + ERB filterbank and transform to decibel scale. + + :param spec: Spectrum of shape [B, C, T, F]. + :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths, + where B are the number of ERB bins. + :param db: Whether to transform the output into decibel scale. Defaults to `True`. + :return: + """ + # complex spec to power spec. (real * real + image * image) + spec_ = np.abs(spec) ** 2 + + # spec to erb feature. + erb_feat = np.matmul(spec_, erb_fb) + + if db: + erb_feat = 10 * np.log10(erb_feat + 1e-10) + + erb_feat = np.array(erb_feat, dtype=np.float32) + return erb_feat + + +def _calculate_norm_alpha(sample_rate: int, hop_size: int, tau: float): + """Exponential decay factor alpha for a given tau (decay window size [s]).""" + dt = hop_size / sample_rate + result = math.exp(-dt / tau) + return result + + +def get_norm_alpha(sample_rate: int, hop_size: int, norm_tau: float) -> float: + a_ = _calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau) + + precision = 3 + a = 1.0 + while a >= 1.0: + a = round(a_, precision) + precision += 1 + + return a + + +MEAN_NORM_INIT = [-60., -90.] + + +def make_erb_norm_state(erb_bins: int, channels: int) -> np.ndarray: + state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins) + state = np.expand_dims(state, axis=0) + state = np.repeat(state, channels, axis=0) + + # state shape: (audio_channels, erb_bins) + return state + + +def erb_normalize(erb_feat: np.ndarray, alpha: float, state: np.ndarray = None): + erb_feat = np.copy(erb_feat) + batch_size, time_steps, erb_bins = erb_feat.shape + + if state is None: + state = make_erb_norm_state(erb_bins, erb_feat.shape[0]) + # state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins) + # state = np.expand_dims(state, axis=0) + # state = np.repeat(state, erb_feat.shape[0], axis=0) + + for i in range(batch_size): + for j in range(time_steps): + for k in range(erb_bins): + x = erb_feat[i][j][k] + s = state[i][k] + + state[i][k] = x * (1. - alpha) + s * alpha + erb_feat[i][j][k] -= state[i][k] + erb_feat[i][j][k] /= 40. + + return erb_feat + + +UNIT_NORM_INIT = [0.001, 0.0001] + + +def make_spec_norm_state(df_bins: int, channels: int) -> np.ndarray: + state = np.linspace(UNIT_NORM_INIT[0], UNIT_NORM_INIT[1], df_bins) + state = np.expand_dims(state, axis=0) + state = np.repeat(state, channels, axis=0) + + # state shape: (audio_channels, df_bins) + return state + + +def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None): + spec_feat = np.copy(spec_feat) + batch_size, time_steps, df_bins = spec_feat.shape + + if state is None: + state = make_spec_norm_state(df_bins, spec_feat.shape[0]) + + for i in range(batch_size): + for j in range(time_steps): + for k in range(df_bins): + x = spec_feat[i][j][k] + s = state[i][k] + + state[i][k] = np.abs(x) * (1. - alpha) + s * alpha + spec_feat[i][j][k] /= np.sqrt(state[i][k]) + return spec_feat + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/dfnet3/modeling_dfnet3.py b/toolbox/torchaudio/models/dfnet3/modeling_dfnet3.py new file mode 100644 index 0000000000000000000000000000000000000000..999aa8de1103bd30edbdc63628de9a44cad57814 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet3/modeling_dfnet3.py @@ -0,0 +1,835 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import logging +import math +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from toolbox.torchaudio.models.dfnet3.configuration_dfnet3 import DfNetConfig +from toolbox.torchaudio.models.dfnet3 import multiframes as MF +from toolbox.torchaudio.models.dfnet3 import utils + +logger = logging.getLogger("toolbox") + +PI = 3.1415926535897932384626433 + + +norm_layer_dict = { + "batch_norm_2d": torch.nn.BatchNorm2d +} + +activation_layer_dict = { + "relu": torch.nn.ReLU, + "identity": torch.nn.Identity, + "sigmoid": torch.nn.Sigmoid, +} + + +class CausalConv2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + ): + """ + Causal Conv2d by delaying the signal for any lookahead. + + Expected input format: [B, C, T, F] + + :param in_channels: + :param out_channels: + :param kernel_size: + :param fstride: + :param dilation: + :param fpad: + """ + super(CausalConv2d, self).__init__() + lookahead = 0 + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) + + if fpad: + fpad_ = kernel_size[1] // 2 + dilation - 1 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + if max(kernel_size) == 1: + separable = False + + layers.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + +class CausalConvTranspose2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + ): + """ + Causal ConvTranspose2d. + + Expected input format: [B, C, T, F] + """ + super(CausalConvTranspose2d, self).__init__() + lookahead = 0 + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + if fpad: + fpad_ = kernel_size[1] // 2 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + + layers.append( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size[0] - 1, fpad_ + dilation - 1), + output_padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + +class GroupedLinear(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" + assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" + self.ws = input_size // groups + self.register_parameter( + "weight", + torch.nn.Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True + ), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., I] + b, t, _ = x.shape + # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] + new_shape = (b, t, self.groups, self.ws) + x = x.view(new_shape) + # The better way, but not supported by torchscript + # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] + x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] + x = x.flatten(2, 3) # [B, T, H] + return x + + def __repr__(self): + cls = self.__class__.__name__ + return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" + + +class SqueezedGRU_S(nn.Module): + """ + SGE net: Video object detection with squeezed GRU and information entropy map + https://arxiv.org/abs/2106.07224 + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + skip_op: str = "none", + activation_layer: str = "identity", + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + self.linear_in = nn.Sequential( + GroupedLinear( + input_size=input_size, + hidden_size=hidden_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + + # gru skip operator + self.gru_skip_op = None + + if skip_op == "none": + self.gru_skip_op = None + elif skip_op == "identity": + if not input_size != output_size: + raise AssertionError("Dimensions do not match") + self.gru_skip_op = nn.Identity() + elif skip_op == "grouped_linear": + self.gru_skip_op = GroupedLinear( + input_size=hidden_size, + hidden_size=hidden_size, + groups=linear_groups, + ) + else: + raise NotImplementedError() + + self.gru = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + ) + + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinear( + input_size=hidden_size, + hidden_size=output_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + else: + self.linear_out = nn.Identity() + + def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.linear_in(inputs) + + x, h = self.gru(x, h) + + x = self.linear_out(x) + + if self.gru_skip_op is not None: + x = x + self.gru_skip_op(inputs) + + return x, h + + +class Add(nn.Module): + def forward(self, a, b): + return a + b + + +class Concat(nn.Module): + def forward(self, a, b): + return torch.cat((a, b), dim=-1) + + +class Encoder(nn.Module): + def __init__(self, config: DfNetConfig): + super(Encoder, self).__init__() + self.emb_in_dim = config.conv_channels * config.erb_bins // 4 + self.emb_out_dim = config.conv_channels * config.erb_bins // 4 + self.emb_hidden_dim = config.emb_hidden_dim + + self.erb_conv0 = CausalConv2d( + in_channels=1, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + ) + self.erb_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.erb_conv2 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + self.erb_conv3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + ) + + self.df_conv0 = CausalConv2d( + in_channels=2, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + ) + self.df_conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + ) + + self.df_fc_emb = nn.Sequential( + GroupedLinear( + config.conv_channels * config.df_bins // 2, + self.emb_in_dim, + groups=config.encoder_linear_groups + ), + nn.ReLU(inplace=True) + ) + + if config.encoder_concat: + self.emb_in_dim *= 2 + self.combine = Concat() + else: + self.combine = Add() + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=1, + batch_first=True, + skip_op=config.encoder_gru_skip_op, + linear_groups=config.encoder_squeezed_gru_linear_groups, + activation_layer="relu", + ) + + self.lsnr_fc = nn.Sequential( + nn.Linear(self.emb_out_dim, 1), + nn.Sigmoid() + ) + self.lsnr_scale = config.lsnr_max - config.lsnr_min + self.lsnr_offset = config.lsnr_min + + def forward(self, + feat_erb: torch.Tensor, + feat_spec: torch.Tensor, + h: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands. + # erb: [B, 1, T, Fe] + # spec: [B, 2, T, Fc] + # b, _, t, _ = feat_erb.shape + e0 = self.erb_conv0(feat_erb) # [B, C, T, F] + e1 = self.erb_conv1(e0) # [B, C*2, T, F/2] + e2 = self.erb_conv2(e1) # [B, C*4, T, F/4] + e3 = self.erb_conv3(e2) # [B, C*4, T, F/4] + c0 = self.df_conv0(feat_spec) # [B, C, T, Fc] + c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2] + cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1] + cemb = self.df_fc_emb(cemb) # [T, B, C * F/4] + emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F] + emb = self.combine(emb, cemb) + emb, h = self.emb_gru(emb, h) # [B, T, -1] + + lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset + return e0, e1, e2, e3, emb, c0, lsnr, h + + +class ErbDecoder(nn.Module): + def __init__(self, + config: DfNetConfig, + ): + super(ErbDecoder, self).__init__() + if config.erb_bins % 8 != 0: + raise AssertionError("erb_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * config.erb_bins // 4 + self.emb_out_dim = config.conv_channels * config.erb_bins // 4 + self.emb_hidden_dim = config.emb_hidden_dim + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=config.erb_decoder_emb_num_layers - 1, + batch_first=True, + skip_op=config.erb_decoder_gru_skip_op, + linear_groups=config.erb_decoder_linear_groups, + activation_layer="relu", + ) + + # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions + self.conv3p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + ) + self.convt3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + ) + self.conv2p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + ) + self.convt2 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + fstride=2, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + ) + self.conv1p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + ) + self.convt1 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + fstride=2, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + ) + self.conv0p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + ) + self.conv0_out = CausalConv2d( + in_channels=config.conv_channels, + out_channels=1, + kernel_size=config.conv_kernel_size_inner, + activation_layer="sigmoid", + bias=False, + separable=True, + ) + + def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: + # Estimates erb mask + b, _, t, f8 = e3.shape + emb, _ = self.emb_gru(emb) + emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8] + e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4] + e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2] + e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F] + m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F] + return m + + +class Mask(nn.Module): + def __init__(self, erb_inv_fb: torch.FloatTensor, post_filter: bool = False, eps: float = 1e-12): + super().__init__() + self.erb_inv_fb: torch.FloatTensor + self.register_buffer("erb_inv_fb", erb_inv_fb.float()) + self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0" + self.post_filter = post_filter + self.eps = eps + + def pf(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: + """ + Post-Filter + + A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. + https://arxiv.org/abs/2008.04259 + + :param mask: Real valued mask, typically of shape [B, C, T, F]. + :param beta: Global gain factor. + :return: + """ + mask_sin = mask * torch.sin(np.pi * mask / 2) + mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) + return mask_pf + + def forward(self, spec: torch.Tensor, mask: torch.Tensor, atten_lim: Optional[torch.Tensor] = None) -> torch.Tensor: + # spec (real) [B, 1, T, F, 2], F: freq_bins + # mask (real): [B, 1, T, Fe], Fe: erb_bins + # atten_lim: [B] + if not self.training and self.post_filter: + mask = self.pf(mask) + if atten_lim is not None: + # dB to amplitude + atten_lim = 10 ** (-atten_lim / 20) + # Greater equal (__ge__) not implemented for TorchVersion. + if self.clamp_tensor: + # Supported by torch >= 1.9 + mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1)) + else: + m_out = [] + for i in range(atten_lim.shape[0]): + m_out.append(mask[i].clamp_min(atten_lim[i].item())) + mask = torch.stack(m_out, dim=0) + mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F] + if not spec.is_complex(): + mask = mask.unsqueeze(4) + return spec * mask + + +class DfDecoder(nn.Module): + def __init__(self, + config: DfNetConfig, + ): + super().__init__() + layer_width = config.conv_channels + + self.emb_in_dim = config.conv_channels * config.erb_bins // 4 + self.emb_dim = config.df_hidden_dim + + self.df_n_hidden = config.df_hidden_dim + self.df_n_layers = config.df_num_layers + self.df_order = config.df_order + self.df_bins = config.df_bins + self.df_out_ch = config.df_order * 2 + + self.df_convp = CausalConv2d( + layer_width, + self.df_out_ch, + fstride=1, + kernel_size=(config.df_pathway_kernel_size_t, 1), + separable=True, + bias=False, + ) + self.df_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_dim, + num_layers=self.df_n_layers, + batch_first=True, + skip_op="none", + activation_layer="relu", + ) + + if config.df_gru_skip == "none": + self.df_skip = None + elif config.df_gru_skip == "identity": + if config.emb_hidden_dim != config.df_hidden_dim: + raise AssertionError("Dimensions do not match") + self.df_skip = nn.Identity() + elif config.df_gru_skip == "grouped_linear": + self.df_skip = GroupedLinear(self.emb_in_dim, self.emb_dim, groups=config.df_decoder_linear_groups) + else: + raise NotImplementedError() + + self.df_out: nn.Module + out_dim = self.df_bins * self.df_out_ch + + self.df_out = nn.Sequential( + GroupedLinear( + input_size=self.df_n_hidden, + hidden_size=out_dim, + groups=config.df_decoder_linear_groups + ), + nn.Tanh() + ) + self.df_fc_a = nn.Sequential( + nn.Linear(self.df_n_hidden, 1), + nn.Sigmoid() + ) + + def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: + b, t, _ = emb.shape + c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden + if self.df_skip is not None: + c = c + self.df_skip(emb) + c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last + c = self.df_out(c) # [B, T, F*O*2], O: df_order + c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2] + return c + + +class DfOutputReshapeMF(nn.Module): + """Coefficients output reshape for multiframe/MultiFrameModule + + Requires input of shape B, C, T, F, 2. + """ + + def __init__(self, df_order: int, df_bins: int): + super().__init__() + self.df_order = df_order + self.df_bins = df_bins + + def forward(self, coefs: torch.Tensor) -> torch.Tensor: + # [B, T, F, O*2] -> [B, O, T, F, 2] + new_shape = list(coefs.shape) + new_shape[-1] = -1 + new_shape.append(2) + coefs = coefs.view(new_shape) + coefs = coefs.permute(0, 3, 1, 2, 4) + return coefs + + +class DfNet(nn.Module): + """ + DeepFilterNet: Perceptually Motivated Real-Time Speech Enhancement + https://arxiv.org/abs/2305.08227 + + hendrik.m.schroeter@fau.de + """ + def __init__(self, + config: DfNetConfig, + erb_fb: torch.FloatTensor, + erb_inv_fb: torch.FloatTensor, + run_df: bool = True, + train_mask: bool = True, + ): + """ + :param erb_fb: erb filter bank. + """ + super(DfNet, self).__init__() + if config.erb_bins % 8 != 0: + raise AssertionError("erb_bins should be divisible by 8") + + self.df_lookahead = config.df_lookahead + self.df_bins = config.df_bins + self.freq_bins: int = config.fft_size // 2 + 1 + self.emb_dim: int = config.conv_channels * config.erb_bins + self.erb_bins: int = config.erb_bins + + if config.conv_lookahead > 0: + if config.conv_lookahead < config.df_lookahead: + raise AssertionError + # for last 2 dim, pad (left, right, top, bottom). + self.pad_feat = nn.ConstantPad2d((0, 0, -config.conv_lookahead, config.conv_lookahead), 0.0) + else: + self.pad_feat = nn.Identity() + + if config.df_lookahead > 0: + # for last 3 dim, pad (left, right, top, bottom, front, back). + self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -config.df_lookahead, config.df_lookahead), 0.0) + else: + self.pad_spec = nn.Identity() + + self.register_buffer("erb_fb", erb_fb) + + self.enc = Encoder(config) + self.erb_dec = ErbDecoder(config) + self.mask = Mask(erb_inv_fb) + + self.erb_inv_fb = erb_inv_fb + self.post_filter = config.mask_post_filter + self.post_filter_beta = config.post_filter_beta + + self.df_order = config.df_order + self.df_op = MF.DF(num_freqs=config.df_bins, frame_size=config.df_order, lookahead=self.df_lookahead) + self.df_dec = DfDecoder(config) + self.df_out_transform = DfOutputReshapeMF(self.df_order, config.df_bins) + + self.run_erb = config.df_bins + 1 < self.freq_bins + if not self.run_erb: + logger.warning("Running without ERB stage") + self.run_df = run_df + if not run_df: + logger.warning("Running without DF stage") + self.train_mask = train_mask + self.lsnr_dropout = config.lsnr_dropout + if config.df_n_iter != 1: + raise AssertionError + + def forward1( + self, + spec: torch.Tensor, + feat_erb: torch.Tensor, + feat_spec: torch.Tensor, # Not used, take spec modified by mask instead + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward method of DeepFilterNet2. + + Args: + spec (Tensor): Spectrum of shape [B, 1, T, F, 2] + feat_erb (Tensor): ERB features of shape [B, 1, T, E] + feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F', 2] + + Returns: + spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2] + m (Tensor): ERB mask estimate of shape [B, 1, T, E] + lsnr (Tensor): Local SNR estimate of shape [B, T, 1] + """ + # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2] + feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) + # feat_spec shape: [batch_size, 2, time_steps, freq_dim] + + # feat_erb shape: [batch_size, 1, time_steps, erb_bins] + # assert time_steps >= conv_lookahead. + feat_erb = self.pad_feat(feat_erb) + feat_spec = self.pad_feat(feat_spec) + e0, e1, e2, e3, emb, c0, lsnr, h = self.enc(feat_erb, feat_spec) + + if self.lsnr_droput: + idcs = lsnr.squeeze() > -10.0 + b, t = (spec.shape[0], spec.shape[2]) + m = torch.zeros((b, 1, t, self.erb_bins), device=spec.device) + df_coefs = torch.zeros((b, t, self.nb_df, self.df_order * 2)) + spec_m = spec.clone() + emb = emb[:, idcs] + e0 = e0[:, :, idcs] + e1 = e1[:, :, idcs] + e2 = e2[:, :, idcs] + e3 = e3[:, :, idcs] + c0 = c0[:, :, idcs] + + if self.run_erb: + if self.lsnr_dropout: + m[:, :, idcs] = self.erb_dec(emb, e3, e2, e1, e0) + else: + m = self.erb_dec(emb, e3, e2, e1, e0) + spec_m = self.mask(spec, m) + else: + m = torch.zeros((), device=spec.device) + spec_m = torch.zeros_like(spec) + + if self.run_df: + if self.lsnr_dropout: + df_coefs[:, idcs] = self.df_dec(emb, c0) + else: + df_coefs = self.df_dec(emb, c0) + df_coefs = self.df_out_transform(df_coefs) + spec_e = self.df_op(spec.clone(), df_coefs) + spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :] + else: + df_coefs = torch.zeros((), device=spec.device) + spec_e = spec_m + + if self.post_filter: + beta = self.post_filter_beta + eps = 1e-12 + mask = (utils.as_complex(spec_e).abs() / utils.as_complex(spec).abs().add(eps)).clamp(eps, 1) + mask_sin = mask * torch.sin(PI * mask / 2).clamp_min(eps) + pf = (1 + beta) / (1 + beta * mask.div(mask_sin).pow(2)) + spec_e = spec_e * pf.unsqueeze(-1) + + return spec_e, m, lsnr, df_coefs + + def forward( + self, + spec: torch.Tensor, + feat_erb: torch.Tensor, + feat_spec: torch.Tensor, # Not used, take spec modified by mask instead + erb_encoder_h: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2] + feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) + # feat_spec shape: [batch_size, 2, time_steps, freq_dim] + + # feat_erb shape: [batch_size, 1, time_steps, erb_bins] + # assert time_steps >= conv_lookahead. + feat_erb = self.pad_feat(feat_erb) + feat_spec = self.pad_feat(feat_spec) + e0, e1, e2, e3, emb, c0, lsnr, erb_encoder_h = self.enc(feat_erb, feat_spec, erb_encoder_h) + + m = self.erb_dec(emb, e3, e2, e1, e0) + spec_m = self.mask(spec, m) + # spec_e = spec_m + + df_coefs = self.df_dec(emb, c0) + df_coefs = self.df_out_transform(df_coefs) + spec_e = self.df_op(spec.clone(), df_coefs) + spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :] + + return spec_e, m, lsnr, df_coefs, erb_encoder_h + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dfnet3/multiframes.py b/toolbox/torchaudio/models/dfnet3/multiframes.py new file mode 100644 index 0000000000000000000000000000000000000000..23d0b9c121f81ee2b732f26708a593d62053acc8 --- /dev/null +++ b/toolbox/torchaudio/models/dfnet3/multiframes.py @@ -0,0 +1,145 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn + + +# From torchaudio +def _compute_mat_trace(input: torch.Tensor, dim1: int = -2, dim2: int = -1) -> torch.Tensor: + r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions. + Args: + input (torch.Tensor): Tensor of dimension `(..., channel, channel)` + dim1 (int, optional): the first dimension of the diagonal matrix + (Default: -1) + dim2 (int, optional): the second dimension of the diagonal matrix + (Default: -2) + Returns: + Tensor: trace of the input Tensor + """ + assert input.ndim >= 2, "The dimension of the tensor must be at least 2." + assert ( + input.shape[dim1] == input.shape[dim2] + ), "The size of ``dim1`` and ``dim2`` must be the same." + input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2) + return input.sum(dim=-1) + + +def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor: + """Perform Tikhonov regularization (only modifying real part). + Args: + mat (torch.Tensor): input matrix (..., channel, channel) + reg (float, optional): regularization factor (Default: 1e-8) + eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``) + Returns: + Tensor: regularized matrix (..., channel, channel) + """ + # Add eps + C = mat.size(-1) + eye = torch.eye(C, dtype=mat.dtype, device=mat.device) + epsilon = _compute_mat_trace(mat).real[..., None, None] * reg + # in case that correlation_matrix is all-zero + epsilon = epsilon + eps + mat = mat + epsilon * eye[..., :, :] + return mat + + +class MultiFrameModule(nn.Module): + """ + Multi-frame speech enhancement modules. + + Signal model and notation: + Noisy: `x = s + n` + Enhanced: `y = f(x)` + Objective: `min ||s - y||` + + PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD. + IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx` + RTF: Relative transfere function, also called steering vector. + """ + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, real: bool = False): + """ + Multi-Frame filtering module. + + :param num_freqs: int. Number of frequency bins used for filtering. + :param frame_size: int. Frame size in FD domain. + :param lookahead: int. Lookahead, may be used to select the output time step. + Note: This module does not add additional padding according to lookahead! + :param real: + """ + super().__init__() + self.num_freqs = num_freqs + self.frame_size = frame_size + self.real = real + if real: + self.pad = nn.ConstantPad3d((0, 0, 0, 0, frame_size - 1 - lookahead, lookahead), 0.0) + else: + self.pad = nn.ConstantPad2d((0, 0, frame_size - 1 - lookahead, lookahead), 0.0) + self.need_unfold = frame_size > 1 + self.lookahead = lookahead + + def spec_unfold_real(self, spec: torch.Tensor): + if self.need_unfold: + spec = self.pad(spec).unfold(-3, self.frame_size, 1) + return spec.permute(0, 1, 5, 2, 3, 4) + # return as_windowed(self.pad(spec), self.frame_size, 1, dim=-3) + return spec.unsqueeze(-1) + + def spec_unfold(self, spec: torch.Tensor): + """Pads and unfolds the spectrogram according to frame_size. + + Args: + spec (complex Tensor): Spectrogram of shape [B, C, T, F] + Returns: + spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. + """ + if self.need_unfold: + return self.pad(spec).unfold(2, self.frame_size, 1) + return spec.unsqueeze(-1) + + @staticmethod + def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> torch.Tensor: + return torch.einsum( + "...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss + ) # [T, F, N] + + @staticmethod + def apply_coefs(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: + # spec: [B, C, T, F, N] + # coefs: [B, C, T, F, N] + return torch.einsum("...n,...n->...", spec, coefs) + + +class DF(MultiFrameModule): + """Deep Filtering.""" + + def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False): + super().__init__(num_freqs, frame_size, lookahead) + self.conj: bool = conj + + def forward(self, spec: torch.Tensor, coefs: torch.Tensor): + spec_u = self.spec_unfold(torch.view_as_complex(spec)) + coefs = torch.view_as_complex(coefs) + spec_f = spec_u.narrow(-2, 0, self.num_freqs) + coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:]) + if self.conj: + coefs = coefs.conj() + spec_f = self.df(spec_f, coefs) + if self.training: + spec = spec.clone() + spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f) + return spec + + @staticmethod + def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: + """ + Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. + :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. + :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. + :return: (complex Tensor). Spectrogram of shape [B, C, T, F]. + """ + return torch.einsum("...tfn,...ntf->...tf", spec, coefs) + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/dfnet3/utils.py b/toolbox/torchaudio/models/dfnet3/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..57ddff152e685aeeb9df59bd67fb77415f5a6cdd --- /dev/null +++ b/toolbox/torchaudio/models/dfnet3/utils.py @@ -0,0 +1,17 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import torch + + +def as_complex(x: torch.Tensor): + if torch.is_complex(x): + return x + if x.shape[-1] != 2: + raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}") + if x.stride(-1) != 1: + x = x.contiguous() + return torch.view_as_complex(x) + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/ehnet/__init__.py b/toolbox/torchaudio/models/ehnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/ehnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/ehnet/modeling_ehnet.py b/toolbox/torchaudio/models/ehnet/modeling_ehnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0acf083dc5fb6f61c5da78d866081310f3768d1c --- /dev/null +++ b/toolbox/torchaudio/models/ehnet/modeling_ehnet.py @@ -0,0 +1,132 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/1805.00579 + +https://github.com/haoxiangsnr/A-Convolutional-Recurrent-Neural-Network-for-Real-Time-Speech-Enhancement + +""" +import torch +import torch.nn as nn + + +class CausalConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 2), + stride=(2, 1), + padding=(0, 1) + ) + self.norm = nn.BatchNorm2d(num_features=out_channels) + self.activation = nn.ELU() + + def forward(self, x): + """ + 2D Causal convolution. + Args: + x: [B, C, F, T] + + Returns: + [B, C, F, T] + """ + x = self.conv(x) + x = x[:, :, :, :-1] # chomp size + x = self.norm(x) + x = self.activation(x) + return x + + +class CausalTransConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)): + super().__init__() + self.conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 2), + stride=(2, 1), + output_padding=output_padding + ) + self.norm = nn.BatchNorm2d(num_features=out_channels) + if is_last: + self.activation = nn.ReLU() + else: + self.activation = nn.ELU() + + def forward(self, x): + """ + 2D Causal convolution. + Args: + x: [B, C, F, T] + + Returns: + [B, C, F, T] + """ + x = self.conv(x) + x = x[:, :, :, :-1] # chomp size + x = self.norm(x) + x = self.activation(x) + return x + + + +class CRN(nn.Module): + """ + Input: [batch size, channels=1, T, n_fft] + Output: [batch size, T, n_fft] + """ + + def __init__(self): + super(CRN, self).__init__() + # Encoder + self.conv_block_1 = CausalConvBlock(1, 16) + self.conv_block_2 = CausalConvBlock(16, 32) + self.conv_block_3 = CausalConvBlock(32, 64) + self.conv_block_4 = CausalConvBlock(64, 128) + self.conv_block_5 = CausalConvBlock(128, 256) + + # LSTM + self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True) + + self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128) + self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64) + self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32) + self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0)) + self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True) + + def forward(self, x): + self.lstm_layer.flatten_parameters() + + e_1 = self.conv_block_1(x) + e_2 = self.conv_block_2(e_1) + e_3 = self.conv_block_3(e_2) + e_4 = self.conv_block_4(e_3) + e_5 = self.conv_block_5(e_4) # [2, 256, 4, 200] + + batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape + + # [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024] + lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1) + lstm_out, _ = self.lstm_layer(lstm_in) # [2, 200, 1024] + lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size) # [2, 256, 4, 200] + + d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1)) + d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1)) + d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1)) + d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1)) + d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1)) + + return d_5 + + +def main(): + layer = CRN() + a = torch.rand(2, 1, 161, 200) + print(layer(a).shape) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/percepnet/__init__.py b/toolbox/torchaudio/models/percepnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/percepnet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/percepnet/modeling_percetnet.py b/toolbox/torchaudio/models/percepnet/modeling_percetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f38e243b78bebe6baa7a213d0a0b359e782ca1bc --- /dev/null +++ b/toolbox/torchaudio/models/percepnet/modeling_percetnet.py @@ -0,0 +1,11 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/jzi040941/PercepNet + +https://arxiv.org/abs/2008.04259 +""" + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/phm_unet/__init__.py b/toolbox/torchaudio/models/phm_unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/phm_unet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/phm_unet/modeling_phm_unet.py b/toolbox/torchaudio/models/phm_unet/modeling_phm_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..0042796e18b4c97fbb05b80d537532f2f7c0460c --- /dev/null +++ b/toolbox/torchaudio/models/phm_unet/modeling_phm_unet.py @@ -0,0 +1,9 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://arxiv.org/abs/2006.00687 +""" + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/rnnoise/__init__.py b/toolbox/torchaudio/models/rnnoise/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/rnnoise/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py b/toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py new file mode 100644 index 0000000000000000000000000000000000000000..8260a3f0f002956538fe10f32b50d20f9ba0712c --- /dev/null +++ b/toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py @@ -0,0 +1,11 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/xiph/rnnoise + +https://arxiv.org/abs/1709.08243 + +""" + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/simple_linear_irm/__init__.py b/toolbox/torchaudio/models/simple_linear_irm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79ab8cb427356b85a4fd65def3fc8101a5ee301a --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/__init__.py @@ -0,0 +1,59 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/zhaoforever/nn-irm +""" +import torch +import torch.nn as nn + +import torchaudio + + +class NNIRM(nn.Module): + """ + Ideal ratio mask estimator: + default config: 1799(257 x 7) => 2048 => 2048 => 2048 => 257 + """ + + def __init__(self, num_bins=257, n_frames=7, hidden_size=2048): + super(NNIRM, self).__init__() + self.nn = nn.Sequential( + nn.Linear(num_bins * n_frames, hidden_size), + nn.ReLU(), + nn.BatchNorm1d(hidden_size), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.BatchNorm1d(hidden_size), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.BatchNorm1d(hidden_size), + nn.Linear(hidden_size, num_bins), + nn.Sigmoid() + ) + + def forward(self, x): + return self.nn(x) + + +def main(): + signal = torch.rand(size=(16000,)) + + transformer = torchaudio.transforms.MelSpectrogram( + sample_rate=8000, + n_fft=512, + win_length=200, + hop_length=80, + f_min=10, + f_max=3800, + window_fn=torch.hamming_window, + n_mels=80, + ) + + inputs = torch.tensor([signal], dtype=torch.float32) + output = transformer.forward(inputs) + print(output) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/simple_linear_irm/configuration_simple_linear_irm.py b/toolbox/torchaudio/models/simple_linear_irm/configuration_simple_linear_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..732016b9bb1cfee9b528b69776a1a700a22170c0 --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/configuration_simple_linear_irm.py @@ -0,0 +1,33 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class SimpleLinearIRMConfig(PretrainedConfig): + def __init__(self, + sample_rate: int, + n_fft: int, + win_length: int, + hop_length: int, + + num_bins: int, + hidden_size: int, + lookback: int, + lookahead: int, + + **kwargs + ): + super(SimpleLinearIRMConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + self.num_bins = num_bins + self.hidden_size = hidden_size + self.lookback = lookback + self.lookahead = lookahead + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/simple_linear_irm/modeling_simple_linear_irm.py b/toolbox/torchaudio/models/simple_linear_irm/modeling_simple_linear_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc4f78150cfd9962a23e376b1dca60fa7bab5f4 --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/modeling_simple_linear_irm.py @@ -0,0 +1,167 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/WenzheLiu-Speech/awesome-speech-enhancement +https://github.com/zhaoforever/nn-irm +""" +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +import torchaudio + +from toolbox.torchaudio.models.simple_linear_irm.configuration_simple_linear_irm import SimpleLinearIRMConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE + + +MODEL_FILE = "model.pt" + + +class Transpose(nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, inputs: torch.Tensor): + inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1) + return inputs + + +class SimpleLinearIRM(nn.Module): + """ + Ideal ratio mask estimator: + default config: 1799(257 x 7) => 2048 => 2048 => 2048 => 257 + """ + + def __init__(self, num_bins=257, hidden_size=2048, lookback: int = 3, lookahead: int = 3): + super(SimpleLinearIRM, self).__init__() + self.num_bins = num_bins + self.hidden_size = hidden_size + self.lookback = lookback + self.lookahead = lookahead + + self.n_frames = lookback + 1 + lookahead + + self.nn = nn.Sequential( + Transpose(dim0=2, dim1=1), + nn.Linear(num_bins * self.n_frames, hidden_size), + nn.ReLU(), + Transpose(dim0=2, dim1=1), + nn.BatchNorm1d(hidden_size), + Transpose(dim0=2, dim1=1), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + Transpose(dim0=2, dim1=1), + nn.BatchNorm1d(hidden_size), + Transpose(dim0=2, dim1=1), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + Transpose(dim0=2, dim1=1), + nn.BatchNorm1d(hidden_size), + Transpose(dim0=2, dim1=1), + nn.Linear(hidden_size, num_bins), + Transpose(dim0=2, dim1=1), + + nn.Sigmoid() + ) + + @staticmethod + def frame_spec(spec: torch.Tensor, lookback: int = 3, lookahead: int = 3): + context = lookback + 1 + lookahead + + # batch, num_bins, time_steps + batch, num_bins, time_steps = spec.shape + + spec_ = torch.zeros([batch, context * num_bins, time_steps], dtype=spec.dtype) + for t in range(time_steps): + for c in range(context): + begin = c * num_bins + + t_ = t - lookback + c + t_ = 0 if t_ < 0 else t_ + t_ = time_steps - 1 if t_ > time_steps - 1 else t_ + spec_[:, begin: begin + num_bins, t] = spec[:, :, t_] + spec_ = spec_.to(spec.device) + return spec_ + + def forward(self, spec: torch.Tensor): + # spec shape: (batch_size, num_bins, time_steps) + frame_spec = self.frame_spec(spec, 3, 3) + # frame_spec shape: (batch_size, context * num_bins, time_steps) + mask = self.nn.forward(frame_spec) + return mask + + +class SimpleLinearIRMPretrainedModel(SimpleLinearIRM): + def __init__(self, + config: SimpleLinearIRMConfig, + ): + super(SimpleLinearIRMPretrainedModel, self).__init__( + num_bins=config.num_bins, + hidden_size=config.hidden_size, + lookback=config.lookback, + lookahead=config.lookahead, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = SimpleLinearIRMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + transformer = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, + ) + + model = SimpleLinearIRM() + + inputs = torch.randn(size=(1, 1600), dtype=torch.float32) + spec = transformer.forward(inputs) + + output = model.forward(spec) + print(output.shape) + print(output) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/simple_linear_irm/yaml/config.yaml b/toolbox/torchaudio/models/simple_linear_irm/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f269122b6f589c012376cda288e9a256f1453de2 --- /dev/null +++ b/toolbox/torchaudio/models/simple_linear_irm/yaml/config.yaml @@ -0,0 +1,13 @@ +model_name: "simple_linear_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +# model +num_bins: 257 +hidden_size: 2048 +lookback: 3 +lookahead: 3 diff --git a/toolbox/torchaudio/models/simple_lstm_irm/__init__.py b/toolbox/torchaudio/models/simple_lstm_irm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/simple_lstm_irm/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/simple_lstm_irm/configuration_simple_lstm_irm.py b/toolbox/torchaudio/models/simple_lstm_irm/configuration_simple_lstm_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..a9240de3d2c833039aa65c6ec1d711af767bdc18 --- /dev/null +++ b/toolbox/torchaudio/models/simple_lstm_irm/configuration_simple_lstm_irm.py @@ -0,0 +1,38 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class SimpleLstmIRMConfig(PretrainedConfig): + def __init__(self, + sample_rate: int, + n_fft: int, + win_length: int, + hop_length: int, + + num_bins: int, + hidden_size: int, + num_layers: int, + batch_first: bool, + dropout: float, + lookback: int, + lookahead: int, + **kwargs + ): + super(SimpleLstmIRMConfig, self).__init__(**kwargs) + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + self.num_bins = num_bins + self.hidden_size = hidden_size + self.num_layers = num_layers + self.batch_first = batch_first + self.dropout = dropout + self.lookback = lookback + self.lookahead = lookahead + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py b/toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..1b6aaed6a2e335c2df227e489f672829e5d4dffe --- /dev/null +++ b/toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py @@ -0,0 +1,141 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py +""" +import os +from typing import Optional, Union + +import torch +import torch.nn as nn +import torchaudio + +from toolbox.torchaudio.models.simple_lstm_irm.configuration_simple_lstm_irm import SimpleLstmIRMConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE + + +MODEL_FILE = "model.pt" + + +class Transpose(nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, inputs: torch.Tensor): + inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1) + return inputs + + +class SimpleLstmIRM(nn.Module): + """ + Ideal ratio mask estimator: + + """ + + def __init__(self, num_bins=257, hidden_size=1024, + num_layers: int = 2, + batch_first: bool = True, + dropout: float = 0.4, + lookback: int = 3, + lookahead: int = 3, + ): + super(SimpleLstmIRM, self).__init__() + self.num_bins = num_bins + self.hidden_size = hidden_size + self.lookback = lookback + self.lookahead = lookahead + + # self.n_frames = lookback + 1 + lookahead + + self.lstm = nn.LSTM(input_size=num_bins, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + dropout=dropout, + ) + self.linear = nn.Linear(in_features=hidden_size, out_features=num_bins) + self.activation = nn.Sigmoid() + + def forward(self, spec: torch.Tensor): + # spec shape: (batch_size, num_bins, time_steps) + spec = torch.transpose(spec, dim0=2, dim1=1) + # frame_spec shape: (batch_size, time_steps, num_bins) + spec, _ = self.lstm(spec) + spec = self.linear(spec) + mask = self.activation(spec) + return mask + + +class SimpleLstmIRMPretrainedModel(SimpleLstmIRM): + def __init__(self, + config: SimpleLstmIRMConfig, + ): + super(SimpleLstmIRMPretrainedModel, self).__init__( + num_bins=config.num_bins, + hidden_size=config.hidden_size, + lookback=config.lookback, + lookahead=config.lookahead, + ) + self.config = config + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = SimpleLstmIRMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + transformer = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, + ) + + model = SimpleLstmIRM() + + inputs = torch.randn(size=(1, 1600), dtype=torch.float32) + spec = transformer.forward(inputs) + + output = model.forward(spec) + print(output.shape) + print(output) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml b/toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04cfeff801f17f382c0702b1ab506b9c275239ff --- /dev/null +++ b/toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml @@ -0,0 +1,16 @@ +model_name: "simple_lstm_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +# model +num_bins: 257 +hidden_size: 1024 +num_layers: 2 +batch_first: true +dropout: 0.4 +lookback: 3 +lookahead: 3 diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/__init__.py b/toolbox/torchaudio/models/spectrum_unet_irm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py b/toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..f9642b31370bbcf43d6b2556e00be6c8dba373db --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/configuration_specturm_unet_irm.py @@ -0,0 +1,72 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +from typing import Tuple + +from toolbox.torchaudio.configuration_utils import PretrainedConfig + + +class SpectrumUnetIRMConfig(PretrainedConfig): + def __init__(self, + sample_rate: int = 8000, + n_fft: int = 512, + win_length: int = 200, + hop_length: int = 80, + + spec_bins: int = 256, + + conv_channels: int = 64, + conv_kernel_size_input: Tuple[int, int] = (3, 3), + conv_kernel_size_inner: Tuple[int, int] = (1, 3), + conv_lookahead: int = 0, + + convt_kernel_size_inner: Tuple[int, int] = (1, 3), + + encoder_emb_skip_op: str = "none", + encoder_emb_linear_groups: int = 16, + encoder_emb_hidden_size: int = 256, + + lsnr_max: int = 20, + lsnr_min: int = -10, + + decoder_emb_num_layers: int = 3, + decoder_emb_skip_op: str = "none", + decoder_emb_linear_groups: int = 16, + decoder_emb_hidden_size: int = 256, + + **kwargs + ): + super(SpectrumUnetIRMConfig, self).__init__(**kwargs) + # transform + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + + # spectrum + self.spec_bins = spec_bins + + # conv + self.conv_channels = conv_channels + self.conv_kernel_size_input = conv_kernel_size_input + self.conv_kernel_size_inner = conv_kernel_size_inner + self.conv_lookahead = conv_lookahead + + self.convt_kernel_size_inner = convt_kernel_size_inner + + # encoder + self.encoder_emb_skip_op = encoder_emb_skip_op + self.encoder_emb_linear_groups = encoder_emb_linear_groups + self.encoder_emb_hidden_size = encoder_emb_hidden_size + + self.lsnr_max = lsnr_max + self.lsnr_min = lsnr_min + + # decoder + self.decoder_emb_num_layers = decoder_emb_num_layers + self.decoder_emb_skip_op = decoder_emb_skip_op + self.decoder_emb_linear_groups = decoder_emb_linear_groups + self.decoder_emb_hidden_size = decoder_emb_hidden_size + + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py b/toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f99f88e03b522d0a0ce3c321d388f60b452ca7 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/modeling_spectrum_unet_irm.py @@ -0,0 +1,646 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import os +import math +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torchaudio + +from toolbox.torchaudio.models.spectrum_unet_irm.configuration_specturm_unet_irm import SpectrumUnetIRMConfig +from toolbox.torchaudio.configuration_utils import CONFIG_FILE + + +MODEL_FILE = "model.pt" + + +norm_layer_dict = { + "batch_norm_2d": torch.nn.BatchNorm2d +} + + +activation_layer_dict = { + "relu": torch.nn.ReLU, + "identity": torch.nn.Identity, + "sigmoid": torch.nn.Sigmoid, +} + + +class CausalConv2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal Conv2d by delaying the signal for any lookahead. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + + :param in_channels: + :param out_channels: + :param kernel_size: + :param fstride: + :param dilation: + :param fpad: + """ + super(CausalConv2d, self).__init__() + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) + + if fpad: + fpad_ = kernel_size[1] // 2 + dilation - 1 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = list() + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + if max(kernel_size) == 1: + separable = False + + layers.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + def forward(self, inputs): + for module in self: + inputs = module(inputs) + return inputs + + +class CausalConvTranspose2d(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Iterable[int]], + fstride: int = 1, + dilation: int = 1, + fpad: bool = True, + bias: bool = True, + separable: bool = False, + norm_layer: str = "batch_norm_2d", + activation_layer: str = "relu", + lookahead: int = 0 + ): + """ + Causal ConvTranspose2d. + + Expected input format: [batch_size, channels, time_steps, spec_dim] + """ + super(CausalConvTranspose2d, self).__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + if fpad: + fpad_ = kernel_size[1] // 2 + else: + fpad_ = 0 + + # for last 2 dim, pad (left, right, top, bottom). + pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) + + layers = [] + if any(x > 0 for x in pad): + layers.append(nn.ConstantPad2d(pad, 0.0)) + + groups = math.gcd(in_channels, out_channels) if separable else 1 + if groups == 1: + separable = False + + layers.append( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=(kernel_size[0] - 1, fpad_ + dilation - 1), + output_padding=(0, fpad_), + stride=(1, fstride), # stride over time is always 1 + dilation=(1, dilation), # dilation over time is always 1 + groups=groups, + bias=bias, + ) + ) + + if separable: + layers.append( + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + bias=False, + ) + ) + + if norm_layer is not None: + norm_layer = norm_layer_dict[norm_layer] + layers.append(norm_layer(out_channels)) + + if activation_layer is not None: + activation_layer = activation_layer_dict[activation_layer] + layers.append(activation_layer()) + + super().__init__(*layers) + + +class GroupedLinear(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, groups: int = 1): + super().__init__() + # self.weight: Tensor + self.input_size = input_size + self.hidden_size = hidden_size + self.groups = groups + assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" + assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" + self.ws = input_size // groups + self.register_parameter( + "weight", + torch.nn.Parameter( + torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True + ), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., I] + b, t, _ = x.shape + # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] + new_shape = (b, t, self.groups, self.ws) + x = x.view(new_shape) + # The better way, but not supported by torchscript + # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] + x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] + x = x.flatten(2, 3) # [B, T, H] + return x + + def __repr__(self): + cls = self.__class__.__name__ + return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" + + +class SqueezedGRU_S(nn.Module): + """ + SGE net: Video object detection with squeezed GRU and information entropy map + https://arxiv.org/abs/2106.07224 + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: Optional[int] = None, + num_layers: int = 1, + linear_groups: int = 8, + batch_first: bool = True, + skip_op: str = "none", + activation_layer: str = "identity", + ): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + + self.linear_in = nn.Sequential( + GroupedLinear( + input_size=input_size, + hidden_size=hidden_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + + # gru skip operator + self.gru_skip_op = None + + if skip_op == "none": + self.gru_skip_op = None + elif skip_op == "identity": + if not input_size != output_size: + raise AssertionError("Dimensions do not match") + self.gru_skip_op = nn.Identity() + elif skip_op == "grouped_linear": + self.gru_skip_op = GroupedLinear( + input_size=hidden_size, + hidden_size=hidden_size, + groups=linear_groups, + ) + else: + raise NotImplementedError() + + self.gru = nn.GRU( + input_size=hidden_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=batch_first, + bidirectional=False, + ) + + if output_size is not None: + self.linear_out = nn.Sequential( + GroupedLinear( + input_size=hidden_size, + hidden_size=output_size, + groups=linear_groups, + ), + activation_layer_dict[activation_layer](), + ) + else: + self.linear_out = nn.Identity() + + def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.linear_in(inputs) + + x, h = self.gru.forward(x, h) + + x = self.linear_out(x) + + if self.gru_skip_op is not None: + x = x + self.gru_skip_op(inputs) + + return x, h + + +class Encoder(nn.Module): + def __init__(self, config: SpectrumUnetIRMConfig): + super(Encoder, self).__init__() + + self.conv0 = CausalConv2d( + in_channels=1, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_input, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv1 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv2 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + # emb_gru + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * (config.spec_bins // 4) + self.emb_out_dim = config.conv_channels * (config.spec_bins // 4) + self.emb_hidden_dim = config.encoder_emb_hidden_size + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=1, + batch_first=True, + skip_op=config.encoder_emb_skip_op, + linear_groups=config.encoder_emb_linear_groups, + activation_layer="relu", + ) + + # lsnr + self.lsnr_fc = nn.Sequential( + nn.Linear(self.emb_out_dim, 1), + nn.Sigmoid() + ) + self.lsnr_scale = config.lsnr_max - config.lsnr_min + self.lsnr_offset = config.lsnr_min + + + def forward(self, + spec: torch.Tensor, + hidden_state: torch.Tensor = None, + ): + # spec shape: (batch_size, 1, time_steps, spec_dim) + e0 = self.conv0.forward(spec) + e1 = self.conv1.forward(e0) + e2 = self.conv2.forward(e1) + e3 = self.conv3.forward(e2) + + # e3 shape: [batch_size, channels, time_steps, hidden_size] + emb = e3.permute(0, 2, 3, 1) + # emb shape: [batch_size, time_steps, hidden_size, channels] + emb = emb.flatten(2) + # emb shape: [batch_size, time_steps, hidden_size * channels] + emb, h = self.emb_gru.forward(emb, hidden_state) + + lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset + + return e0, e1, e2, e3, emb, lsnr + + +class Decoder(nn.Module): + def __init__(self, config: SpectrumUnetIRMConfig): + super(Decoder, self).__init__() + + if config.spec_bins % 8 != 0: + raise AssertionError("spec_bins should be divisible by 8") + + self.emb_in_dim = config.conv_channels * config.spec_bins // 4 + self.emb_out_dim = config.conv_channels * config.spec_bins // 4 + self.emb_hidden_dim = config.decoder_emb_hidden_size + + self.emb_gru = SqueezedGRU_S( + self.emb_in_dim, + self.emb_hidden_dim, + output_size=self.emb_out_dim, + num_layers=config.decoder_emb_num_layers - 1, + batch_first=True, + skip_op=config.decoder_emb_skip_op, + linear_groups=config.decoder_emb_linear_groups, + activation_layer="relu", + ) + self.conv3p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt3 = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.conv_kernel_size_inner, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv2p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt2 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv1p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.convt1 = CausalConvTranspose2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=config.convt_kernel_size_inner, + bias=False, + separable=True, + fstride=2, + lookahead=config.conv_lookahead, + ) + self.conv0p = CausalConv2d( + in_channels=config.conv_channels, + out_channels=config.conv_channels, + kernel_size=1, + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + self.conv0_out = CausalConv2d( + in_channels=config.conv_channels, + out_channels=1, + kernel_size=config.conv_kernel_size_inner, + activation_layer="sigmoid", + bias=False, + separable=True, + fstride=1, + lookahead=config.conv_lookahead, + ) + + + def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: + # Estimates erb mask + b, _, t, f8 = e3.shape + + # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] + emb, _ = self.emb_gru(emb) + # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) + e3 = self.convt3(self.conv3p(e3) + emb) + # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + e2 = self.convt2(self.conv2p(e2) + e3) + # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] + e1 = self.convt1(self.conv1p(e1) + e2) + # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] + mask = self.conv0_out(self.conv0p(e0) + e1) + # mask shape: [batch_size, 1, time_steps, freq_dim] + return mask + + +class SpectrumUnetIRM(nn.Module): + def __init__(self, config: SpectrumUnetIRMConfig): + super(SpectrumUnetIRM, self).__init__() + self.config = config + self.encoder = Encoder(config) + self.decoder = Decoder(config) + + def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: + """ + 总体上来说, 它会将 mask 中的值都调大一点. 可能是为了保留更多的声音以免损伤音质, 因为预测的 mask 肯定不是特别正确. + 这个不参与训练, 只在推理时应用在 mask 上. + + Post-Filter + + A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. + https://arxiv.org/abs/2008.04259 + + :param mask: Real valued mask, typically of shape [B, C, T, F]. + :param beta: Global gain factor. + :return: + """ + mask_sin = mask * torch.sin(np.pi * mask / 2) + mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) + return mask_pf + + def forward(self, + spec: torch.Tensor, + ): + # spec shape: [batch_size, freq_dim, time_steps] + # spec shape: [batch_size, 1, freq_dim, time_steps] + + spec = spec.unsqueeze(1).permute(0, 1, 3, 2) + # spec shape: [batch_size, channel, time_steps, freq_dim] + + e0, e1, e2, e3, emb, lsnr = self.encoder.forward(spec) + + # e0 shape: [batch_size, conv_channels, time_steps, freq_dim] + # e1 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] + # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] + # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] + # lsnr shape: [batch_size, time_steps, 1] + # h shape: [batch_size, 1, freq_dim] + + mask = self.decoder.forward(emb, e3, e2, e1, e0) + + if torch.any(mask > 1): + raise AssertionError + if torch.any(mask < 0): + raise AssertionError + + # mask shape: [batch_size, 1, time_steps, freq_dim] + # lsnr shape: [batch_size, time_steps, 1] + mask = torch.squeeze(mask, dim=1) + mask = torch.transpose(mask, dim0=2, dim1=1) + lsnr = torch.transpose(lsnr, dim0=2, dim1=1) + + # mask shape: [batch_size, freq_dim, time_steps] + # lsnr shape: [batch_size, 1, time_steps] + return mask, lsnr + + +class SpectrumUnetIRMPretrainedModel(SpectrumUnetIRM): + def __init__(self, + config: SpectrumUnetIRMConfig, + ): + super(SpectrumUnetIRMPretrainedModel, self).__init__( + config=config, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = SpectrumUnetIRMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + model = cls(config) + + if os.path.isdir(pretrained_model_name_or_path): + ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) + else: + ckpt_file = pretrained_model_name_or_path + + with open(ckpt_file, "rb") as f: + state_dict = torch.load(f, map_location="cpu", weights_only=True) + model.load_state_dict(state_dict, strict=True) + return model + + def save_pretrained(self, + save_directory: Union[str, os.PathLike], + state_dict: Optional[dict] = None, + ): + + model = self + + if state_dict is None: + state_dict = model.state_dict() + + os.makedirs(save_directory, exist_ok=True) + + # save state dict + model_file = os.path.join(save_directory, MODEL_FILE) + torch.save(state_dict, model_file) + + # save config + config_file = os.path.join(save_directory, CONFIG_FILE) + self.config.to_yaml_file(config_file) + return save_directory + + +def main(): + transformer = torchaudio.transforms.Spectrogram( + n_fft=512, + win_length=200, + hop_length=80, + window_fn=torch.hamming_window, + ) + + config = SpectrumUnetIRMConfig() + model = SpectrumUnetIRM(config=config) + + inputs = torch.randn(size=(1, 16000), dtype=torch.float32) + spec = transformer.forward(inputs) + spec = spec[:, :-1, :] + + output = model.forward(spec) + print(output[0].shape) + return + + +if __name__ == '__main__': + main() diff --git a/toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml b/toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b756c46aeccf91428ed137a1c11aeb127509ac00 --- /dev/null +++ b/toolbox/torchaudio/models/spectrum_unet_irm/yaml/config.yaml @@ -0,0 +1,35 @@ +model_name: "spectrum_unet_irm" + +# spec +sample_rate: 8000 +n_fft: 512 +win_length: 200 +hop_length: 80 + +spec_bins: 256 + +# model +conv_channels: 64 +conv_kernel_size_input: + - 3 + - 3 +conv_kernel_size_inner: + - 1 + - 3 +conv_lookahead: 0 + +convt_kernel_size_inner: + - 1 + - 3 + +encoder_emb_skip_op: "none" +encoder_emb_linear_groups: 16 +encoder_emb_hidden_size: 256 + +lsnr_max: 20 +lsnr_min: -10 + +decoder_emb_num_layers: 3 +decoder_emb_skip_op: "none" +decoder_emb_linear_groups: 16 +decoder_emb_hidden_size: 256 diff --git a/toolbox/torchaudio/models/wave_unet/__init__.py b/toolbox/torchaudio/models/wave_unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc5155c67cae42f80e8126d1727b0edc1e02398 --- /dev/null +++ b/toolbox/torchaudio/models/wave_unet/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + + +if __name__ == '__main__': + pass diff --git a/toolbox/torchaudio/models/wave_unet/modeling_wave_unet.py b/toolbox/torchaudio/models/wave_unet/modeling_wave_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e6429cae0b97658ecf47842007ce4fd7839443f5 --- /dev/null +++ b/toolbox/torchaudio/models/wave_unet/modeling_wave_unet.py @@ -0,0 +1,12 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +https://github.com/YosukeSugiura/Wave-U-Net-for-Speech-Enhancement-NNabla + +https://arxiv.org/abs/1811.11307 + +""" + + +if __name__ == '__main__': + pass