HoneyTian commited on
Commit
0d6ae9b
·
1 Parent(s): 65a472d
examples/nx_denoise/run.sh ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
18
+ --max_epochs 100
19
+
20
+
21
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
22
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
23
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
24
+ --max_epochs 100 --max_count 10000
25
+
26
+
27
+ END
28
+
29
+
30
+ # params
31
+ system_version="windows";
32
+ verbose=true;
33
+ stage=0 # start from 0 if you need to start from data preparation
34
+ stop_stage=9
35
+
36
+ work_dir="$(pwd)"
37
+ file_folder_name=file_folder_name
38
+ final_model_name=final_model_name
39
+ config_file="yaml/config.yaml"
40
+ limit=10
41
+
42
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
43
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
44
+
45
+ max_count=10000000
46
+
47
+ nohup_name=nohup.out
48
+
49
+ # model params
50
+ batch_size=64
51
+ max_epochs=200
52
+ save_top_k=10
53
+ patience=5
54
+
55
+
56
+ # parse options
57
+ while true; do
58
+ [ -z "${1:-}" ] && break; # break if there are no arguments
59
+ case "$1" in
60
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
61
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
62
+ old_value="(eval echo \\$$name)";
63
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
64
+ was_bool=true;
65
+ else
66
+ was_bool=false;
67
+ fi
68
+
69
+ # Set the variable to the right value-- the escaped quotes make it work if
70
+ # the option had spaces, like --cmd "queue.pl -sync y"
71
+ eval "${name}=\"$2\"";
72
+
73
+ # Check that Boolean-valued arguments are really Boolean.
74
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
75
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
76
+ exit 1;
77
+ fi
78
+ shift 2;
79
+ ;;
80
+
81
+ *) break;
82
+ esac
83
+ done
84
+
85
+ file_dir="${work_dir}/${file_folder_name}"
86
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
87
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
88
+
89
+ dataset="${file_dir}/dataset.xlsx"
90
+ train_dataset="${file_dir}/train.xlsx"
91
+ valid_dataset="${file_dir}/valid.xlsx"
92
+
93
+ $verbose && echo "system_version: ${system_version}"
94
+ $verbose && echo "file_folder_name: ${file_folder_name}"
95
+
96
+ if [ $system_version == "windows" ]; then
97
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
98
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
99
+ #source /data/local/bin/nx_denoise/bin/activate
100
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
101
+ fi
102
+
103
+
104
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
105
+ $verbose && echo "stage 1: prepare data"
106
+ cd "${work_dir}" || exit 1
107
+ python3 step_1_prepare_data.py \
108
+ --file_dir "${file_dir}" \
109
+ --noise_dir "${noise_dir}" \
110
+ --speech_dir "${speech_dir}" \
111
+ --train_dataset "${train_dataset}" \
112
+ --valid_dataset "${valid_dataset}" \
113
+ --max_count "${max_count}" \
114
+
115
+ fi
116
+
117
+
118
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
119
+ $verbose && echo "stage 2: train model"
120
+ cd "${work_dir}" || exit 1
121
+ python3 step_2_train_model.py \
122
+ --train_dataset "${train_dataset}" \
123
+ --valid_dataset "${valid_dataset}" \
124
+ --serialization_dir "${file_dir}" \
125
+ --config_file "${config_file}" \
126
+
127
+ fi
128
+
129
+
130
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
131
+ $verbose && echo "stage 3: test model"
132
+ cd "${work_dir}" || exit 1
133
+ python3 step_3_evaluation.py \
134
+ --valid_dataset "${valid_dataset}" \
135
+ --model_dir "${file_dir}/best" \
136
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
137
+ --limit "${limit}" \
138
+
139
+ fi
140
+
141
+
142
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
143
+ $verbose && echo "stage 4: collect files"
144
+ cd "${work_dir}" || exit 1
145
+
146
+ mkdir -p ${final_model_dir}
147
+
148
+ cp "${file_dir}/best"/* "${final_model_dir}"
149
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
150
+
151
+ cd "${final_model_dir}/.." || exit 1;
152
+
153
+ if [ -e "${final_model_name}.zip" ]; then
154
+ rm -rf "${final_model_name}_backup.zip"
155
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
156
+ fi
157
+
158
+ zip -r "${final_model_name}.zip" "${final_model_name}"
159
+ rm -rf "${final_model_name}"
160
+
161
+ fi
162
+
163
+
164
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
165
+ $verbose && echo "stage 5: clear file_dir"
166
+ cd "${work_dir}" || exit 1
167
+
168
+ rm -rf "${file_dir}";
169
+
170
+ fi
examples/nx_denoise/step_1_prepare_data.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ parser.add_argument("--max_count", default=10000, type=int)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def filename_generator(data_dir: str):
52
+ data_dir = Path(data_dir)
53
+ for filename in data_dir.glob("**/*.wav"):
54
+ yield filename.as_posix()
55
+
56
+
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
58
+ data_dir = Path(data_dir)
59
+ for filename in data_dir.glob("**/*.wav"):
60
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
61
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
62
+
63
+ if raw_duration < duration:
64
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
65
+ continue
66
+ if signal.ndim != 1:
67
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
68
+
69
+ signal_length = len(signal)
70
+ win_size = int(duration * sample_rate)
71
+ for begin in range(0, signal_length - win_size, win_size):
72
+ row = {
73
+ "filename": filename.as_posix(),
74
+ "raw_duration": round(raw_duration, 4),
75
+ "offset": round(begin / sample_rate, 4),
76
+ "duration": round(duration, 4),
77
+ }
78
+ yield row
79
+
80
+
81
+ def get_dataset(args):
82
+ file_dir = Path(args.file_dir)
83
+ file_dir.mkdir(exist_ok=True)
84
+
85
+ noise_dir = Path(args.noise_dir)
86
+ speech_dir = Path(args.speech_dir)
87
+
88
+ noise_generator = target_second_signal_generator(
89
+ noise_dir.as_posix(),
90
+ duration=args.duration,
91
+ sample_rate=args.target_sample_rate
92
+ )
93
+ speech_generator = target_second_signal_generator(
94
+ speech_dir.as_posix(),
95
+ duration=args.duration,
96
+ sample_rate=args.target_sample_rate
97
+ )
98
+
99
+ dataset = list()
100
+
101
+ count = 0
102
+ process_bar = tqdm(desc="build dataset excel")
103
+ for noise, speech in zip(noise_generator, speech_generator):
104
+ if count >= args.max_count:
105
+ break
106
+
107
+ noise_filename = noise["filename"]
108
+ noise_raw_duration = noise["raw_duration"]
109
+ noise_offset = noise["offset"]
110
+ noise_duration = noise["duration"]
111
+
112
+ speech_filename = speech["filename"]
113
+ speech_raw_duration = speech["raw_duration"]
114
+ speech_offset = speech["offset"]
115
+ speech_duration = speech["duration"]
116
+
117
+ random1 = random.random()
118
+ random2 = random.random()
119
+
120
+ row = {
121
+ "noise_filename": noise_filename,
122
+ "noise_raw_duration": noise_raw_duration,
123
+ "noise_offset": noise_offset,
124
+ "noise_duration": noise_duration,
125
+
126
+ "speech_filename": speech_filename,
127
+ "speech_raw_duration": speech_raw_duration,
128
+ "speech_offset": speech_offset,
129
+ "speech_duration": speech_duration,
130
+
131
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
132
+
133
+ "random1": random1,
134
+ "random2": random2,
135
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
136
+ }
137
+ dataset.append(row)
138
+ count += 1
139
+ duration_seconds = count * args.duration
140
+ duration_hours = duration_seconds / 3600
141
+
142
+ process_bar.update(n=1)
143
+ process_bar.set_postfix({
144
+ # "duration_seconds": round(duration_seconds, 4),
145
+ "duration_hours": round(duration_hours, 4),
146
+
147
+ })
148
+
149
+ dataset = pd.DataFrame(dataset)
150
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
151
+ dataset.to_excel(
152
+ file_dir / "dataset.xlsx",
153
+ index=False,
154
+ )
155
+ return
156
+
157
+
158
+
159
+ def split_dataset(args):
160
+ """分割训练集, 测试集"""
161
+ file_dir = Path(args.file_dir)
162
+ file_dir.mkdir(exist_ok=True)
163
+
164
+ df = pd.read_excel(file_dir / "dataset.xlsx")
165
+
166
+ train = list()
167
+ test = list()
168
+
169
+ for i, row in df.iterrows():
170
+ flag = row["flag"]
171
+ if flag == "TRAIN":
172
+ train.append(row)
173
+ else:
174
+ test.append(row)
175
+
176
+ train = pd.DataFrame(train)
177
+ train.to_excel(
178
+ args.train_dataset,
179
+ index=False,
180
+ # encoding="utf_8_sig"
181
+ )
182
+ test = pd.DataFrame(test)
183
+ test.to_excel(
184
+ args.valid_dataset,
185
+ index=False,
186
+ # encoding="utf_8_sig"
187
+ )
188
+
189
+ return
190
+
191
+
192
+ def main():
193
+ args = get_args()
194
+
195
+ get_dataset(args)
196
+ split_dataset(args)
197
+ return
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
examples/nx_denoise/step_2_train_model.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch.nn import functional as F
24
+ from torch.utils.data.dataloader import DataLoader
25
+ from tqdm import tqdm
26
+
27
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
+ from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
29
+ from toolbox.torchaudio.models.nx_denoise.discriminator import MetricDiscriminator, MetricDiscriminatorPretrainedModel
30
+ from toolbox.torchaudio.models.nx_denoise.modeling_nx_denoise import NXDenoise, NXDenoisePretrainedModel
31
+ from toolbox.torchaudio.models.nx_denoise.metrics import run_batch_pesq, run_pesq_score
32
+ from toolbox.torchaudio.models.nx_denoise.utils import mag_pha_stft, mag_pha_istft
33
+ from toolbox.torchaudio.models.nx_denoise.loss import phase_losses
34
+
35
+
36
+ def get_args():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
39
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
40
+
41
+ parser.add_argument("--max_epochs", default=100, type=int)
42
+
43
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
+ parser.add_argument("--patience", default=5, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ clean_audios = list()
80
+ noisy_audios = list()
81
+
82
+ for sample in batch:
83
+ # noise_wave: torch.Tensor = sample["noise_wave"]
84
+ clean_audio: torch.Tensor = sample["speech_wave"]
85
+ noisy_audio: torch.Tensor = sample["mix_wave"]
86
+ # snr_db: float = sample["snr_db"]
87
+
88
+ clean_audios.append(clean_audio)
89
+ noisy_audios.append(noisy_audio)
90
+
91
+ clean_audios = torch.stack(clean_audios)
92
+ noisy_audios = torch.stack(noisy_audios)
93
+
94
+ # assert
95
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
96
+ raise AssertionError("nan or inf in clean_audios")
97
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
98
+ raise AssertionError("nan or inf in noisy_audios")
99
+ return clean_audios, noisy_audios
100
+
101
+
102
+ collate_fn = CollateFunction()
103
+
104
+
105
+ def main():
106
+ args = get_args()
107
+
108
+ config = NXDenoiseConfig.from_pretrained(
109
+ pretrained_model_name_or_path=args.config_file,
110
+ )
111
+
112
+ serialization_dir = Path(args.serialization_dir)
113
+ serialization_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ logger = logging_config(serialization_dir)
116
+
117
+ random.seed(config.seed)
118
+ np.random.seed(config.seed)
119
+ torch.manual_seed(config.seed)
120
+ logger.info(f"set seed: {config.seed}")
121
+
122
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ n_gpu = torch.cuda.device_count()
124
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
125
+
126
+ # datasets
127
+ train_dataset = DenoiseExcelDataset(
128
+ excel_file=args.train_dataset,
129
+ expected_sample_rate=8000,
130
+ max_wave_value=32768.0,
131
+ )
132
+ valid_dataset = DenoiseExcelDataset(
133
+ excel_file=args.valid_dataset,
134
+ expected_sample_rate=8000,
135
+ max_wave_value=32768.0,
136
+ )
137
+ train_data_loader = DataLoader(
138
+ dataset=train_dataset,
139
+ batch_size=config.batch_size,
140
+ shuffle=True,
141
+ sampler=None,
142
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
143
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
144
+ collate_fn=collate_fn,
145
+ pin_memory=False,
146
+ prefetch_factor=16,
147
+ )
148
+ valid_data_loader = DataLoader(
149
+ dataset=valid_dataset,
150
+ batch_size=config.batch_size,
151
+ shuffle=True,
152
+ sampler=None,
153
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
154
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
155
+ collate_fn=collate_fn,
156
+ pin_memory=False,
157
+ prefetch_factor=16,
158
+ )
159
+
160
+ # models
161
+ logger.info(f"prepare models. config_file: {args.config_file}")
162
+ generator = NXDenoisePretrainedModel(config).to(device)
163
+ discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
164
+
165
+ # optimizer
166
+ logger.info("prepare optimizer, lr_scheduler")
167
+ num_params = 0
168
+ for p in generator.parameters():
169
+ num_params += p.numel()
170
+ logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
171
+
172
+ optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
173
+ optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
174
+
175
+ # resume training
176
+ last_epoch = -1
177
+ for epoch_i in serialization_dir.glob("epoch-*"):
178
+ epoch_i = Path(epoch_i)
179
+ epoch_idx = epoch_i.stem.split("-")[1]
180
+ epoch_idx = int(epoch_idx)
181
+ if epoch_idx > last_epoch:
182
+ last_epoch = epoch_idx
183
+
184
+ if last_epoch != -1:
185
+ logger.info(f"resume from epoch-{last_epoch}.")
186
+ generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
187
+ discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
188
+ optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
189
+ optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
190
+
191
+ logger.info(f"load state dict for generator.")
192
+ with open(generator_pt.as_posix(), "rb") as f:
193
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
+ generator.load_state_dict(state_dict, strict=True)
195
+ logger.info(f"load state dict for discriminator.")
196
+ with open(discriminator_pt.as_posix(), "rb") as f:
197
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
198
+ discriminator.load_state_dict(state_dict, strict=True)
199
+
200
+ logger.info(f"load state dict for optim_g.")
201
+ with open(optim_g_pth.as_posix(), "rb") as f:
202
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
203
+ optim_g.load_state_dict(state_dict)
204
+ logger.info(f"load state dict for optim_d.")
205
+ with open(optim_d_pth.as_posix(), "rb") as f:
206
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
207
+ optim_d.load_state_dict(state_dict)
208
+
209
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
210
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
211
+
212
+ # training loop
213
+
214
+ # state
215
+ loss_d = 10000000000
216
+ loss_g = 10000000000
217
+ pesq_metric = 10000000000
218
+ mag_err = 10000000000
219
+ pha_err = 10000000000
220
+ com_err = 10000000000
221
+
222
+ model_list = list()
223
+ best_idx_epoch = None
224
+ best_metric = None
225
+ patience_count = 0
226
+
227
+ logger.info("training")
228
+ for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
229
+ # train
230
+ generator.train()
231
+ discriminator.train()
232
+
233
+ total_loss_d = 0.
234
+ total_loss_g = 0.
235
+ total_batches = 0.
236
+ progress_bar = tqdm(
237
+ total=len(train_data_loader),
238
+ desc="Training; epoch: {}".format(idx_epoch),
239
+ )
240
+ for batch in train_data_loader:
241
+ clean_audios, noisy_audios = batch
242
+ clean_audios = clean_audios.to(device)
243
+ noisy_audios = noisy_audios.to(device)
244
+ one_labels = torch.ones(clean_audios.shape[0]).to(device)
245
+
246
+ audio_g = generator.forward(noisy_audios)
247
+
248
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_length, config.win_length, config.compress_factor)
249
+ mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_length, config.win_length, config.compress_factor)
250
+
251
+ clean_audio_list = torch.split(clean_audios, 1, dim=0)
252
+ enhanced_audio_list = torch.split(audio_g, 1, dim=0)
253
+ clean_audio_list = [t.squeeze().detach().cpu().numpy() for t in clean_audio_list]
254
+ enhanced_audio_list = [t.squeeze().detach().cpu().numpy() for t in enhanced_audio_list]
255
+
256
+ pesq_score_list: List[float] = run_batch_pesq(clean_audio_list, enhanced_audio_list, sample_rate=config.sample_rate, mode="nb")
257
+
258
+ # Discriminator
259
+ optim_d.zero_grad()
260
+ metric_r = discriminator.forward(clean_audios, clean_audios)
261
+ metric_g = discriminator.forward(clean_audios, audio_g.detach())
262
+ loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
263
+
264
+ if -1 in pesq_score_list:
265
+ # print("-1 in batch_pesq_score!")
266
+ loss_disc_g = 0
267
+ else:
268
+ pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
269
+ loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
270
+
271
+ loss_disc_all = loss_disc_r + loss_disc_g
272
+ loss_disc_all.backward()
273
+ optim_d.step()
274
+
275
+ # Generator
276
+ optim_g.zero_grad()
277
+ # L2 Magnitude Loss
278
+ loss_mag = F.mse_loss(clean_mag, mag_g)
279
+ # Anti-wrapping Phase Loss
280
+ loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
281
+ loss_pha = loss_ip + loss_gd + loss_iaf
282
+ # L2 Complex Loss
283
+ loss_com = F.mse_loss(clean_com, com_g) * 2
284
+ # L2 Consistency Loss
285
+ # Time Loss
286
+ loss_time = F.l1_loss(clean_audios, audio_g)
287
+ # Metric Loss
288
+ metric_g = discriminator.forward(clean_audios, audio_g.detach())
289
+ loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
290
+
291
+ # loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2
292
+ loss_gen_all = loss_mag * 0.1 + loss_pha * 0.1 + loss_com * 0.1 + loss_metric * 0.9 + loss_time * 0.9
293
+
294
+ loss_gen_all.backward()
295
+ optim_g.step()
296
+
297
+ total_loss_d += loss_disc_all.item()
298
+ total_loss_g += loss_gen_all.item()
299
+ total_batches += 1
300
+
301
+ loss_d = round(total_loss_d / total_batches, 4)
302
+ loss_g = round(total_loss_g / total_batches, 4)
303
+
304
+ progress_bar.update(1)
305
+ progress_bar.set_postfix({
306
+ "loss_d": loss_d,
307
+ "loss_g": loss_g,
308
+ })
309
+
310
+ # evaluation
311
+ generator.eval()
312
+ discriminator.eval()
313
+
314
+ torch.cuda.empty_cache()
315
+ total_pesq_score = 0.
316
+ total_mag_err = 0.
317
+ total_pha_err = 0.
318
+ total_com_err = 0.
319
+ total_batches = 0.
320
+
321
+ progress_bar = tqdm(
322
+ total=len(valid_data_loader),
323
+ desc="Evaluation; epoch: {}".format(idx_epoch),
324
+ )
325
+ with torch.no_grad():
326
+ for batch in valid_data_loader:
327
+ clean_audios, noisy_audios = batch
328
+ clean_audios = clean_audios.to(device)
329
+ noisy_audios = noisy_audios.to(device)
330
+
331
+ audio_g = generator.forward(noisy_audios)
332
+
333
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_length, config.win_length, config.compress_factor)
334
+ mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_length, config.win_length, config.compress_factor)
335
+
336
+ clean_audio_list = torch.split(clean_audios, 1, dim=0)
337
+ enhanced_audio_list = torch.split(audio_g, 1, dim=0)
338
+ clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list]
339
+ enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list]
340
+ pesq_score = run_pesq_score(
341
+ clean_audio_list,
342
+ enhanced_audio_list,
343
+ sample_rate = config.sample_rate,
344
+ mode = "nb",
345
+ )
346
+ total_pesq_score += pesq_score
347
+ total_mag_err += F.mse_loss(clean_mag, mag_g).item()
348
+ val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
349
+ total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
350
+ total_com_err += F.mse_loss(clean_com, com_g).item()
351
+
352
+ total_batches += 1
353
+
354
+ pesq_metric = round(total_pesq_score / total_batches, 4)
355
+ mag_err = round(total_mag_err / total_batches, 4)
356
+ pha_err = round(total_pha_err / total_batches, 4)
357
+ com_err = round(total_com_err / total_batches, 4)
358
+
359
+ progress_bar.update(1)
360
+ progress_bar.set_postfix({
361
+ "pesq_metric": pesq_metric,
362
+ "mag_err": mag_err,
363
+ "pha_err": pha_err,
364
+ "com_err": com_err,
365
+ })
366
+
367
+ # scheduler
368
+ scheduler_g.step()
369
+ scheduler_d.step()
370
+
371
+ # save path
372
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
373
+ epoch_dir.mkdir(parents=True, exist_ok=False)
374
+
375
+ # save models
376
+ generator.save_pretrained(epoch_dir.as_posix())
377
+ discriminator.save_pretrained(epoch_dir.as_posix())
378
+
379
+ # save optim
380
+ torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
381
+ torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
382
+
383
+ model_list.append(epoch_dir)
384
+ if len(model_list) >= args.num_serialized_models_to_keep:
385
+ model_to_delete: Path = model_list.pop(0)
386
+ shutil.rmtree(model_to_delete.as_posix())
387
+
388
+ # save metric
389
+ if best_metric is None:
390
+ best_idx_epoch = idx_epoch
391
+ best_metric = pesq_metric
392
+ elif pesq_metric > best_metric:
393
+ # great is better.
394
+ best_idx_epoch = idx_epoch
395
+ best_metric = pesq_metric
396
+ else:
397
+ pass
398
+
399
+ metrics = {
400
+ "idx_epoch": idx_epoch,
401
+ "best_idx_epoch": best_idx_epoch,
402
+ "loss_d": loss_d,
403
+ "loss_g": loss_g,
404
+
405
+ "pesq_metric": pesq_metric,
406
+ "mag_err": mag_err,
407
+ "pha_err": pha_err,
408
+ "com_err": com_err,
409
+
410
+ }
411
+ metrics_filename = epoch_dir / "metrics_epoch.json"
412
+ with open(metrics_filename, "w", encoding="utf-8") as f:
413
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
414
+
415
+ # save best
416
+ best_dir = serialization_dir / "best"
417
+ if best_idx_epoch == idx_epoch:
418
+ if best_dir.exists():
419
+ shutil.rmtree(best_dir)
420
+ shutil.copytree(epoch_dir, best_dir)
421
+
422
+ # early stop
423
+ early_stop_flag = False
424
+ if best_idx_epoch == idx_epoch:
425
+ patience_count = 0
426
+ else:
427
+ patience_count += 1
428
+ if patience_count >= args.patience:
429
+ early_stop_flag = True
430
+
431
+ # early stop
432
+ if early_stop_flag:
433
+ break
434
+
435
+ return
436
+
437
+
438
+ if __name__ == "__main__":
439
+ main()
examples/nx_denoise/step_3_evaluation.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+ import uuid
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import pandas as pd
16
+ from scipy.io import wavfile
17
+ import torch
18
+ import torch.nn as nn
19
+ import torchaudio
20
+ from tqdm import tqdm
21
+
22
+
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
26
+ parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
27
+ parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
28
+
29
+ parser.add_argument("--limit", default=10, type=int)
30
+
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def logging_config():
36
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
37
+
38
+ logging.basicConfig(format=fmt,
39
+ datefmt="%m/%d/%Y %H:%M:%S",
40
+ level=logging.INFO)
41
+ stream_handler = logging.StreamHandler()
42
+ stream_handler.setLevel(logging.INFO)
43
+ stream_handler.setFormatter(logging.Formatter(fmt))
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ return logger
48
+
49
+
50
+ def main():
51
+ return
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
examples/nx_denoise/yaml/config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "nx_clean_unet"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 16000
5
+ n_fft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ down_sampling_num_layers: 6
10
+ down_sampling_in_channels: 1
11
+ down_sampling_hidden_channels: 64
12
+ down_sampling_kernel_size: 4
13
+ down_sampling_stride: 2
14
+
15
+ causal_in_channels: 1
16
+ causal_out_channels: 1
17
+ causal_kernel_size: 3
18
+ causal_bias: false
19
+ causal_separable: true
20
+ causal_f_stride: 1
21
+ causal_num_layers: 5
22
+
23
+ tsfm_hidden_size: 256
24
+ tsfm_attention_heads: 8
25
+ tsfm_num_blocks: 6
26
+ tsfm_dropout_rate: 0.1
27
+ tsfm_max_length: 512
28
+ tsfm_chunk_size: 1
29
+ tsfm_num_left_chunks: 128
30
+ tsfm_num_right_chunks: 4
31
+
32
+ discriminator_dim: 32
33
+ discriminator_in_channel: 2
34
+
35
+ compress_factor: 0.3
36
+
37
+ batch_size: 64
38
+ learning_rate: 0.0005
39
+ adam_b1: 0.8
40
+ adam_b2: 0.99
41
+ lr_decay: 0.99
42
+ seed: 1234
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -143,7 +143,7 @@ class UpSampling(nn.Module):
143
  for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
144
  skip_x = skip_connection_list[idx]
145
  x = x + skip_x
146
- # x = x + skip_x[:, :, :x.shape[-1]]
147
  x = up_sampling_block.forward(x)
148
  return x
149
 
 
143
  for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
144
  skip_x = skip_connection_list[idx]
145
  x = x + skip_x
146
+ # x = x + skip_x[:, :, :x.size(2)]
147
  x = up_sampling_block.forward(x)
148
  return x
149
 
toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py CHANGED
@@ -207,6 +207,13 @@ class RelativeMultiHeadSelfAttention(nn.Module):
207
  mask: torch.Tensor = None,
208
  cache: torch.Tensor = None
209
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
210
  # attention! self attention.
211
 
212
  q, k, v = self.forward_qkv(x, x, x)
 
207
  mask: torch.Tensor = None,
208
  cache: torch.Tensor = None
209
  ) -> Tuple[torch.Tensor, torch.Tensor]:
210
+ """
211
+
212
+ :param x:
213
+ :param mask:
214
+ :param cache: Tensor, shape: [1, n_heads, time_steps, dim]
215
+ :return:
216
+ """
217
  # attention! self attention.
218
 
219
  q, k, v = self.forward_qkv(x, x, x)
toolbox/torchaudio/models/nx_denoise/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ import os
5
+ from typing import List, Optional, Union, Iterable
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+
13
+ norm_layer_dict = {
14
+ "batch_norm_2d": torch.nn.BatchNorm2d
15
+ }
16
+
17
+
18
+ activation_layer_dict = {
19
+ "relu": torch.nn.ReLU,
20
+ "identity": torch.nn.Identity,
21
+ "sigmoid": torch.nn.Sigmoid,
22
+ }
23
+
24
+
25
+ class CausalConv2d(nn.Module):
26
+ def __init__(self,
27
+ in_channels: int,
28
+ out_channels: int,
29
+ kernel_size: Union[int, Iterable[int]],
30
+ f_stride: int = 1,
31
+ dilation: int = 1,
32
+ do_f_pad: bool = True,
33
+ bias: bool = True,
34
+ separable: bool = False,
35
+ norm_layer: str = "batch_norm_2d",
36
+ activation_layer: str = "relu",
37
+ lookahead: int = 0
38
+ ):
39
+ super(CausalConv2d, self).__init__()
40
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
41
+
42
+ if do_f_pad:
43
+ f_pad = kernel_size[1] // 2 + dilation - 1
44
+ else:
45
+ f_pad = 0
46
+
47
+ self.causal_left_pad = kernel_size[0] - 1 - lookahead
48
+ self.causal_right_pad = lookahead
49
+ self.constant_pad = nn.ConstantPad2d(
50
+ padding=(0, 0, self.causal_left_pad, self.causal_right_pad),
51
+ value=0.0
52
+ )
53
+
54
+ groups = math.gcd(in_channels, out_channels) if separable else 1
55
+ self.conv1 = nn.Conv2d(
56
+ in_channels,
57
+ out_channels,
58
+ kernel_size=kernel_size,
59
+ padding=(0, f_pad),
60
+ stride=(1, f_stride),
61
+ dilation=(1, dilation),
62
+ groups=groups,
63
+ bias=bias,
64
+ )
65
+
66
+ self.conv2 = None
67
+ if not any([groups == 1, max(kernel_size) == 1]):
68
+ self.conv2 = nn.Conv2d(
69
+ out_channels,
70
+ out_channels,
71
+ kernel_size=1,
72
+ bias=False,
73
+ )
74
+
75
+ self.norm = None
76
+ if norm_layer is not None:
77
+ norm_layer = norm_layer_dict[norm_layer]
78
+ self.norm = norm_layer(out_channels)
79
+
80
+ self.activation = None
81
+ if activation_layer is not None:
82
+ activation_layer = activation_layer_dict[activation_layer]
83
+ self.activation = activation_layer()
84
+
85
+ def forward(self,
86
+ inputs: torch.Tensor,
87
+ causal_cache: torch.Tensor = None,
88
+ ):
89
+
90
+ if causal_cache is None:
91
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
92
+ x = self.constant_pad.forward(inputs)
93
+ else:
94
+ # inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size]
95
+ # causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size]
96
+ x = torch.concat(tensors=[causal_cache, inputs], dim=2)
97
+ # x shape: [batch_size, 1, time_steps2, hidden_size]
98
+ # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
99
+
100
+ x = self.conv1.forward(x)
101
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
102
+
103
+ if self.conv2:
104
+ x = self.conv2.forward(x)
105
+
106
+ if self.norm:
107
+ x = self.norm(x)
108
+ if self.activation:
109
+ x = self.activation(x)
110
+
111
+ causal_cache = x[:, :, -self.causal_left_pad:, :]
112
+
113
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
114
+ return x, causal_cache
115
+
116
+
117
+ class CausalConv2dEncoder(nn.Module):
118
+ def __init__(self,
119
+ in_channels: int,
120
+ hidden_channels: int,
121
+ out_channels: int,
122
+ kernel_size: Union[int, Iterable[int]],
123
+ f_stride: int = 1,
124
+ dilation: int = 1,
125
+ do_f_pad: bool = True,
126
+ bias: bool = True,
127
+ separable: bool = False,
128
+ norm_layer: str = "batch_norm_2d",
129
+ activation_layer: str = "relu",
130
+ lookahead: int = 0,
131
+ num_layers: int = 5,
132
+ ):
133
+ super(CausalConv2dEncoder, self).__init__()
134
+ self.num_layers = num_layers
135
+
136
+ self.total_causal_left_pad = 0
137
+ self.total_causal_right_pad = 0
138
+
139
+ self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[])
140
+ for i_layer in range(num_layers):
141
+ conv = CausalConv2d(
142
+ in_channels=in_channels,
143
+ out_channels=hidden_channels,
144
+ kernel_size=kernel_size,
145
+ f_stride=f_stride,
146
+ dilation=dilation,
147
+ do_f_pad=do_f_pad,
148
+ bias=bias,
149
+ separable=separable,
150
+ norm_layer=norm_layer,
151
+ activation_layer=activation_layer,
152
+ lookahead=lookahead,
153
+ )
154
+ self.causal_conv_list.append(conv)
155
+
156
+ self.total_causal_left_pad += conv.causal_left_pad
157
+ self.total_causal_right_pad += conv.causal_right_pad
158
+
159
+ in_channels = hidden_channels
160
+ else:
161
+ conv = CausalConv2d(
162
+ in_channels=hidden_channels,
163
+ out_channels=out_channels,
164
+ kernel_size=kernel_size,
165
+ f_stride=f_stride,
166
+ dilation=dilation,
167
+ do_f_pad=do_f_pad,
168
+ bias=bias,
169
+ separable=separable,
170
+ norm_layer=norm_layer,
171
+ activation_layer=activation_layer,
172
+ lookahead=lookahead,
173
+ )
174
+ self.causal_conv_list.append(conv)
175
+
176
+ self.total_causal_left_pad += conv.causal_left_pad
177
+ self.total_causal_right_pad += conv.causal_right_pad
178
+
179
+
180
+ def forward(self, inputs: torch.Tensor):
181
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
182
+
183
+ x = inputs
184
+ for layer in self.causal_conv_list:
185
+ x, _ = layer.forward(x)
186
+ return x
187
+
188
+ def forward_chunk(self,
189
+ chunk: torch.Tensor,
190
+ causal_cache: torch.Tensor = None,
191
+ ):
192
+ # causal_cache shape: [self.num_layers, 1, causal_left_pad, hidden_size]
193
+
194
+ new_causal_cache_list = list()
195
+ for idx, causal_conv in enumerate(self.causal_conv_list):
196
+ chunk, new_causal_cache = causal_conv.forward(
197
+ inputs=chunk, causal_cache=causal_cache[idx: idx+1] if causal_cache is not None else None
198
+ )
199
+ new_causal_cache_list.append(new_causal_cache)
200
+
201
+ new_causal_cache = torch.cat(new_causal_cache_list, dim=0)
202
+ return chunk, new_causal_cache
203
+
204
+ def forward_chunk_by_chunk(self, inputs: torch.Tensor):
205
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
206
+ # batch_size = 1
207
+
208
+ batch_size, channels, time_steps, hidden_size = inputs.shape
209
+
210
+ causal_cache = None
211
+
212
+ outputs = []
213
+ for idx in range(0, time_steps, 1):
214
+ begin = idx
215
+ end = begin + self.total_causal_right_pad + 1
216
+ chunk_xs = inputs[:, :, begin:end, :]
217
+
218
+ ys, attention_cache = self.forward_chunk(
219
+ chunk=chunk_xs,
220
+ causal_cache=causal_cache,
221
+ )
222
+ # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
223
+ ys = ys[:, :, :1, :]
224
+
225
+ # ys shape: [batch_size, chunk_size, hidden_size]
226
+ outputs.append(ys)
227
+
228
+ ys = torch.cat(outputs, 2)
229
+ return ys
230
+
231
+
232
+ def main2():
233
+ conv = CausalConv2d(
234
+ in_channels=1,
235
+ out_channels=64,
236
+ kernel_size=3,
237
+ bias=False,
238
+ separable=True,
239
+ f_stride=1,
240
+ lookahead=0,
241
+ )
242
+
243
+ spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
244
+ # spec shape: [batch_size, 1, time_steps, hidden_size]
245
+ cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32)
246
+
247
+ output, _ = conv.forward(spec)
248
+ print(output.shape)
249
+
250
+ output, _ = conv.forward(spec, cache)
251
+ print(output.shape)
252
+
253
+ return
254
+
255
+
256
+ def main():
257
+ causal = CausalConv2dEncoder(
258
+ in_channels=1,
259
+ out_channels=1,
260
+ kernel_size=3,
261
+ bias=False,
262
+ separable=True,
263
+ f_stride=1,
264
+ lookahead=0,
265
+ num_layers=3,
266
+ )
267
+
268
+ spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
269
+ # spec shape: [batch_size, 1, time_steps, hidden_size]
270
+
271
+ output = causal.forward(spec)
272
+ print(output.shape)
273
+
274
+ output = causal.forward_chunk_by_chunk(spec)
275
+ print(output.shape)
276
+
277
+ return
278
+
279
+
280
+ if __name__ == '__main__':
281
+ main()
toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class NXDenoiseConfig(PretrainedConfig):
7
+ """
8
+ https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
+ """
10
+ def __init__(self,
11
+ sample_rate: int = 8000,
12
+ segment_size: int = 16000,
13
+ n_fft: int = 512,
14
+ win_length: int = 200,
15
+ hop_length: int = 80,
16
+
17
+ down_sampling_num_layers: int = 5,
18
+ down_sampling_in_channels: int = 1,
19
+ down_sampling_hidden_channels: int = 64,
20
+ down_sampling_kernel_size: int = 4,
21
+ down_sampling_stride: int = 2,
22
+
23
+ causal_in_channels: int = 1,
24
+ causal_hidden_channels: int = 64,
25
+ causal_kernel_size: int = 3,
26
+ causal_bias: bool = False,
27
+ causal_separable: bool = True,
28
+ causal_f_stride: int = 1,
29
+ # causal_lookahead: int = 0,
30
+ causal_num_layers: int = 3,
31
+
32
+ tsfm_hidden_size: int = 256,
33
+ tsfm_attention_heads: int = 4,
34
+ tsfm_num_blocks: int = 6,
35
+ tsfm_dropout_rate: float = 0.1,
36
+ tsfm_max_time_relative_position: int = 1024,
37
+ tsfm_max_freq_relative_position: int = 128,
38
+ tsfm_chunk_size: int = 4,
39
+ tsfm_num_left_chunks: int = 128,
40
+ tsfm_num_right_chunks: int = 2,
41
+
42
+ discriminator_dim: int = 16,
43
+ discriminator_in_channel: int = 2,
44
+
45
+ compress_factor: float = 0.3,
46
+
47
+ batch_size: int = 4,
48
+ learning_rate: float = 0.0005,
49
+ adam_b1: float = 0.8,
50
+ adam_b2: float = 0.99,
51
+ lr_decay: float = 0.99,
52
+ seed: int = 1234,
53
+
54
+ **kwargs
55
+ ):
56
+ super(NXDenoiseConfig, self).__init__(**kwargs)
57
+ self.sample_rate = sample_rate
58
+ self.segment_size = segment_size
59
+ self.n_fft = n_fft
60
+ self.win_length = win_length
61
+ self.hop_length = hop_length
62
+
63
+ self.down_sampling_num_layers = down_sampling_num_layers
64
+ self.down_sampling_in_channels = down_sampling_in_channels
65
+ self.down_sampling_hidden_channels = down_sampling_hidden_channels
66
+ self.down_sampling_kernel_size = down_sampling_kernel_size
67
+ self.down_sampling_stride = down_sampling_stride
68
+
69
+ self.causal_in_channels = causal_in_channels
70
+ self.causal_hidden_channels = causal_hidden_channels
71
+ self.causal_kernel_size = causal_kernel_size
72
+ self.causal_bias = causal_bias
73
+ self.causal_separable = causal_separable
74
+ self.causal_f_stride = causal_f_stride
75
+ # self.causal_lookahead = causal_lookahead
76
+ self.causal_num_layers = causal_num_layers
77
+
78
+ self.tsfm_hidden_size = tsfm_hidden_size
79
+ self.tsfm_attention_heads = tsfm_attention_heads
80
+ self.tsfm_num_blocks = tsfm_num_blocks
81
+ self.tsfm_dropout_rate = tsfm_dropout_rate
82
+ self.tsfm_max_time_relative_position = tsfm_max_time_relative_position
83
+ self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position
84
+ self.tsfm_chunk_size = tsfm_chunk_size
85
+ self.tsfm_num_left_chunks = tsfm_num_left_chunks
86
+ self.tsfm_num_right_chunks = tsfm_num_right_chunks
87
+
88
+ self.discriminator_dim = discriminator_dim
89
+ self.discriminator_in_channel = discriminator_in_channel
90
+
91
+ self.compress_factor = compress_factor
92
+
93
+ self.batch_size = batch_size
94
+ self.learning_rate = learning_rate
95
+ self.adam_b1 = adam_b1
96
+ self.adam_b2 = adam_b2
97
+ self.lr_decay = lr_decay
98
+ self.seed = seed
99
+
100
+
101
+ if __name__ == '__main__':
102
+ pass
toolbox/torchaudio/models/nx_denoise/discriminator.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchaudio
9
+
10
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
11
+ from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
12
+ from toolbox.torchaudio.models.nx_denoise.utils import LearnableSigmoid1d
13
+
14
+
15
+ class MetricDiscriminator(nn.Module):
16
+ def __init__(self, config: NXDenoiseConfig):
17
+ super(MetricDiscriminator, self).__init__()
18
+ dim = config.discriminator_dim
19
+ self.in_channel = config.discriminator_in_channel
20
+
21
+ self.n_fft = config.n_fft
22
+ self.win_length = config.win_length
23
+ self.hop_length = config.hop_length
24
+
25
+ self.transform = torchaudio.transforms.Spectrogram(
26
+ n_fft=self.n_fft,
27
+ win_length=self.win_length,
28
+ hop_length=self.hop_length,
29
+ power=1.0,
30
+ window_fn=torch.hann_window,
31
+ # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
32
+ )
33
+
34
+ self.layers = nn.Sequential(
35
+ nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
36
+ nn.InstanceNorm2d(dim, affine=True),
37
+ nn.PReLU(dim),
38
+ nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
39
+ nn.InstanceNorm2d(dim*2, affine=True),
40
+ nn.PReLU(dim*2),
41
+ nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
42
+ nn.InstanceNorm2d(dim*4, affine=True),
43
+ nn.PReLU(dim*4),
44
+ nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
45
+ nn.InstanceNorm2d(dim*8, affine=True),
46
+ nn.PReLU(dim*8),
47
+ nn.AdaptiveMaxPool2d(1),
48
+ nn.Flatten(),
49
+ nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
50
+ nn.Dropout(0.3),
51
+ nn.PReLU(dim*4),
52
+ nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
53
+ LearnableSigmoid1d(1)
54
+ )
55
+
56
+ def forward(self, x, y):
57
+ x = self.transform.forward(x)
58
+ y = self.transform.forward(y)
59
+
60
+ xy = torch.stack((x, y), dim=1)
61
+ return self.layers(xy)
62
+
63
+
64
+ MODEL_FILE = "discriminator.pt"
65
+
66
+
67
+ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
68
+ def __init__(self,
69
+ config: NXDenoiseConfig,
70
+ ):
71
+ super(MetricDiscriminatorPretrainedModel, self).__init__(
72
+ config=config,
73
+ )
74
+ self.config = config
75
+
76
+ @classmethod
77
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
78
+ config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
79
+
80
+ model = cls(config)
81
+
82
+ if os.path.isdir(pretrained_model_name_or_path):
83
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
84
+ else:
85
+ ckpt_file = pretrained_model_name_or_path
86
+
87
+ with open(ckpt_file, "rb") as f:
88
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
89
+ model.load_state_dict(state_dict, strict=True)
90
+ return model
91
+
92
+ def save_pretrained(self,
93
+ save_directory: Union[str, os.PathLike],
94
+ state_dict: Optional[dict] = None,
95
+ ):
96
+
97
+ model = self
98
+
99
+ if state_dict is None:
100
+ state_dict = model.state_dict()
101
+
102
+ os.makedirs(save_directory, exist_ok=True)
103
+
104
+ # save state dict
105
+ model_file = os.path.join(save_directory, MODEL_FILE)
106
+ torch.save(state_dict, model_file)
107
+
108
+ # save config
109
+ config_file = os.path.join(save_directory, CONFIG_FILE)
110
+ self.config.to_yaml_file(config_file)
111
+ return save_directory
112
+
113
+
114
+ def main():
115
+ config = NXDenoiseConfig()
116
+ discriminator = MetricDiscriminator(config=config)
117
+
118
+ # shape: [batch_size, num_samples]
119
+ # x = torch.ones([4, int(4.5 * 16000)])
120
+ # y = torch.ones([4, int(4.5 * 16000)])
121
+ x = torch.ones([4, 16000])
122
+ y = torch.ones([4, 16000])
123
+
124
+ output = discriminator.forward(x, y)
125
+ print(output.shape)
126
+ print(output)
127
+
128
+ return
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
toolbox/torchaudio/models/nx_denoise/loss.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def anti_wrapping_function(x):
8
+
9
+ return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
10
+
11
+
12
+ def phase_losses(phase_r, phase_g):
13
+
14
+ ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
15
+ gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
16
+ iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
17
+
18
+ return ip_loss, gd_loss, iaf_loss
19
+
20
+
21
+ if __name__ == '__main__':
22
+ pass
toolbox/torchaudio/models/nx_denoise/metrics.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from joblib import Parallel, delayed
4
+ import numpy as np
5
+ from pesq import pesq
6
+ from typing import List
7
+
8
+ from pesq import cypesq
9
+
10
+
11
+ def run_pesq(clean_audio: np.ndarray,
12
+ noisy_audio: np.ndarray,
13
+ sample_rate: int = 16000,
14
+ mode: str = "wb",
15
+ ) -> float:
16
+ if sample_rate == 8000 and mode == "wb":
17
+ raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
+ try:
19
+ pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
+ except cypesq.NoUtterancesError as e:
21
+ pesq_score = -1
22
+ except Exception as e:
23
+ print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
24
+ pesq_score = -1
25
+ return pesq_score
26
+
27
+
28
+ def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
+ noisy_audio_list: List[np.ndarray],
30
+ sample_rate: int = 16000,
31
+ mode: str = "wb",
32
+ n_jobs: int = 4,
33
+ ) -> List[float]:
34
+ parallel = Parallel(n_jobs=n_jobs)
35
+
36
+ parallel_tasks = list()
37
+ for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
+ parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
+ parallel_tasks.append(parallel_task)
40
+
41
+ pesq_score_list = parallel.__call__(parallel_tasks)
42
+ return pesq_score_list
43
+
44
+
45
+ def run_pesq_score(clean_audio_list: List[np.ndarray],
46
+ noisy_audio_list: List[np.ndarray],
47
+ sample_rate: int = 16000,
48
+ mode: str = "wb",
49
+ n_jobs: int = 4,
50
+ ) -> List[float]:
51
+
52
+ pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
+ noisy_audio_list=noisy_audio_list,
54
+ sample_rate=sample_rate,
55
+ mode=mode,
56
+ n_jobs=n_jobs,
57
+ )
58
+
59
+ pesq_score = np.mean(pesq_score_list)
60
+ return pesq_score
61
+
62
+
63
+ def main():
64
+ clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
+ noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
+
67
+ clean_audio_list = list(clean_audio)
68
+ noisy_audio_list = list(noisy_audio)
69
+
70
+ pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
+ print(pesq_score_list)
72
+
73
+ pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
+ print(pesq_score)
75
+
76
+ return
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
+ from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
13
+ from toolbox.torchaudio.models.nx_denoise.causal_convolution.causal_conv2d import CausalConv2dEncoder
14
+ from toolbox.torchaudio.models.nx_denoise.transformers.transformers import TSTransformerEncoder
15
+
16
+
17
+ class DownSamplingBlock(nn.Module):
18
+ def __init__(self,
19
+ in_channels: int,
20
+ hidden_channels: int,
21
+ kernel_size: int,
22
+ stride: int,
23
+ ):
24
+ super(DownSamplingBlock, self).__init__()
25
+ self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride)
26
+ self.relu = nn.ReLU()
27
+ self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
28
+ self.glu = nn.GLU(dim=1)
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ # x shape: [batch_size, 1, num_samples]
32
+ x = self.conv1.forward(x)
33
+ # x shape: [batch_size, hidden_channels, new_num_samples]
34
+ x = self.relu(x)
35
+ x = self.conv2.forward(x)
36
+ # x shape: [batch_size, hidden_channels*2, new_num_samples]
37
+ x = self.glu(x)
38
+ # x shape: [batch_size, hidden_channels, new_num_samples]
39
+ # new_num_samples = (num_samples-kernel_size) // stride + 1
40
+ return x
41
+
42
+
43
+ class DownSampling(nn.Module):
44
+ def __init__(self,
45
+ num_layers: int,
46
+ in_channels: int,
47
+ hidden_channels: int,
48
+ kernel_size: int,
49
+ stride: int,
50
+ ):
51
+ super(DownSampling, self).__init__()
52
+ self.num_layers = num_layers
53
+
54
+ down_sampling_block_list = list()
55
+ for idx in range(self.num_layers):
56
+ down_sampling_block = DownSamplingBlock(
57
+ in_channels=in_channels,
58
+ hidden_channels=hidden_channels,
59
+ kernel_size=kernel_size,
60
+ stride=stride,
61
+ )
62
+ down_sampling_block_list.append(down_sampling_block)
63
+ in_channels = hidden_channels
64
+
65
+ self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list)
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ # x shape: [batch_size, channels, num_samples]
69
+ skip_connection_list = list()
70
+ for down_sampling_block in self.down_sampling_block_list:
71
+ x = down_sampling_block.forward(x)
72
+ skip_connection_list.append(x)
73
+ # x shape: [batch_size, hidden_channels, num_samples**]
74
+ return x, skip_connection_list
75
+
76
+
77
+ class UpSamplingBlock(nn.Module):
78
+ def __init__(self,
79
+ out_channels: int,
80
+ hidden_channels: int,
81
+ kernel_size: int,
82
+ stride: int,
83
+ do_relu: bool = True,
84
+ ):
85
+ super(UpSamplingBlock, self).__init__()
86
+ self.do_relu = do_relu
87
+
88
+ self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
89
+ self.glu = nn.GLU(dim=1)
90
+ self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride)
91
+ self.relu = nn.ReLU()
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ # x shape: [batch_size, hidden_channels*2, num_samples]
95
+ x = self.conv1.forward(x)
96
+ # x shape: [batch_size, hidden_channels, num_samples]
97
+ x = self.glu(x)
98
+ # x shape: [batch_size, hidden_channels, num_samples]
99
+ x = self.convt.forward(x)
100
+ # x shape: [batch_size, hidden_channels, new_num_samples]
101
+ # new_num_samples = (num_samples - 1) * stride + kernel_size
102
+ if self.do_relu:
103
+ x = self.relu(x)
104
+ return x
105
+
106
+
107
+ class UpSampling(nn.Module):
108
+ def __init__(self,
109
+ num_layers: int,
110
+ out_channels: int,
111
+ hidden_channels: int,
112
+ kernel_size: int,
113
+ stride: int,
114
+ ):
115
+ super(UpSampling, self).__init__()
116
+ self.num_layers = num_layers
117
+
118
+ up_sampling_block_list = list()
119
+ for idx in range(self.num_layers-1):
120
+ up_sampling_block = UpSamplingBlock(
121
+ out_channels=hidden_channels,
122
+ hidden_channels=hidden_channels,
123
+ kernel_size=kernel_size,
124
+ stride=stride,
125
+ do_relu=True,
126
+ )
127
+ up_sampling_block_list.append(up_sampling_block)
128
+ else:
129
+ up_sampling_block = UpSamplingBlock(
130
+ out_channels=out_channels,
131
+ hidden_channels=hidden_channels,
132
+ kernel_size=kernel_size,
133
+ stride=stride,
134
+ do_relu=False,
135
+ )
136
+ up_sampling_block_list.append(up_sampling_block)
137
+ self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
138
+
139
+ def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]):
140
+ skip_connection_list = skip_connection_list[::-1]
141
+
142
+ # x shape: [batch_size, channels, num_samples]
143
+ for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
144
+ skip_x = skip_connection_list[idx]
145
+ x = x + skip_x
146
+ # x = x + skip_x[:, :, :x.size(2)]
147
+ x = up_sampling_block.forward(x)
148
+ return x
149
+
150
+
151
+ def get_padding_length(length, num_layers: int, kernel_size: int, stride: int):
152
+ for _ in range(num_layers):
153
+ if length < kernel_size:
154
+ length = 1
155
+ else:
156
+ length = 1 + np.ceil((length - kernel_size) / stride)
157
+
158
+ for _ in range(num_layers):
159
+ length = (length - 1) * stride + kernel_size
160
+
161
+ padded_length = int(length)
162
+ return padded_length
163
+
164
+
165
+ class NXDenoise(nn.Module):
166
+ def __init__(self, config: NXDenoiseConfig):
167
+ super().__init__()
168
+ self.config = config
169
+
170
+ self.down_sampling = DownSampling(
171
+ num_layers=config.down_sampling_num_layers,
172
+ in_channels=config.down_sampling_in_channels,
173
+ hidden_channels=config.down_sampling_hidden_channels,
174
+ kernel_size=config.down_sampling_kernel_size,
175
+ stride=config.down_sampling_stride,
176
+ )
177
+ self.causal_conv_in = CausalConv2dEncoder(
178
+ in_channels=config.causal_in_channels,
179
+ hidden_channels=config.causal_hidden_channels,
180
+ out_channels=config.causal_hidden_channels,
181
+ kernel_size=config.causal_kernel_size,
182
+ bias=config.causal_bias,
183
+ separable=config.causal_separable,
184
+ f_stride=config.causal_f_stride,
185
+ lookahead=0,
186
+ num_layers=config.causal_num_layers,
187
+ )
188
+ self.ts_transformer = TSTransformerEncoder(
189
+ input_size=config.down_sampling_hidden_channels,
190
+ hidden_size=config.tsfm_hidden_size,
191
+ attention_heads=config.tsfm_attention_heads,
192
+ num_blocks=config.tsfm_num_blocks,
193
+ dropout_rate=config.tsfm_dropout_rate,
194
+ max_time_relative_position=config.tsfm_max_time_relative_position,
195
+ max_freq_relative_position=config.tsfm_max_freq_relative_position,
196
+ chunk_size=config.tsfm_chunk_size,
197
+ num_left_chunks=config.tsfm_num_left_chunks,
198
+ num_right_chunks=config.tsfm_num_right_chunks,
199
+ )
200
+ self.causal_conv_out = CausalConv2dEncoder(
201
+ in_channels=config.causal_hidden_channels,
202
+ hidden_channels=config.causal_hidden_channels,
203
+ out_channels=config.causal_in_channels,
204
+ kernel_size=config.causal_kernel_size,
205
+ bias=config.causal_bias,
206
+ separable=config.causal_separable,
207
+ f_stride=config.causal_f_stride,
208
+ lookahead=0,
209
+ num_layers=config.causal_num_layers,
210
+ )
211
+ self.up_sampling = UpSampling(
212
+ num_layers=config.down_sampling_num_layers,
213
+ out_channels=config.down_sampling_in_channels,
214
+ hidden_channels=config.down_sampling_hidden_channels,
215
+ kernel_size=config.down_sampling_kernel_size,
216
+ stride=config.down_sampling_stride,
217
+ )
218
+
219
+ def forward(self, noisy_audios: torch.Tensor):
220
+ # noisy_audios shape: [batch_size, n_samples]
221
+ noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
222
+ # noisy_audios shape: [batch_size, 1, n_samples]
223
+
224
+ n_samples = noisy_audios.shape[-1]
225
+ padded_length = get_padding_length(
226
+ n_samples,
227
+ num_layers=self.config.down_sampling_num_layers,
228
+ kernel_size=self.config.down_sampling_kernel_size,
229
+ stride=self.config.down_sampling_stride,
230
+ )
231
+ noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
232
+
233
+ # down sampling
234
+ bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
235
+ # bottle_neck shape: [batch_size, channels, time_steps]
236
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
237
+ # bottle_neck shape: [batch_size, time_steps, channels]
238
+ bottle_neck = torch.unsqueeze(bottle_neck, dim=1)
239
+ # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
240
+
241
+ # causal conv in
242
+ bottle_neck = self.causal_conv_in.forward(bottle_neck)
243
+ # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
244
+
245
+ # ts transformer
246
+ # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
247
+
248
+ # causal conv out
249
+ bottle_neck = self.causal_conv_out.forward(bottle_neck)
250
+ # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
251
+
252
+ # up sampling
253
+ bottle_neck = torch.squeeze(bottle_neck, dim=1)
254
+ # bottle_neck shape: [batch_size, time_steps, channels]
255
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
256
+ # bottle_neck shape: [batch_size, channels, time_steps]
257
+
258
+ enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
259
+
260
+ enhanced_audios = enhanced_audios[:, :, :n_samples]
261
+ # enhanced_audios shape: [batch_size, 1, n_samples]
262
+
263
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
264
+ # enhanced_audios shape: [batch_size, n_samples]
265
+
266
+ return enhanced_audios
267
+
268
+
269
+ MODEL_FILE = "generator.pt"
270
+
271
+
272
+ class NXDenoisePretrainedModel(NXDenoise):
273
+ def __init__(self,
274
+ config: NXDenoiseConfig,
275
+ ):
276
+ super(NXDenoisePretrainedModel, self).__init__(
277
+ config=config,
278
+ )
279
+ self.config = config
280
+
281
+ @classmethod
282
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
283
+ config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
284
+
285
+ model = cls(config)
286
+
287
+ if os.path.isdir(pretrained_model_name_or_path):
288
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
289
+ else:
290
+ ckpt_file = pretrained_model_name_or_path
291
+
292
+ with open(ckpt_file, "rb") as f:
293
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
294
+ model.load_state_dict(state_dict, strict=True)
295
+ return model
296
+
297
+ def save_pretrained(self,
298
+ save_directory: Union[str, os.PathLike],
299
+ state_dict: Optional[dict] = None,
300
+ ):
301
+
302
+ model = self
303
+
304
+ if state_dict is None:
305
+ state_dict = model.state_dict()
306
+
307
+ os.makedirs(save_directory, exist_ok=True)
308
+
309
+ # save state dict
310
+ model_file = os.path.join(save_directory, MODEL_FILE)
311
+ torch.save(state_dict, model_file)
312
+
313
+ # save config
314
+ config_file = os.path.join(save_directory, CONFIG_FILE)
315
+ self.config.to_yaml_file(config_file)
316
+ return save_directory
317
+
318
+
319
+ def main():
320
+
321
+ config = NXDenoiseConfig()
322
+
323
+ # shape: [batch_size, channels, num_samples]
324
+ # min length: 94, stride: 32, 32 == 2**5
325
+ # x = torch.ones([4, 94])
326
+ # x = torch.ones([4, 126])
327
+ # x = torch.ones([4, 158])
328
+ # x = torch.ones([4, 190])
329
+ x = torch.ones([4, 16000])
330
+
331
+ model = NXDenoise(config)
332
+ enhanced_audios = model.forward(x)
333
+ print(enhanced_audios.shape)
334
+ return
335
+
336
+
337
+ if __name__ == "__main__":
338
+ main()
toolbox/torchaudio/models/nx_denoise/transformers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_denoise/transformers/attention.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class MultiHeadSelfAttention(nn.Module):
11
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
12
+ """
13
+ :param n_head: int. the number of heads.
14
+ :param n_feat: int. the number of features.
15
+ :param dropout_rate: float. dropout rate.
16
+ """
17
+ super().__init__()
18
+ assert n_feat % n_head == 0
19
+ # We assume d_v always equals d_k
20
+ self.d_k = n_feat // n_head
21
+ self.h = n_head
22
+ self.linear_q = nn.Linear(n_feat, n_feat)
23
+ self.linear_k = nn.Linear(n_feat, n_feat)
24
+ self.linear_v = nn.Linear(n_feat, n_feat)
25
+ self.linear_out = nn.Linear(n_feat, n_feat)
26
+ self.dropout = nn.Dropout(p=dropout_rate)
27
+
28
+ def forward_qkv(self,
29
+ query: torch.Tensor,
30
+ key: torch.Tensor,
31
+ value: torch.Tensor
32
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
33
+ """
34
+ transform query, key and value.
35
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
36
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
37
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
38
+ :return:
39
+ """
40
+ n_batch = query.size(0)
41
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
42
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
43
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
44
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
45
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
46
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
47
+
48
+ return q, k, v
49
+
50
+ def forward_attention(self,
51
+ value: torch.Tensor,
52
+ scores: torch.Tensor,
53
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
54
+ ) -> torch.Tensor:
55
+ """
56
+ compute attention context vector.
57
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
58
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
59
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
60
+ (batch_size, time1, time2), (0, 0, 0) means fake mask.
61
+ :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
62
+ weighted by the attention score (batch_size, time1, time2).
63
+ """
64
+ n_batch = value.size(0)
65
+ # NOTE: When will `if mask.size(2) > 0` be True?
66
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
67
+ # 1st chunk to ease the onnx export.]
68
+ # 2. pytorch training
69
+ if mask.size(2) > 0: # time2 > 0
70
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
71
+ # For last chunk, time2 might be larger than scores.size(-1)
72
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
73
+ scores = scores.masked_fill(mask, -float('inf'))
74
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
75
+
76
+ # NOTE: When will `if mask.size(2) > 0` be False?
77
+ # 1. onnx(16/-1, -1/-1, 16/0)
78
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
79
+ else:
80
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
81
+
82
+ p_attn = self.dropout(attn)
83
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
84
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
85
+
86
+ return self.linear_out(x) # (batch, time1, n_feat)
87
+
88
+ def forward(self,
89
+ x: torch.Tensor,
90
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
91
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
92
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
93
+
94
+ q, k, v = self.forward_qkv(x, x, x)
95
+
96
+ if cache.size(0) > 0:
97
+ key_cache, value_cache = torch.split(
98
+ cache, cache.size(-1) // 2, dim=-1)
99
+ k = torch.cat([key_cache, k], dim=2)
100
+ v = torch.cat([value_cache, v], dim=2)
101
+ # NOTE: We do cache slicing in encoder.forward_chunk, since it's
102
+ # non-trivial to calculate `next_cache_start` here.
103
+ new_cache = torch.cat((k, v), dim=-1)
104
+
105
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
106
+ return self.forward_attention(v, scores, mask), new_cache
107
+
108
+
109
+ class RelativeMultiHeadSelfAttention(nn.Module):
110
+
111
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
112
+ """
113
+ :param n_head: int. the number of heads.
114
+ :param n_feat: int. the number of features.
115
+ :param dropout_rate: float. dropout rate.
116
+ :param max_relative_position: int. maximum relative position for relative position encoding.
117
+ """
118
+ super().__init__()
119
+ assert n_feat % n_head == 0
120
+ # We assume d_v always equals d_k
121
+ self.d_k = n_feat // n_head
122
+ self.h = n_head
123
+ self.linear_q = nn.Linear(n_feat, n_feat)
124
+ self.linear_k = nn.Linear(n_feat, n_feat)
125
+ self.linear_v = nn.Linear(n_feat, n_feat)
126
+ self.linear_out = nn.Linear(n_feat, n_feat)
127
+ self.dropout = nn.Dropout(p=dropout_rate)
128
+
129
+ # Relative position encoding
130
+ self.max_relative_position = max_relative_position
131
+ self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
132
+
133
+ def forward_qkv(self,
134
+ query: torch.Tensor,
135
+ key: torch.Tensor,
136
+ value: torch.Tensor
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
138
+ """
139
+ transform query, key and value.
140
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
141
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
142
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
143
+ :return:
144
+ """
145
+ n_batch = query.size(0)
146
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
147
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
148
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
149
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
150
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
151
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
152
+
153
+ return q, k, v
154
+
155
+ def forward_attention(self,
156
+ value: torch.Tensor,
157
+ scores: torch.Tensor,
158
+ mask: torch.Tensor = None
159
+ ) -> torch.Tensor:
160
+ """
161
+ compute attention context vector.
162
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
163
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
164
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
165
+ :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
166
+ weighted by the attention score (batch_size, query_time_steps, key_time_steps).
167
+ """
168
+ n_batch = value.size(0)
169
+ if mask is not None:
170
+ mask = mask.unsqueeze(1).eq(0)
171
+ # mask shape: [batch_size, 1, query_time_steps, key_time_steps]
172
+ scores = scores.masked_fill(mask, -float('inf'))
173
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
174
+ else:
175
+ attn = torch.softmax(scores, dim=-1)
176
+ # attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
177
+
178
+ p_attn = self.dropout(attn)
179
+
180
+ x = torch.matmul(p_attn, value)
181
+ # x shape: [batch_size, n_head, query_time_steps, d_k]
182
+ x = x.transpose(1, 2)
183
+ # x shape: [batch_size, query_time_steps, n_head, d_k]
184
+
185
+ x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
186
+ # x shape: [batch_size, query_time_steps, n_head * d_k]
187
+ # x shape: [batch_size, query_time_steps, n_feat]
188
+
189
+ x = self.linear_out(x)
190
+ # x shape: [batch_size, query_time_steps, n_feat]
191
+ return x
192
+
193
+ def relative_position_encoding(self, length: int) -> torch.Tensor:
194
+ """
195
+ Generate relative position encoding.
196
+ :param length: int. length of the sequence.
197
+ :return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
198
+ """
199
+ range_vec = torch.arange(length)
200
+ distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
201
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
202
+ final_mat = distance_mat_clipped + self.max_relative_position
203
+ return final_mat
204
+
205
+ def forward(self,
206
+ x: torch.Tensor,
207
+ mask: torch.Tensor = None,
208
+ cache: torch.Tensor = None
209
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
210
+ # attention! self attention.
211
+
212
+ q, k, v = self.forward_qkv(x, x, x)
213
+ # q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
214
+
215
+ if cache is not None:
216
+ key_cache, value_cache = torch.split(
217
+ cache, cache.size(-1) // 2, dim=-1)
218
+ k = torch.cat([key_cache, k], dim=2)
219
+ v = torch.cat([value_cache, v], dim=2)
220
+
221
+ # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
222
+ new_cache = torch.cat((k, v), dim=-1)
223
+
224
+ # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
225
+ native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
226
+
227
+ # Compute relative position encoding
228
+ q_length, k_length = q.size(2), k.size(2)
229
+ relative_position = self.relative_position_encoding(k_length)
230
+
231
+ relative_position = relative_position[-q_length:]
232
+
233
+ relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
234
+
235
+ relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
236
+ relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
237
+
238
+ relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
239
+ # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
240
+
241
+ # score
242
+ scores = native_scores + relative_position_scores
243
+
244
+ return self.forward_attention(v, scores, mask), new_cache
245
+
246
+
247
+ def main():
248
+ rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
249
+
250
+ x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
251
+ xt, new_cache = rel_attention.forward(x, x, x)
252
+
253
+ # x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
254
+ # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
255
+ # xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
256
+
257
+ print(xt.shape)
258
+ print(new_cache.shape)
259
+ return
260
+
261
+
262
+ if __name__ == '__main__':
263
+ main()
toolbox/torchaudio/models/nx_denoise/transformers/mask.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+
5
+
6
+ def make_pad_mask(lengths: torch.Tensor,
7
+ max_len: int = 0,
8
+ ) -> torch.Tensor:
9
+ batch_size = lengths.size(0)
10
+ max_len = max_len if max_len > 0 else lengths.max().item()
11
+ seq_range = torch.arange(
12
+ 0,
13
+ max_len,
14
+ dtype=torch.int64,
15
+ device=lengths.device
16
+ )
17
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
18
+ seq_length_expand = lengths.unsqueeze(-1)
19
+ mask = seq_range_expand >= seq_length_expand
20
+ return mask
21
+
22
+
23
+
24
+ def subsequent_chunk_mask(
25
+ size: int,
26
+ chunk_size: int,
27
+ num_left_chunks: int = -1,
28
+ num_right_chunks: int = 0,
29
+ device: torch.device = torch.device("cpu"),
30
+ ) -> torch.Tensor:
31
+ """
32
+ Create mask for subsequent steps (size, size) with chunk size,
33
+ this is for streaming encoder
34
+
35
+ Examples:
36
+ > subsequent_chunk_mask(4, 2)
37
+ [[1, 1, 0, 0],
38
+ [1, 1, 0, 0],
39
+ [1, 1, 1, 1],
40
+ [1, 1, 1, 1]]
41
+
42
+ :param size: int. size of mask.
43
+ :param chunk_size: int. size of chunk.
44
+ :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
45
+ :param num_right_chunks: int. number of right chunks.
46
+ :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
47
+ :return: torch.Tensor. mask
48
+ """
49
+
50
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
51
+ for i in range(size):
52
+ if num_left_chunks < 0:
53
+ start = 0
54
+ else:
55
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
56
+ ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
57
+ ret[i, start:ending] = True
58
+ return ret
59
+
60
+
61
+ def main():
62
+ chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
63
+ print(chunk_mask)
64
+
65
+ chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
66
+ print(chunk_mask)
67
+
68
+ chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
69
+ print(chunk_mask)
70
+ return
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
toolbox/torchaudio/models/nx_denoise/transformers/transformers.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Dict, Optional, Tuple, List, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
9
+ from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
10
+
11
+
12
+ class PositionwiseFeedForward(nn.Module):
13
+ def __init__(self,
14
+ input_dim: int,
15
+ hidden_units: int,
16
+ dropout_rate: float,
17
+ activation: torch.nn.Module = torch.nn.ReLU()):
18
+ """
19
+ FeedForward are applied on each position of the sequence.
20
+ the output dim is same with the input dim.
21
+
22
+ :param input_dim: int. input dimension.
23
+ :param hidden_units: int. the number of hidden units.
24
+ :param dropout_rate: float. dropout rate.
25
+ :param activation: torch.nn.Module. activation function.
26
+ """
27
+ super(PositionwiseFeedForward, self).__init__()
28
+ self.w_1 = torch.nn.Linear(input_dim, hidden_units)
29
+ self.activation = activation
30
+ self.dropout = torch.nn.Dropout(dropout_rate)
31
+ self.w_2 = torch.nn.Linear(hidden_units, input_dim)
32
+
33
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Forward function.
36
+ :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim).
37
+ :return: output tensor. shape=(batch_size, max_length, dim).
38
+ """
39
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
40
+
41
+
42
+ class TransformerBlock(nn.Module):
43
+ def __init__(self,
44
+ input_dim: int,
45
+ dropout_rate: float = 0.1,
46
+ n_heads: int = 4,
47
+ max_relative_position: int = 5120
48
+ ):
49
+ super().__init__()
50
+ self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
51
+ self.attention = RelativeMultiHeadSelfAttention(
52
+ n_head=n_heads,
53
+ n_feat=input_dim,
54
+ dropout_rate=dropout_rate,
55
+ max_relative_position=max_relative_position,
56
+ )
57
+
58
+ self.dropout1 = nn.Dropout(dropout_rate)
59
+ self.norm2 = nn.LayerNorm(input_dim, eps=1e-5)
60
+ self.ffn = PositionwiseFeedForward(
61
+ input_dim=input_dim,
62
+ hidden_units=input_dim,
63
+ dropout_rate=dropout_rate
64
+ )
65
+ self.dropout2 = nn.Dropout(dropout_rate)
66
+ self.norm3 = nn.LayerNorm(input_dim, eps=1e-5)
67
+
68
+ def forward(
69
+ self,
70
+ x: torch.Tensor,
71
+ mask: torch.Tensor = None,
72
+ attention_cache: torch.Tensor = None,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+
76
+ :param x: torch.Tensor. shape=(batch_size, time, input_dim).
77
+ :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
78
+ :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
79
+ shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
80
+ :return:
81
+ torch.Tensor: Output tensor (batch_size, time, input_dim).
82
+ torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2).
83
+ """
84
+ xt = self.norm1(x)
85
+
86
+ x_att, new_att_cache = self.attention.forward(
87
+ xt, mask=mask, cache=attention_cache
88
+ )
89
+ x = x + self.dropout1(xt)
90
+ xt = self.norm2(x)
91
+ xt = self.ffn.forward(xt)
92
+ x = x + self.dropout2(xt)
93
+
94
+ x = self.norm3(x)
95
+
96
+ return x, new_att_cache
97
+
98
+
99
+ class TransformerEncoder(nn.Module):
100
+ """
101
+ https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364
102
+ """
103
+ def __init__(self,
104
+ input_size: int = 64,
105
+ hidden_size: int = 256,
106
+ attention_heads: int = 4,
107
+ num_blocks: int = 6,
108
+ dropout_rate: float = 0.1,
109
+ max_relative_position: int = 1024,
110
+ chunk_size: int = 1,
111
+ num_left_chunks: int = 128,
112
+ num_right_chunks: int = 2,
113
+ ):
114
+ super().__init__()
115
+ self.input_size = input_size
116
+ self.hidden_size = hidden_size
117
+
118
+ self.max_relative_position = max_relative_position
119
+ self.chunk_size = chunk_size
120
+ self.num_left_chunks = num_left_chunks
121
+ self.num_right_chunks = num_right_chunks
122
+
123
+ self.input_linear = nn.Linear(
124
+ in_features=self.input_size,
125
+ out_features=self.hidden_size,
126
+ )
127
+
128
+ self.encoder_layer_list = torch.nn.ModuleList([
129
+ TransformerBlock(
130
+ input_dim=hidden_size,
131
+ n_heads=attention_heads,
132
+ dropout_rate=dropout_rate,
133
+ max_relative_position=max_relative_position,
134
+ ) for _ in range(num_blocks)
135
+ ])
136
+
137
+ self.output_linear = nn.Linear(
138
+ in_features=self.hidden_size,
139
+ out_features=self.input_size,
140
+ )
141
+
142
+ def forward(self,
143
+ xs: torch.Tensor,
144
+ ):
145
+ """
146
+ :param xs: Tensor, shape: [batch_size, time_steps, input_size]
147
+ :return: Tensor, shape: [batch_size, time_steps, input_size]
148
+ """
149
+ batch_size, time_steps, _ = xs.shape
150
+ # xs shape: [batch_size, time_steps, input_size]
151
+ xs = self.input_linear.forward(xs)
152
+ # xs shape: [batch_size, time_steps, hidden_size]
153
+
154
+ chunk_masks = subsequent_chunk_mask(
155
+ size=time_steps,
156
+ chunk_size=self.chunk_size,
157
+ num_left_chunks=self.num_left_chunks,
158
+ num_right_chunks=self.num_right_chunks,
159
+ )
160
+ chunk_masks = chunk_masks.to(xs.device)
161
+ # chunk_masks shape: [time_steps, time_steps]
162
+ chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
163
+ # chunk_masks shape: [batch_size, time_steps, time_steps]
164
+
165
+ for encoder_layer in self.encoder_layer_list:
166
+ xs, _ = encoder_layer.forward(xs, chunk_masks)
167
+
168
+ # xs shape: [batch_size, time_steps, hidden_size]
169
+ xs = self.output_linear.forward(xs)
170
+ # xs shape: [batch_size, time_steps, input_size]
171
+
172
+ return xs
173
+
174
+ def forward_chunk(self,
175
+ xs: torch.Tensor,
176
+ max_att_cache_length: int,
177
+ attention_cache: torch.Tensor = None,
178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
179
+ """
180
+
181
+ :param xs:
182
+ :param max_att_cache_length:
183
+ :param attention_cache: Tensor, [num_layers, ...]
184
+ :return:
185
+ """
186
+ # xs shape: [batch_size, time_steps, input_size]
187
+ xs = self.input_linear.forward(xs)
188
+ # xs shape: [batch_size, time_steps, hidden_size]
189
+
190
+ r_att_cache = []
191
+ for idx, encoder_layer in enumerate(self.encoder_layer_list):
192
+ xs, new_att_cache = encoder_layer.forward(
193
+ x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
194
+ )
195
+ # new_att_cache shape: [batch_size, n_heads, time_steps, dim]
196
+ if new_att_cache.size(2) > max_att_cache_length:
197
+ begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
198
+ end = self.num_right_chunks * self.chunk_size
199
+ new_att_cache = new_att_cache[:, :, -begin:-end, :]
200
+ r_att_cache.append(new_att_cache)
201
+
202
+ r_att_cache = torch.stack(r_att_cache, dim=0)
203
+
204
+ # xs shape: [batch_size, time_steps, hidden_size]
205
+ xs = self.output_linear.forward(xs)
206
+ # xs shape: [batch_size, time_steps, input_size]
207
+
208
+ return xs, r_att_cache
209
+
210
+ def forward_chunk_by_chunk(
211
+ self,
212
+ xs: torch.Tensor,
213
+ ) -> torch.Tensor:
214
+
215
+ batch_size, time_steps, _ = xs.shape
216
+
217
+ # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
218
+ max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
219
+ attention_cache = None
220
+
221
+ outputs = []
222
+ for idx in range(0, time_steps, self.chunk_size):
223
+ begin = idx
224
+ end = begin + self.chunk_size * (self.num_right_chunks + 1)
225
+ chunk_xs = xs[:, begin:end, :]
226
+ # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
227
+
228
+ ys, attention_cache = self.forward_chunk(
229
+ xs=chunk_xs,
230
+ max_att_cache_length=max_att_cache_length,
231
+ attention_cache=attention_cache,
232
+ )
233
+
234
+ # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), input_size]
235
+ ys = ys[:, :self.chunk_size, :]
236
+
237
+ outputs.append(ys)
238
+
239
+ ys = torch.cat(outputs, 1)
240
+ return ys
241
+
242
+
243
+ class TSTransformerBlock(nn.Module):
244
+ def __init__(self,
245
+ input_dim: int,
246
+ dropout_rate: float = 0.1,
247
+ n_heads: int = 4,
248
+ max_time_relative_position: int = 1024,
249
+ max_freq_relative_position: int = 128,
250
+ ):
251
+ super(TSTransformerBlock, self).__init__()
252
+ self.time_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_time_relative_position)
253
+ self.freq_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_freq_relative_position)
254
+
255
+ def forward(self,
256
+ x: torch.Tensor,
257
+ mask: torch.Tensor = None,
258
+ attention_cache: torch.Tensor = None,
259
+ ):
260
+ """
261
+
262
+ :param x: Tensor. shape: [batch_size, hidden_size, time_steps, input_size]
263
+ :param mask: Tensor. shape: [time_steps, time_steps]
264
+ :param attention_cache:
265
+ :return:
266
+ """
267
+ b, c, t, f = x.size()
268
+
269
+ mask = None if mask is None else torch.broadcast_to(mask, size=(b*f, t, t))
270
+
271
+ x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
272
+ x_, new_att_cache = self.time_transformer.forward(x, mask, attention_cache)
273
+ x = x_ + x
274
+ x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
275
+ x_, _ = self.freq_transformer.forward(x)
276
+ x = x_ + x
277
+ x = x.view(b, t, f, c).permute(0, 3, 1, 2)
278
+ return x, new_att_cache
279
+
280
+
281
+ class TSTransformerEncoder(nn.Module):
282
+ def __init__(self,
283
+ input_size: int = 64,
284
+ hidden_size: int = 256,
285
+ attention_heads: int = 4,
286
+ num_blocks: int = 6,
287
+ dropout_rate: float = 0.1,
288
+ max_time_relative_position: int = 1024,
289
+ max_freq_relative_position: int = 128,
290
+ chunk_size: int = 1,
291
+ num_left_chunks: int = 128,
292
+ num_right_chunks: int = 2,
293
+ ):
294
+ super().__init__()
295
+ self.input_size = input_size
296
+ self.hidden_size = hidden_size
297
+
298
+ self.max_time_relative_position = max_time_relative_position
299
+ self.max_freq_relative_position = max_freq_relative_position
300
+ self.chunk_size = chunk_size
301
+ self.num_left_chunks = num_left_chunks
302
+ self.num_right_chunks = num_right_chunks
303
+
304
+ self.input_linear = nn.Linear(
305
+ in_features=self.input_size,
306
+ out_features=self.hidden_size,
307
+ )
308
+
309
+ self.encoder_layer_list = torch.nn.ModuleList([
310
+ TSTransformerBlock(
311
+ input_dim=hidden_size,
312
+ n_heads=attention_heads,
313
+ dropout_rate=dropout_rate,
314
+ max_time_relative_position=max_time_relative_position,
315
+ max_freq_relative_position=max_freq_relative_position,
316
+ ) for _ in range(num_blocks)
317
+ ])
318
+
319
+ self.output_linear = nn.Linear(
320
+ in_features=self.hidden_size,
321
+ out_features=self.input_size,
322
+ )
323
+
324
+ def forward(self,
325
+ xs: torch.Tensor,
326
+ ):
327
+ """
328
+ :param xs: Tensor, shape: [batch_size, channels, time_steps, input_size]
329
+ :return: Tensor, shape: [batch_size, channels, time_steps, input_size]
330
+ """
331
+ batch_size, channels, time_steps, _ = xs.shape
332
+ # xs shape: [batch_size, channels, time_steps, input_size]
333
+ xs = xs.permute(0, 3, 2, 1)
334
+ # xs shape: [batch_size, input_size, time_steps, channels]
335
+ xs = self.input_linear.forward(xs)
336
+ # xs shape: [batch_size, input_size, time_steps, hidden_size]
337
+ xs = xs.permute(0, 3, 2, 1)
338
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
339
+
340
+ chunk_masks = subsequent_chunk_mask(
341
+ size=time_steps,
342
+ chunk_size=self.chunk_size,
343
+ num_left_chunks=self.num_left_chunks,
344
+ num_right_chunks=self.num_right_chunks,
345
+ )
346
+ chunk_masks = chunk_masks.to(xs.device)
347
+ # chunk_masks shape: [time_steps, time_steps]
348
+
349
+ for encoder_layer in self.encoder_layer_list:
350
+ xs, _ = encoder_layer.forward(xs, chunk_masks)
351
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
352
+ xs = xs.permute(0, 3, 2, 1)
353
+ # xs shape: [batch_size, input_size, time_steps, hidden_size]
354
+ xs = self.output_linear.forward(xs)
355
+ # xs shape: [batch_size, input_size, time_steps, channels]
356
+ xs = xs.permute(0, 3, 2, 1)
357
+ # xs shape: [batch_size, channels, time_steps, input_size]
358
+
359
+ return xs
360
+
361
+ def forward_chunk(self,
362
+ xs: torch.Tensor,
363
+ max_att_cache_length: int,
364
+ attention_cache: torch.Tensor = None,
365
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
366
+ """
367
+
368
+ :param xs:
369
+ :param max_att_cache_length:
370
+ :param attention_cache: Tensor, shape: [num_layers, ...]
371
+ :return:
372
+ """
373
+ # xs shape: [batch_size, channels, time_steps, input_size]
374
+ xs = xs.permute(0, 3, 2, 1)
375
+ xs = self.input_linear.forward(xs)
376
+ xs = xs.permute(0, 3, 2, 1)
377
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
378
+
379
+ r_att_cache = []
380
+ for idx, encoder_layer in enumerate(self.encoder_layer_list):
381
+ xs, new_att_cache = encoder_layer.forward(
382
+ x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
383
+ )
384
+ # new_att_cache shape: [b*f, n_heads, time_steps, dim]
385
+ if new_att_cache.size(2) > max_att_cache_length:
386
+ begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
387
+ end = self.num_right_chunks * self.chunk_size
388
+ new_att_cache = new_att_cache[:, :, -begin:-end, :]
389
+ r_att_cache.append(new_att_cache)
390
+
391
+ r_att_cache = torch.stack(r_att_cache, dim=0)
392
+
393
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
394
+ xs = xs.permute(0, 3, 2, 1)
395
+ xs = self.output_linear.forward(xs)
396
+ xs = xs.permute(0, 3, 2, 1)
397
+ # xs shape: [batch_size, channels, time_steps, input_size]
398
+
399
+ return xs, r_att_cache
400
+
401
+ def forward_chunk_by_chunk(
402
+ self,
403
+ xs: torch.Tensor,
404
+ ) -> torch.Tensor:
405
+
406
+ batch_size, channels, time_steps, _ = xs.shape
407
+
408
+ max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
409
+ attention_cache = None
410
+
411
+ outputs = []
412
+ for idx in range(0, time_steps, self.chunk_size):
413
+ begin = idx
414
+ end = begin + self.chunk_size * (self.num_right_chunks + 1)
415
+ chunk_xs = xs[:, :, begin:end, :]
416
+ # chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
417
+
418
+ ys, attention_cache = self.forward_chunk(
419
+ xs=chunk_xs,
420
+ max_att_cache_length=max_att_cache_length,
421
+ attention_cache=attention_cache,
422
+ )
423
+ # ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
424
+ ys = ys[:, :, :self.chunk_size, :]
425
+
426
+ outputs.append(ys)
427
+
428
+ ys = torch.cat(outputs, dim=2)
429
+ return ys
430
+
431
+
432
+ def main2():
433
+
434
+ encoder = TransformerEncoder(
435
+ input_size=64,
436
+ hidden_size=256,
437
+ attention_heads=4,
438
+ num_blocks=6,
439
+ dropout_rate=0.1,
440
+ )
441
+ print(encoder)
442
+
443
+ x = torch.ones([4, 200, 64])
444
+
445
+ x = torch.ones([4, 200, 64])
446
+ y = encoder.forward(xs=x)
447
+ print(y.shape)
448
+
449
+ x = torch.ones([4, 200, 64])
450
+ y = encoder.forward_chunk_by_chunk(xs=x)
451
+ print(y.shape)
452
+
453
+ return
454
+
455
+
456
+ def main():
457
+
458
+ encoder = TSTransformerEncoder(
459
+ input_size=16,
460
+ hidden_size=64,
461
+ attention_heads=4,
462
+ num_blocks=4,
463
+ dropout_rate=0.1,
464
+ )
465
+ # print(encoder)
466
+
467
+ x = torch.ones([4, 16, 200, 32])
468
+ y = encoder.forward(xs=x)
469
+ print(y.shape)
470
+
471
+ x = torch.ones([4, 16, 200, 32])
472
+ y = encoder.forward_chunk_by_chunk(xs=x)
473
+ print(y.shape)
474
+
475
+ return
476
+
477
+
478
+ if __name__ == '__main__':
479
+ main()
toolbox/torchaudio/models/nx_denoise/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class LearnableSigmoid1d(nn.Module):
8
+ def __init__(self, in_features, beta=1):
9
+ super().__init__()
10
+ self.beta = beta
11
+ self.slope = nn.Parameter(torch.ones(in_features))
12
+ self.slope.requiresGrad = True
13
+
14
+ def forward(self, x):
15
+ # x shape: [batch_size, time_steps, spec_bins]
16
+ return self.beta * torch.sigmoid(self.slope * x)
17
+
18
+
19
+ def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
20
+
21
+ hann_window = torch.hann_window(win_size).to(y.device)
22
+ stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
23
+ center=center, pad_mode='reflect', normalized=False, return_complex=True)
24
+ stft_spec = torch.view_as_real(stft_spec)
25
+ mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
26
+ pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
27
+ # Magnitude Compression
28
+ mag = torch.pow(mag, compress_factor)
29
+ com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
30
+
31
+ return mag, pha, com
32
+
33
+
34
+ def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
35
+ # Magnitude Decompression
36
+ mag = torch.pow(mag, (1.0/compress_factor))
37
+ com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
38
+ hann_window = torch.hann_window(win_size).to(com.device)
39
+ wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
40
+
41
+ return wav
42
+
43
+
44
+ if __name__ == '__main__':
45
+ pass
toolbox/torchaudio/models/nx_denoise/yaml/config.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "nx_clean_unet"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 16000
5
+ n_fft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+ # 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。
9
+
10
+ # 2**down_sampling_num_layers,
11
+ # 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步,
12
+ # 则一步是 64/sample_rate = 0.008秒。
13
+ # 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms
14
+ # 假设每次向左看1秒,向右看30ms,则:
15
+ # tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4
16
+ # tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
17
+ # tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
18
+ down_sampling_num_layers: 6
19
+ down_sampling_in_channels: 1
20
+ down_sampling_hidden_channels: 64
21
+ down_sampling_kernel_size: 4
22
+ down_sampling_stride: 2
23
+
24
+ causal_in_channels: 1
25
+ causal_out_channels: 64
26
+ causal_kernel_size: 3
27
+ causal_bias: false
28
+ causal_separable: true
29
+ causal_f_stride: 1
30
+ causal_num_layers: 3
31
+
32
+ tsfm_hidden_size: 256
33
+ tsfm_attention_heads: 8
34
+ tsfm_num_blocks: 6
35
+ tsfm_dropout_rate: 0.1
36
+ tsfm_max_length: 512
37
+ tsfm_chunk_size: 1
38
+ tsfm_num_left_chunks: 128
39
+ tsfm_num_right_chunks: 4
40
+
41
+ discriminator_dim: 32
42
+ discriminator_in_channel: 2
43
+
44
+ compress_factor: 0.3
45
+
46
+ batch_size: 4
47
+ learning_rate: 0.0005
48
+ adam_b1: 0.8
49
+ adam_b2: 0.99
50
+ lr_decay: 0.99
51
+ seed: 1234