HoneyTian commited on
Commit
33aff71
·
1 Parent(s): 3834772
Files changed (26) hide show
  1. .gitignore +1 -1
  2. examples/nx_mpnet/run.sh +166 -0
  3. examples/nx_mpnet/step_1_prepare_data.py +202 -0
  4. examples/nx_mpnet/step_2_train_model.py +447 -0
  5. examples/nx_mpnet/step_3_evaluation.py +187 -0
  6. examples/nx_mpnet/yaml/config.yaml +27 -0
  7. toolbox/torchaudio/models/mpnet/inference_mpnet.py +4 -2
  8. toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py +12 -12
  9. toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py +97 -0
  10. toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py +52 -0
  11. toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py +6 -0
  12. toolbox/torchaudio/models/nx_denoise/stftnet/istftnet.py +9 -0
  13. toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py +9 -0
  14. toolbox/torchaudio/models/nx_mpnet/__init__.py +6 -0
  15. toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py +6 -0
  16. toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py +445 -0
  17. toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py +90 -0
  18. toolbox/torchaudio/models/nx_mpnet/discriminator.py +132 -0
  19. toolbox/torchaudio/models/nx_mpnet/loss.py +22 -0
  20. toolbox/torchaudio/models/nx_mpnet/metrics.py +80 -0
  21. toolbox/torchaudio/models/nx_mpnet/modeling_nx_mpnet.py +143 -0
  22. toolbox/torchaudio/models/nx_mpnet/transformers/__init__.py +6 -0
  23. toolbox/torchaudio/models/nx_mpnet/transformers/attention.py +263 -0
  24. toolbox/torchaudio/models/nx_mpnet/transformers/mask.py +74 -0
  25. toolbox/torchaudio/models/nx_mpnet/transformers/transformers.py +479 -0
  26. toolbox/torchaudio/models/nx_mpnet/utils.py +56 -0
.gitignore CHANGED
@@ -18,5 +18,5 @@
18
  /trained_models/
19
  /temp/
20
 
21
- #**/*.wav
22
  **/*.xlsx
 
18
  /trained_models/
19
  /temp/
20
 
21
+ **/*.wav
22
  **/*.xlsx
examples/nx_mpnet/run.sh ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+ sh run.sh --stage 5 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
+
19
+
20
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
21
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
22
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
23
+ --max_epochs 1
24
+
25
+
26
+ END
27
+
28
+
29
+ # params
30
+ system_version="windows";
31
+ verbose=true;
32
+ stage=0 # start from 0 if you need to start from data preparation
33
+ stop_stage=9
34
+
35
+ work_dir="$(pwd)"
36
+ file_folder_name=file_folder_name
37
+ final_model_name=final_model_name
38
+ config_file="yaml/config.yaml"
39
+ limit=10
40
+
41
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
42
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
43
+
44
+ nohup_name=nohup.out
45
+
46
+ # model params
47
+ batch_size=64
48
+ max_epochs=200
49
+ save_top_k=10
50
+ patience=5
51
+
52
+
53
+ # parse options
54
+ while true; do
55
+ [ -z "${1:-}" ] && break; # break if there are no arguments
56
+ case "$1" in
57
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
58
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
59
+ old_value="(eval echo \\$$name)";
60
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
61
+ was_bool=true;
62
+ else
63
+ was_bool=false;
64
+ fi
65
+
66
+ # Set the variable to the right value-- the escaped quotes make it work if
67
+ # the option had spaces, like --cmd "queue.pl -sync y"
68
+ eval "${name}=\"$2\"";
69
+
70
+ # Check that Boolean-valued arguments are really Boolean.
71
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
72
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
73
+ exit 1;
74
+ fi
75
+ shift 2;
76
+ ;;
77
+
78
+ *) break;
79
+ esac
80
+ done
81
+
82
+ file_dir="${work_dir}/${file_folder_name}"
83
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
84
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
85
+
86
+ dataset="${file_dir}/dataset.xlsx"
87
+ train_dataset="${file_dir}/train.xlsx"
88
+ valid_dataset="${file_dir}/valid.xlsx"
89
+
90
+ $verbose && echo "system_version: ${system_version}"
91
+ $verbose && echo "file_folder_name: ${file_folder_name}"
92
+
93
+ if [ $system_version == "windows" ]; then
94
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
95
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
96
+ #source /data/local/bin/nx_denoise/bin/activate
97
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
98
+ fi
99
+
100
+
101
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
102
+ $verbose && echo "stage 1: prepare data"
103
+ cd "${work_dir}" || exit 1
104
+ python3 step_1_prepare_data.py \
105
+ --file_dir "${file_dir}" \
106
+ --noise_dir "${noise_dir}" \
107
+ --speech_dir "${speech_dir}" \
108
+ --train_dataset "${train_dataset}" \
109
+ --valid_dataset "${valid_dataset}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
115
+ $verbose && echo "stage 2: train model"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_2_train_model.py \
118
+ --train_dataset "${train_dataset}" \
119
+ --valid_dataset "${valid_dataset}" \
120
+ --serialization_dir "${file_dir}" \
121
+ --config_file "${config_file}" \
122
+
123
+ fi
124
+
125
+
126
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
127
+ $verbose && echo "stage 3: test model"
128
+ cd "${work_dir}" || exit 1
129
+ python3 step_3_evaluation.py \
130
+ --valid_dataset "${valid_dataset}" \
131
+ --model_dir "${file_dir}/best" \
132
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
133
+ --limit "${limit}" \
134
+
135
+ fi
136
+
137
+
138
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
139
+ $verbose && echo "stage 4: collect files"
140
+ cd "${work_dir}" || exit 1
141
+
142
+ mkdir -p ${final_model_dir}
143
+
144
+ cp "${file_dir}/best"/* "${final_model_dir}"
145
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
146
+
147
+ cd "${final_model_dir}/.." || exit 1;
148
+
149
+ if [ -e "${final_model_name}.zip" ]; then
150
+ rm -rf "${final_model_name}_backup.zip"
151
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
152
+ fi
153
+
154
+ zip -r "${final_model_name}.zip" "${final_model_name}"
155
+ rm -rf "${final_model_name}"
156
+
157
+ fi
158
+
159
+
160
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
161
+ $verbose && echo "stage 5: clear file_dir"
162
+ cd "${work_dir}" || exit 1
163
+
164
+ rm -rf "${file_dir}";
165
+
166
+ fi
examples/nx_mpnet/step_1_prepare_data.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ parser.add_argument("--scale", default=1, type=float)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def filename_generator(data_dir: str):
52
+ data_dir = Path(data_dir)
53
+ for filename in data_dir.glob("**/*.wav"):
54
+ yield filename.as_posix()
55
+
56
+
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
58
+ data_dir = Path(data_dir)
59
+ for filename in data_dir.glob("**/*.wav"):
60
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
61
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
62
+
63
+ if raw_duration < duration:
64
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
65
+ continue
66
+ if signal.ndim != 1:
67
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
68
+
69
+ signal_length = len(signal)
70
+ win_size = int(duration * sample_rate)
71
+ for begin in range(0, signal_length - win_size, win_size):
72
+ row = {
73
+ "filename": filename.as_posix(),
74
+ "raw_duration": round(raw_duration, 4),
75
+ "offset": round(begin / sample_rate, 4),
76
+ "duration": round(duration, 4),
77
+ }
78
+ yield row
79
+
80
+
81
+ def get_dataset(args):
82
+ file_dir = Path(args.file_dir)
83
+ file_dir.mkdir(exist_ok=True)
84
+
85
+ noise_dir = Path(args.noise_dir)
86
+ speech_dir = Path(args.speech_dir)
87
+
88
+ noise_generator = target_second_signal_generator(
89
+ noise_dir.as_posix(),
90
+ duration=args.duration,
91
+ sample_rate=args.target_sample_rate
92
+ )
93
+ speech_generator = target_second_signal_generator(
94
+ speech_dir.as_posix(),
95
+ duration=args.duration,
96
+ sample_rate=args.target_sample_rate
97
+ )
98
+
99
+ dataset = list()
100
+
101
+ count = 0
102
+ process_bar = tqdm(desc="build dataset excel")
103
+ for noise, speech in zip(noise_generator, speech_generator):
104
+ flag = random.random()
105
+ if flag > args.scale:
106
+ continue
107
+
108
+ noise_filename = noise["filename"]
109
+ noise_raw_duration = noise["raw_duration"]
110
+ noise_offset = noise["offset"]
111
+ noise_duration = noise["duration"]
112
+
113
+ speech_filename = speech["filename"]
114
+ speech_raw_duration = speech["raw_duration"]
115
+ speech_offset = speech["offset"]
116
+ speech_duration = speech["duration"]
117
+
118
+ random1 = random.random()
119
+ random2 = random.random()
120
+
121
+ row = {
122
+ "noise_filename": noise_filename,
123
+ "noise_raw_duration": noise_raw_duration,
124
+ "noise_offset": noise_offset,
125
+ "noise_duration": noise_duration,
126
+
127
+ "speech_filename": speech_filename,
128
+ "speech_raw_duration": speech_raw_duration,
129
+ "speech_offset": speech_offset,
130
+ "speech_duration": speech_duration,
131
+
132
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
133
+
134
+ "random1": random1,
135
+ "random2": random2,
136
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
137
+ }
138
+ dataset.append(row)
139
+ count += 1
140
+ duration_seconds = count * args.duration
141
+ duration_hours = duration_seconds / 3600
142
+
143
+ process_bar.update(n=1)
144
+ process_bar.set_postfix({
145
+ # "duration_seconds": round(duration_seconds, 4),
146
+ "duration_hours": round(duration_hours, 4),
147
+
148
+ })
149
+
150
+ dataset = pd.DataFrame(dataset)
151
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
152
+ dataset.to_excel(
153
+ file_dir / "dataset.xlsx",
154
+ index=False,
155
+ )
156
+ return
157
+
158
+
159
+
160
+ def split_dataset(args):
161
+ """分割训练集, 测试集"""
162
+ file_dir = Path(args.file_dir)
163
+ file_dir.mkdir(exist_ok=True)
164
+
165
+ df = pd.read_excel(file_dir / "dataset.xlsx")
166
+
167
+ train = list()
168
+ test = list()
169
+
170
+ for i, row in df.iterrows():
171
+ flag = row["flag"]
172
+ if flag == "TRAIN":
173
+ train.append(row)
174
+ else:
175
+ test.append(row)
176
+
177
+ train = pd.DataFrame(train)
178
+ train.to_excel(
179
+ args.train_dataset,
180
+ index=False,
181
+ # encoding="utf_8_sig"
182
+ )
183
+ test = pd.DataFrame(test)
184
+ test.to_excel(
185
+ args.valid_dataset,
186
+ index=False,
187
+ # encoding="utf_8_sig"
188
+ )
189
+
190
+ return
191
+
192
+
193
+ def main():
194
+ args = get_args()
195
+
196
+ get_dataset(args)
197
+ split_dataset(args)
198
+ return
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()
examples/nx_mpnet/step_2_train_model.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch.nn import functional as F
24
+ from torch.utils.data.dataloader import DataLoader
25
+ from tqdm import tqdm
26
+
27
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
+ from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
29
+ from toolbox.torchaudio.models.nx_mpnet.discriminator import MetricDiscriminatorPretrainedModel
30
+ from toolbox.torchaudio.models.nx_mpnet.modeling_nx_mpnet import NXMPNet, NXMPNetPretrainedModel
31
+ from toolbox.torchaudio.models.nx_mpnet.utils import mag_pha_stft, mag_pha_istft
32
+ from toolbox.torchaudio.models.nx_mpnet.metrics import run_batch_pesq, run_pesq_score
33
+ from toolbox.torchaudio.models.nx_mpnet.loss import phase_losses
34
+
35
+
36
+ def get_args():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
39
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
40
+
41
+ parser.add_argument("--max_epochs", default=100, type=int)
42
+
43
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
+ parser.add_argument("--patience", default=5, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ clean_audios = list()
80
+ noisy_audios = list()
81
+
82
+ for sample in batch:
83
+ # noise_wave: torch.Tensor = sample["noise_wave"]
84
+ clean_audio: torch.Tensor = sample["speech_wave"]
85
+ noisy_audio: torch.Tensor = sample["mix_wave"]
86
+ # snr_db: float = sample["snr_db"]
87
+
88
+ clean_audios.append(clean_audio)
89
+ noisy_audios.append(noisy_audio)
90
+
91
+ clean_audios = torch.stack(clean_audios)
92
+ noisy_audios = torch.stack(noisy_audios)
93
+
94
+ # assert
95
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
96
+ raise AssertionError("nan or inf in clean_audios")
97
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
98
+ raise AssertionError("nan or inf in noisy_audios")
99
+ return clean_audios, noisy_audios
100
+
101
+
102
+ collate_fn = CollateFunction()
103
+
104
+
105
+ def main():
106
+ args = get_args()
107
+
108
+ config = NXMPNetConfig.from_pretrained(
109
+ pretrained_model_name_or_path=args.config_file,
110
+ )
111
+
112
+ serialization_dir = Path(args.serialization_dir)
113
+ serialization_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ logger = logging_config(serialization_dir)
116
+
117
+ random.seed(config.seed)
118
+ np.random.seed(config.seed)
119
+ torch.manual_seed(config.seed)
120
+ logger.info(f"set seed: {config.seed}")
121
+
122
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ n_gpu = torch.cuda.device_count()
124
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
125
+
126
+ # datasets
127
+ train_dataset = DenoiseExcelDataset(
128
+ excel_file=args.train_dataset,
129
+ expected_sample_rate=8000,
130
+ max_wave_value=32768.0,
131
+ )
132
+ valid_dataset = DenoiseExcelDataset(
133
+ excel_file=args.valid_dataset,
134
+ expected_sample_rate=8000,
135
+ max_wave_value=32768.0,
136
+ )
137
+ train_data_loader = DataLoader(
138
+ dataset=train_dataset,
139
+ batch_size=config.batch_size,
140
+ shuffle=True,
141
+ sampler=None,
142
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
143
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
144
+ collate_fn=collate_fn,
145
+ pin_memory=False,
146
+ # prefetch_factor=64,
147
+ )
148
+ valid_data_loader = DataLoader(
149
+ dataset=valid_dataset,
150
+ batch_size=config.batch_size,
151
+ shuffle=True,
152
+ sampler=None,
153
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
154
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
155
+ collate_fn=collate_fn,
156
+ pin_memory=False,
157
+ # prefetch_factor=64,
158
+ )
159
+
160
+ # models
161
+ logger.info(f"prepare models. config_file: {args.config_file}")
162
+ generator = NXMPNetPretrainedModel(config).to(device)
163
+ discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
164
+
165
+ # optimizer
166
+ logger.info("prepare optimizer, lr_scheduler")
167
+ num_params = 0
168
+ for p in generator.parameters():
169
+ num_params += p.numel()
170
+ logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
171
+
172
+ optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
173
+ optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
174
+
175
+ # resume training
176
+ last_epoch = -1
177
+ for epoch_i in serialization_dir.glob("epoch-*"):
178
+ epoch_i = Path(epoch_i)
179
+ epoch_idx = epoch_i.stem.split("-")[1]
180
+ epoch_idx = int(epoch_idx)
181
+ if epoch_idx > last_epoch:
182
+ last_epoch = epoch_idx
183
+
184
+ if last_epoch != -1:
185
+ logger.info(f"resume from epoch-{last_epoch}.")
186
+ generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
187
+ discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
188
+ optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
189
+ optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
190
+
191
+ logger.info(f"load state dict for generator.")
192
+ with open(generator_pt.as_posix(), "rb") as f:
193
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
+ generator.load_state_dict(state_dict, strict=True)
195
+ logger.info(f"load state dict for discriminator.")
196
+ with open(discriminator_pt.as_posix(), "rb") as f:
197
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
198
+ discriminator.load_state_dict(state_dict, strict=True)
199
+
200
+ logger.info(f"load state dict for optim_g.")
201
+ with open(optim_g_pth.as_posix(), "rb") as f:
202
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
203
+ optim_g.load_state_dict(state_dict)
204
+ logger.info(f"load state dict for optim_d.")
205
+ with open(optim_d_pth.as_posix(), "rb") as f:
206
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
207
+ optim_d.load_state_dict(state_dict)
208
+
209
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
210
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
211
+
212
+ # training loop
213
+
214
+ # state
215
+ loss_d = 10000000000
216
+ loss_g = 10000000000
217
+ pesq_metric = 10000000000
218
+ mag_err = 10000000000
219
+ pha_err = 10000000000
220
+ com_err = 10000000000
221
+ stft_err = 10000000000
222
+
223
+ model_list = list()
224
+ best_idx_epoch = None
225
+ best_metric = None
226
+ patience_count = 0
227
+
228
+ logger.info("training")
229
+ for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
230
+ # train
231
+ generator.train()
232
+ discriminator.train()
233
+
234
+ total_loss_d = 0.
235
+ total_loss_g = 0.
236
+ total_batches = 0.
237
+ progress_bar = tqdm(
238
+ total=len(train_data_loader),
239
+ desc="Training; epoch: {}".format(idx_epoch),
240
+ )
241
+ for batch in train_data_loader:
242
+ clean_audio, noisy_audio = batch
243
+ clean_audio = clean_audio.to(device)
244
+ noisy_audio = noisy_audio.to(device)
245
+ one_labels = torch.ones(clean_audio.shape[0]).to(device)
246
+
247
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
248
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
249
+
250
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
251
+
252
+ audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
253
+ mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
254
+
255
+ audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
256
+ pesq_score_list: List[float] = run_batch_pesq(audio_list_r, audio_list_g, sample_rate=config.sample_rate, mode="nb")
257
+
258
+ # Discriminator
259
+ optim_d.zero_grad()
260
+ metric_r = discriminator.forward(clean_mag, clean_mag)
261
+ metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
262
+ loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
263
+
264
+ if -1 in pesq_score_list:
265
+ # print("-1 in batch_pesq_score!")
266
+ loss_disc_g = 0
267
+ else:
268
+ pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
269
+ loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
270
+
271
+ loss_disc_all = loss_disc_r + loss_disc_g
272
+ loss_disc_all.backward()
273
+ optim_d.step()
274
+
275
+ # Generator
276
+ optim_g.zero_grad()
277
+ # L2 Magnitude Loss
278
+ loss_mag = F.mse_loss(clean_mag, mag_g)
279
+ # Anti-wrapping Phase Loss
280
+ loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
281
+ loss_pha = loss_ip + loss_gd + loss_iaf
282
+ # L2 Complex Loss
283
+ loss_com = F.mse_loss(clean_com, com_g) * 2
284
+ # L2 Consistency Loss
285
+ loss_stft = F.mse_loss(com_g, com_g_hat) * 2
286
+ # Time Loss
287
+ loss_time = F.l1_loss(clean_audio, audio_g)
288
+ # Metric Loss
289
+ metric_g = discriminator.forward(clean_mag, mag_g_hat)
290
+ loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
291
+
292
+ loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
293
+
294
+ loss_gen_all.backward()
295
+ optim_g.step()
296
+
297
+ total_loss_d += loss_disc_all.item()
298
+ total_loss_g += loss_gen_all.item()
299
+ total_batches += 1
300
+
301
+ loss_d = round(total_loss_d / total_batches, 4)
302
+ loss_g = round(total_loss_g / total_batches, 4)
303
+
304
+ progress_bar.update(1)
305
+ progress_bar.set_postfix({
306
+ "loss_d": loss_d,
307
+ "loss_g": loss_g,
308
+ })
309
+
310
+ # evaluation
311
+ generator.eval()
312
+ discriminator.eval()
313
+
314
+ torch.cuda.empty_cache()
315
+ total_pesq_score = 0.
316
+ total_mag_err = 0.
317
+ total_pha_err = 0.
318
+ total_com_err = 0.
319
+ total_stft_err = 0.
320
+ total_batches = 0.
321
+
322
+ progress_bar = tqdm(
323
+ total=len(valid_data_loader),
324
+ desc="Evaluation; epoch: {}".format(idx_epoch),
325
+ )
326
+ with torch.no_grad():
327
+ for batch in valid_data_loader:
328
+ clean_audio, noisy_audio = batch
329
+ clean_audio = clean_audio.to(device)
330
+ noisy_audio = noisy_audio.to(device)
331
+
332
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
333
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
334
+
335
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
336
+
337
+ audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
338
+ mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
339
+
340
+ clean_audio_list = torch.split(clean_audio, 1, dim=0)
341
+ enhanced_audio_list = torch.split(audio_g, 1, dim=0)
342
+ clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list]
343
+ enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list]
344
+ pesq_score = run_pesq_score(
345
+ clean_audio_list,
346
+ enhanced_audio_list,
347
+ sample_rate = config.sample_rate,
348
+ mode = "nb",
349
+ )
350
+ total_pesq_score += pesq_score
351
+ total_mag_err += F.mse_loss(clean_mag, mag_g).item()
352
+ val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
353
+ total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
354
+ total_com_err += F.mse_loss(clean_com, com_g).item()
355
+ total_stft_err += F.mse_loss(com_g, com_g_hat).item()
356
+
357
+ total_batches += 1
358
+
359
+ pesq_metric = round(total_pesq_score / total_batches, 4)
360
+ mag_err = round(total_mag_err / total_batches, 4)
361
+ pha_err = round(total_pha_err / total_batches, 4)
362
+ com_err = round(total_com_err / total_batches, 4)
363
+ stft_err = round(total_stft_err / total_batches, 4)
364
+
365
+ progress_bar.update(1)
366
+ progress_bar.set_postfix({
367
+ "pesq_metric": pesq_metric,
368
+ "mag_err": mag_err,
369
+ "pha_err": pha_err,
370
+ "com_err": com_err,
371
+ "stft_err": stft_err,
372
+ })
373
+
374
+ # scheduler
375
+ scheduler_g.step()
376
+ scheduler_d.step()
377
+
378
+ # save path
379
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
380
+ epoch_dir.mkdir(parents=True, exist_ok=False)
381
+
382
+ # save models
383
+ generator.save_pretrained(epoch_dir.as_posix())
384
+ discriminator.save_pretrained(epoch_dir.as_posix())
385
+
386
+ # save optim
387
+ torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
388
+ torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
389
+
390
+ model_list.append(epoch_dir)
391
+ if len(model_list) >= args.num_serialized_models_to_keep:
392
+ model_to_delete: Path = model_list.pop(0)
393
+ shutil.rmtree(model_to_delete.as_posix())
394
+
395
+ # save metric
396
+ if best_metric is None:
397
+ best_idx_epoch = idx_epoch
398
+ best_metric = pesq_metric
399
+ elif pesq_metric > best_metric:
400
+ # great is better.
401
+ best_idx_epoch = idx_epoch
402
+ best_metric = pesq_metric
403
+ else:
404
+ pass
405
+
406
+ metrics = {
407
+ "idx_epoch": idx_epoch,
408
+ "best_idx_epoch": best_idx_epoch,
409
+ "loss_d": loss_d,
410
+ "loss_g": loss_g,
411
+
412
+ "pesq_metric": pesq_metric,
413
+ "mag_err": mag_err,
414
+ "pha_err": pha_err,
415
+ "com_err": com_err,
416
+ "stft_err": stft_err,
417
+
418
+ }
419
+ metrics_filename = epoch_dir / "metrics_epoch.json"
420
+ with open(metrics_filename, "w", encoding="utf-8") as f:
421
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
422
+
423
+ # save best
424
+ best_dir = serialization_dir / "best"
425
+ if best_idx_epoch == idx_epoch:
426
+ if best_dir.exists():
427
+ shutil.rmtree(best_dir)
428
+ shutil.copytree(epoch_dir, best_dir)
429
+
430
+ # early stop
431
+ early_stop_flag = False
432
+ if best_idx_epoch == idx_epoch:
433
+ patience_count = 0
434
+ else:
435
+ patience_count += 1
436
+ if patience_count >= args.patience:
437
+ early_stop_flag = True
438
+
439
+ # early stop
440
+ if early_stop_flag:
441
+ break
442
+
443
+ return
444
+
445
+
446
+ if __name__ == "__main__":
447
+ main()
examples/nx_mpnet/step_3_evaluation.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/yxlu-0102/MP-SENet/blob/main/inference.py
5
+ """
6
+ import argparse
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ import sys
11
+ import uuid
12
+
13
+ pwd = os.path.abspath(os.path.dirname(__file__))
14
+ sys.path.append(os.path.join(pwd, "../../"))
15
+
16
+ import librosa
17
+ import numpy as np
18
+ import pandas as pd
19
+ from scipy.io import wavfile
20
+ import torch
21
+ import torch.nn as nn
22
+ import torchaudio
23
+ from tqdm import tqdm
24
+
25
+ from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
26
+ from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel
27
+ from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
33
+ parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
34
+ parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
35
+
36
+ parser.add_argument("--limit", default=10, type=int)
37
+
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def logging_config():
43
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
44
+
45
+ logging.basicConfig(format=fmt,
46
+ datefmt="%m/%d/%Y %H:%M:%S",
47
+ level=logging.INFO)
48
+ stream_handler = logging.StreamHandler()
49
+ stream_handler.setLevel(logging.INFO)
50
+ stream_handler.setFormatter(logging.Formatter(fmt))
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+ return logger
55
+
56
+
57
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
58
+ l1 = len(speech)
59
+ l2 = len(noise)
60
+ l = min(l1, l2)
61
+ speech = speech[:l]
62
+ noise = noise[:l]
63
+
64
+ # np.float32, value between (-1, 1).
65
+
66
+ speech_power = np.mean(np.square(speech))
67
+ noise_power = speech_power / (10 ** (snr_db / 10))
68
+
69
+ noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
70
+
71
+ noisy_signal = speech + noise_adjusted
72
+
73
+ return noisy_signal
74
+
75
+
76
+ def save_audios(noise_audio: torch.Tensor,
77
+ clean_audio: torch.Tensor,
78
+ noisy_audio: torch.Tensor,
79
+ enhanced_audio: torch.Tensor,
80
+ output_dir: str,
81
+ sample_rate: int = 8000,
82
+ ):
83
+ basename = uuid.uuid4().__str__()
84
+ output_dir = Path(output_dir) / basename
85
+ output_dir.mkdir(parents=True, exist_ok=True)
86
+
87
+ filename = output_dir / "noise_audio.wav"
88
+ torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate, bits_per_sample=16)
89
+ filename = output_dir / "clean_audio.wav"
90
+ torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate, bits_per_sample=16)
91
+ filename = output_dir / "noisy_audio.wav"
92
+ torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate, bits_per_sample=16)
93
+
94
+ filename = output_dir / "enhanced_audio.wav"
95
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate, bits_per_sample=16)
96
+
97
+ return output_dir.as_posix()
98
+
99
+
100
+ def main():
101
+ args = get_args()
102
+
103
+ logger = logging_config()
104
+
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ n_gpu = torch.cuda.device_count()
107
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
108
+
109
+ logger.info("prepare model")
110
+ config = MPNetConfig.from_pretrained(
111
+ pretrained_model_name_or_path=args.model_dir,
112
+ )
113
+ generator = MPNetPretrainedModel.from_pretrained(
114
+ pretrained_model_name_or_path=args.model_dir,
115
+ )
116
+ generator.to(device)
117
+ generator.eval()
118
+
119
+ logger.info("read excel")
120
+ df = pd.read_excel(args.valid_dataset)
121
+
122
+ progress_bar = tqdm(total=len(df), desc="Evaluation")
123
+ for idx, row in df.iterrows():
124
+ noise_filename = row["noise_filename"]
125
+ noise_offset = row["noise_offset"]
126
+ noise_duration = row["noise_duration"]
127
+
128
+ speech_filename = row["speech_filename"]
129
+ speech_offset = row["speech_offset"]
130
+ speech_duration = row["speech_duration"]
131
+
132
+ snr_db = row["snr_db"]
133
+
134
+ noise_audio, _ = librosa.load(
135
+ noise_filename,
136
+ sr=8000,
137
+ offset=noise_offset,
138
+ duration=noise_duration,
139
+ )
140
+ clean_audio, _ = librosa.load(
141
+ speech_filename,
142
+ sr=8000,
143
+ offset=speech_offset,
144
+ duration=speech_duration,
145
+ )
146
+ noisy_audio: np.ndarray = mix_speech_and_noise(
147
+ speech=clean_audio,
148
+ noise=noise_audio,
149
+ snr_db=snr_db,
150
+ )
151
+ noise_audio = torch.tensor(noise_audio, dtype=torch.float32)
152
+ clean_audio = torch.tensor(clean_audio, dtype=torch.float32)
153
+ noisy_audio: torch.Tensor = torch.tensor(noisy_audio, dtype=torch.float32)
154
+
155
+ noise_audio = noise_audio.unsqueeze(dim=0)
156
+ clean_audio = clean_audio.unsqueeze(dim=0)
157
+ noisy_audio: torch.Tensor = noisy_audio.unsqueeze(dim=0)
158
+
159
+ # inference
160
+ clean_audio = clean_audio.to(device)
161
+ noisy_audio = noisy_audio.to(device)
162
+ with torch.no_grad():
163
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
164
+ noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor
165
+ )
166
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
167
+ audio_g = mag_pha_istft(
168
+ mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor
169
+ )
170
+ enhanced_audio = audio_g.detach()
171
+
172
+ save_audios(
173
+ noise_audio, clean_audio, noisy_audio,
174
+ enhanced_audio,
175
+ args.evaluation_audio_dir
176
+ )
177
+
178
+ progress_bar.update(1)
179
+
180
+ if idx > args.limit:
181
+ break
182
+
183
+ return
184
+
185
+
186
+ if __name__ == '__main__':
187
+ main()
examples/nx_mpnet/yaml/config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "mpnet"
2
+
3
+ num_gpus: 0
4
+ batch_size: 3
5
+ learning_rate: 0.0005
6
+ adam_b1: 0.8
7
+ adam_b2: 0.99
8
+ lr_decay: 0.99
9
+ seed: 1234
10
+
11
+ dense_channel: 64
12
+ compress_factor: 0.3
13
+ num_tsconformers: 4
14
+ beta: 2.0
15
+
16
+ sample_rate: 8000
17
+ segment_size: 16000
18
+ n_fft: 512
19
+ hop_size: 80
20
+ win_size: 200
21
+
22
+ num_workers: 4
23
+
24
+ dist_config:
25
+ dist_backend: nccl
26
+ dist_url: tcp://localhost:54321
27
+ world_size: 1
toolbox/torchaudio/models/mpnet/inference_mpnet.py CHANGED
@@ -84,16 +84,18 @@ class InferenceMPNet(object):
84
  enhanced_audio = enhanced_audio[0]
85
  return enhanced_audio
86
 
 
87
  def main():
88
- model_zip_file = project_path / "trained_models/mpnet_aishell_20250221.zip"
89
  infer_mpnet = InferenceMPNet(model_zip_file)
90
 
91
  sample_rate = 8000
92
- noisy_audio_file = project_path / "data/examples/noisy_audio.wav"
93
  noisy_audio, _ = librosa.load(
94
  noisy_audio_file.as_posix(),
95
  sr=sample_rate,
96
  )
 
97
  noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
98
  noisy_audio = noisy_audio.unsqueeze(dim=0)
99
 
 
84
  enhanced_audio = enhanced_audio[0]
85
  return enhanced_audio
86
 
87
+
88
  def main():
89
+ model_zip_file = project_path / "trained_models/mpnet-aishell-1-epoch.zip"
90
  infer_mpnet = InferenceMPNet(model_zip_file)
91
 
92
  sample_rate = 8000
93
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
94
  noisy_audio, _ = librosa.load(
95
  noisy_audio_file.as_posix(),
96
  sr=sample_rate,
97
  )
98
+ noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
99
  noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
100
  noisy_audio = noisy_audio.unsqueeze(dim=0)
101
 
toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py CHANGED
@@ -84,7 +84,7 @@ class CausalConv2d(nn.Module):
84
 
85
  def forward(self,
86
  inputs: torch.Tensor,
87
- causal_cache: torch.Tensor = None,
88
  ):
89
 
90
  if causal_cache is None:
@@ -97,6 +97,8 @@ class CausalConv2d(nn.Module):
97
  # x shape: [batch_size, 1, time_steps2, hidden_size]
98
  # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
99
 
 
 
100
  x = self.conv1.forward(x)
101
  # inputs shape: [batch_size, 1, time_steps, hidden_size]
102
 
@@ -108,8 +110,6 @@ class CausalConv2d(nn.Module):
108
  if self.activation:
109
  x = self.activation(x)
110
 
111
- causal_cache = x[:, :, -self.causal_left_pad:, :]
112
-
113
  # inputs shape: [batch_size, 1, time_steps, hidden_size]
114
  return x, causal_cache
115
 
@@ -187,19 +187,19 @@ class CausalConv2dEncoder(nn.Module):
187
 
188
  def forward_chunk(self,
189
  chunk: torch.Tensor,
190
- causal_cache: torch.Tensor = None,
191
  ):
192
- # causal_cache shape: [self.num_layers, 1, causal_left_pad, hidden_size]
193
 
194
- new_causal_cache_list = list()
195
  for idx, causal_conv in enumerate(self.causal_conv_list):
196
  chunk, new_causal_cache = causal_conv.forward(
197
- inputs=chunk, causal_cache=causal_cache[idx: idx+1] if causal_cache is not None else None
198
  )
 
199
  new_causal_cache_list.append(new_causal_cache)
200
 
201
- new_causal_cache = torch.cat(new_causal_cache_list, dim=0)
202
- return chunk, new_causal_cache
203
 
204
  def forward_chunk_by_chunk(self, inputs: torch.Tensor):
205
  # inputs shape: [batch_size, 1, time_steps, hidden_size]
@@ -207,7 +207,7 @@ class CausalConv2dEncoder(nn.Module):
207
 
208
  batch_size, channels, time_steps, hidden_size = inputs.shape
209
 
210
- causal_cache = None
211
 
212
  outputs = []
213
  for idx in range(0, time_steps, 1):
@@ -215,9 +215,9 @@ class CausalConv2dEncoder(nn.Module):
215
  end = begin + self.total_causal_right_pad + 1
216
  chunk_xs = inputs[:, :, begin:end, :]
217
 
218
- ys, attention_cache = self.forward_chunk(
219
  chunk=chunk_xs,
220
- causal_cache=causal_cache,
221
  )
222
  # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
223
  ys = ys[:, :, :1, :]
 
84
 
85
  def forward(self,
86
  inputs: torch.Tensor,
87
+ causal_cache: List[torch.Tensor] = None,
88
  ):
89
 
90
  if causal_cache is None:
 
97
  # x shape: [batch_size, 1, time_steps2, hidden_size]
98
  # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
99
 
100
+ causal_cache = x[:, :, -self.causal_left_pad:, :]
101
+
102
  x = self.conv1.forward(x)
103
  # inputs shape: [batch_size, 1, time_steps, hidden_size]
104
 
 
110
  if self.activation:
111
  x = self.activation(x)
112
 
 
 
113
  # inputs shape: [batch_size, 1, time_steps, hidden_size]
114
  return x, causal_cache
115
 
 
187
 
188
  def forward_chunk(self,
189
  chunk: torch.Tensor,
190
+ causal_cache: List[torch.Tensor] = None,
191
  ):
192
+ # causal_cache shape: [self.num_layers, batch_size, 1, causal_left_pad, hidden_size]
193
 
194
+ new_causal_cache_list: List[torch.Tensor] = list()
195
  for idx, causal_conv in enumerate(self.causal_conv_list):
196
  chunk, new_causal_cache = causal_conv.forward(
197
+ inputs=chunk, causal_cache=causal_cache[idx] if causal_cache is not None else None
198
  )
199
+ # print(f"idx: {idx}, new_causal_cache: {new_causal_cache.shape}")
200
  new_causal_cache_list.append(new_causal_cache)
201
 
202
+ return chunk, new_causal_cache_list
 
203
 
204
  def forward_chunk_by_chunk(self, inputs: torch.Tensor):
205
  # inputs shape: [batch_size, 1, time_steps, hidden_size]
 
207
 
208
  batch_size, channels, time_steps, hidden_size = inputs.shape
209
 
210
+ new_causal_cache_list: List[torch.Tensor] = None
211
 
212
  outputs = []
213
  for idx in range(0, time_steps, 1):
 
215
  end = begin + self.total_causal_right_pad + 1
216
  chunk_xs = inputs[:, :, begin:end, :]
217
 
218
+ ys, new_causal_cache_list = self.forward_chunk(
219
  chunk=chunk_xs,
220
+ causal_cache=new_causal_cache_list,
221
  )
222
  # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
223
  ys = ys[:, :, :1, :]
toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ from project_settings import project_path
15
+ from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
16
+ from toolbox.torchaudio.models.nx_denoise.modeling_nx_denoise import NXDenoisePretrainedModel, MODEL_FILE
17
+
18
+ logger = logging.getLogger("toolbox")
19
+
20
+
21
+ class InferenceNXDenoise(object):
22
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
23
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
24
+ self.device = torch.device(device)
25
+
26
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
27
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
28
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
29
+
30
+ self.config = config
31
+ self.model = model
32
+ self.model.to(device)
33
+ self.model.eval()
34
+
35
+ def load_models(self, model_path: str):
36
+ model_path = Path(model_path)
37
+ if model_path.name.endswith(".zip"):
38
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
39
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
40
+ out_root.mkdir(parents=True, exist_ok=True)
41
+ f_zip.extractall(path=out_root)
42
+ model_path = out_root / model_path.stem
43
+
44
+ config = NXDenoiseConfig.from_pretrained(
45
+ pretrained_model_name_or_path=model_path.as_posix(),
46
+ )
47
+ model = NXDenoisePretrainedModel.from_pretrained(
48
+ pretrained_model_name_or_path=model_path.as_posix(),
49
+ )
50
+ model.to(self.device)
51
+ model.eval()
52
+
53
+ shutil.rmtree(model_path)
54
+ return config, model
55
+
56
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
57
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
58
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
59
+
60
+ # noisy_audio shape: [batch_size, num_samples]
61
+ noisy_audios = noisy_audio.to(self.device)
62
+
63
+ with torch.no_grad():
64
+ # enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
65
+ enhanced_audios = self.model.forward(noisy_audios)
66
+ # enhanced_audio shape: [batch_size, n_samples]
67
+ # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
68
+
69
+ enhanced_audio = enhanced_audios[0]
70
+ # enhanced_audio shape: [num_samples,]
71
+ return enhanced_audio
72
+
73
+
74
+ def main():
75
+ model_zip_file = project_path / "trained_models/nx-denoise.zip"
76
+ runtime = InferenceNXDenoise(model_zip_file)
77
+
78
+ sample_rate = 8000
79
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
80
+ noisy_audio, _ = librosa.load(
81
+ noisy_audio_file.as_posix(),
82
+ sr=sample_rate,
83
+ )
84
+ noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
85
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
86
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
87
+
88
+ enhanced_audio = runtime.enhancement_by_tensor(noisy_audio)
89
+
90
+ filename = "enhanced_audio.wav"
91
+ torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate)
92
+
93
+ return
94
+
95
+
96
+ if __name__ == '__main__':
97
+ main()
toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py CHANGED
@@ -268,6 +268,58 @@ class NXDenoise(nn.Module):
268
  return enhanced_audios
269
 
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  MODEL_FILE = "generator.pt"
272
 
273
 
 
268
  return enhanced_audios
269
 
270
 
271
+ def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor):
272
+ # noisy_audios shape: [batch_size, n_samples]
273
+ noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
274
+ # noisy_audios shape: [batch_size, 1, n_samples]
275
+
276
+ n_samples = noisy_audios.shape[-1]
277
+ padded_length = get_padding_length(
278
+ n_samples,
279
+ num_layers=self.config.down_sampling_num_layers,
280
+ kernel_size=self.config.down_sampling_kernel_size,
281
+ stride=self.config.down_sampling_stride,
282
+ )
283
+ noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
284
+
285
+ # down sampling
286
+ bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
287
+ # bottle_neck shape: [batch_size, channels, time_steps]
288
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
289
+ # bottle_neck shape: [batch_size, time_steps, channels]
290
+ bottle_neck = torch.unsqueeze(bottle_neck, dim=1)
291
+ # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
292
+
293
+ # causal conv in
294
+ bottle_neck = self.causal_conv_in.forward_chunk_by_chunk(bottle_neck)
295
+ # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
296
+
297
+ # ts transformer
298
+ # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
299
+ bottle_neck = self.ts_transformer.forward_chunk_by_chunk(bottle_neck)
300
+ # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
301
+
302
+ # causal conv out
303
+ bottle_neck = self.causal_conv_out.forward_chunk_by_chunk(bottle_neck)
304
+ # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
305
+
306
+ # up sampling
307
+ bottle_neck = torch.squeeze(bottle_neck, dim=1)
308
+ # bottle_neck shape: [batch_size, time_steps, channels]
309
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
310
+ # bottle_neck shape: [batch_size, channels, time_steps]
311
+
312
+ enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
313
+
314
+ enhanced_audios = enhanced_audios[:, :, :n_samples]
315
+ # enhanced_audios shape: [batch_size, 1, n_samples]
316
+
317
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
318
+ # enhanced_audios shape: [batch_size, n_samples]
319
+
320
+ return enhanced_audios
321
+
322
+
323
  MODEL_FILE = "generator.pt"
324
 
325
 
toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_denoise/stftnet/istftnet.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/2203.02395
5
+ """
6
+
7
+
8
+ if __name__ == '__main__':
9
+ pass
toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/1902.07849
5
+ """
6
+
7
+
8
+ if __name__ == '__main__':
9
+ pass
toolbox/torchaudio/models/nx_mpnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import List, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid2d
9
+
10
+
11
+ class SPConvTranspose2d(nn.Module):
12
+ def __init__(self,
13
+ in_channels: int,
14
+ out_channels: int,
15
+ kernel_size: Union[int, Tuple[int]],
16
+ r=1
17
+ ):
18
+ super(SPConvTranspose2d, self).__init__()
19
+ self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
20
+ self.out_channels = out_channels
21
+ self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
22
+ self.r = r
23
+
24
+ def forward(self, x: torch.Tensor):
25
+ x = self.pad_freq(x)
26
+ out = self.conv(x)
27
+
28
+ b, c, t, f = out.shape
29
+
30
+ out = out.view((b, self.r, c // self.r, t, f))
31
+ out = out.permute(0, 2, 3, 4, 1)
32
+ out = out.contiguous().view((b, c // self.r, t, -1))
33
+ return out
34
+
35
+
36
+ class CausalConv2dBlock(nn.Module):
37
+ def __init__(self,
38
+ in_channels: int,
39
+ out_channels: int,
40
+ dilation: int,
41
+ kernel_size: Tuple[int, int] = (2, 3),
42
+ ):
43
+ super(CausalConv2dBlock, self).__init__()
44
+ self.pad_length = dilation
45
+
46
+ self.pad_time = nn.ConstantPad2d((0, 0, self.pad_length, 0), value=0.)
47
+ self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
48
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=(dilation, 1))
49
+ self.norm = nn.InstanceNorm2d(out_channels, affine=True)
50
+ self.activation = nn.PReLU(out_channels)
51
+
52
+ def forward(self,
53
+ x: torch.Tensor,
54
+ cache_pad: torch.Tensor = None,
55
+ ):
56
+ """
57
+
58
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
59
+ :param cache_pad:
60
+ :return:
61
+ """
62
+ if cache_pad is None:
63
+ x = self.pad_time(x)
64
+ else:
65
+ x = torch.concat(tensors=[cache_pad, x], dim=2)
66
+ new_cache_pad = x[:, :, -self.pad_length:, :]
67
+
68
+ x = self.pad_freq(x)
69
+
70
+ x = self.conv(x)
71
+ x = self.norm(x)
72
+ x = self.activation(x)
73
+ return x, new_cache_pad
74
+
75
+
76
+ class CausalConv2dEncoder(nn.Module):
77
+ def __init__(self,
78
+ num_blocks: int,
79
+ hidden_size: int,
80
+ ):
81
+ super(CausalConv2dEncoder, self).__init__()
82
+ self.num_blocks = num_blocks
83
+
84
+ self.blocks: List[CausalConv2dBlock] = nn.ModuleList([])
85
+ for idx in range(num_blocks):
86
+ in_channels = hidden_size * (idx+1)
87
+ dilation = 2 ** idx
88
+ block = CausalConv2dBlock(
89
+ in_channels=in_channels,
90
+ out_channels=hidden_size,
91
+ dilation=dilation,
92
+ kernel_size=(2, 3),
93
+ )
94
+ self.blocks.append(block)
95
+
96
+ def forward(self,
97
+ x: torch.Tensor,
98
+ cache_pad_list: List[torch.Tensor] = None,
99
+ ):
100
+ """
101
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
102
+ :param cache_pad_list: List[Tensor]
103
+ :return:
104
+ """
105
+ new_cache_pad_list = list()
106
+
107
+ skip = x
108
+ for idx, block in enumerate(self.blocks):
109
+ x, new_cache_pad = block.forward(
110
+ skip,
111
+ cache_pad=None if cache_pad_list is None else cache_pad_list[idx]
112
+ )
113
+ new_cache_pad_list.append(new_cache_pad)
114
+ skip = torch.cat([x, skip], dim=1)
115
+ # x shape: [batch_size, channels, time_steps, dim].
116
+ return x, new_cache_pad_list
117
+
118
+ def forward_chunk(self,
119
+ chunk: torch.Tensor,
120
+ cache_pad_list: List[torch.Tensor] = None,
121
+ ):
122
+ return self.forward(chunk, cache_pad_list)
123
+
124
+ def forward_chunk_by_chunk(self,
125
+ x: torch.Tensor,
126
+ ):
127
+ """
128
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
129
+ :return:
130
+ """
131
+ batch_size, channels, time_steps, _ = x.shape
132
+
133
+ cache_pad_list = None
134
+
135
+ outputs = list()
136
+ for idx in range(time_steps):
137
+ chunk = x[:, :, idx:idx+1, :]
138
+
139
+ y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
140
+ outputs.append(y)
141
+
142
+ outputs = torch.concat(outputs, dim=2)
143
+ return outputs
144
+
145
+
146
+ class DenseEncoder(nn.Module):
147
+ def __init__(self,
148
+ num_blocks: int,
149
+ in_channels: int,
150
+ out_channels: int,
151
+ ):
152
+ super(DenseEncoder, self).__init__()
153
+ self.dense_conv_1 = nn.Sequential(
154
+ nn.Conv2d(in_channels, out_channels, (1, 1)),
155
+ nn.InstanceNorm2d(out_channels, affine=True),
156
+ nn.PReLU(out_channels)
157
+ )
158
+ self.dense_block = CausalConv2dEncoder(
159
+ num_blocks=num_blocks, hidden_size=out_channels,
160
+ )
161
+ self.dense_conv_2 = nn.Sequential(
162
+ nn.Conv2d(out_channels, out_channels, (1, 3), (1, 2), padding=(0, 1)),
163
+ nn.InstanceNorm2d(out_channels, affine=True),
164
+ nn.PReLU(out_channels)
165
+ )
166
+
167
+ def forward(self,
168
+ x: torch.Tensor,
169
+ ):
170
+ """
171
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
172
+ :return:
173
+ """
174
+ x = self.dense_conv_1(x)
175
+ x, _ = self.dense_block.forward(x)
176
+ x = self.dense_conv_2(x)
177
+ # x shape: [b, c, t, f//2]
178
+ return x
179
+
180
+ def forward_chunk(self,
181
+ x: torch.Tensor,
182
+ cache_pad_list: List[torch.Tensor] = None,
183
+ ):
184
+ x = self.dense_conv_1(x)
185
+ x, new_cache_pad_list = self.dense_block.forward(x, cache_pad_list)
186
+ x = self.dense_conv_2(x)
187
+ # x shape: [b, c, t, f//2]
188
+ return x, new_cache_pad_list
189
+
190
+ def forward_chunk_by_chunk(self,
191
+ x: torch.Tensor,
192
+ ):
193
+ """
194
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
195
+ :return:
196
+ """
197
+ batch_size, channels, time_steps, _ = x.shape
198
+
199
+ cache_pad_list = None
200
+
201
+ outputs = list()
202
+ for idx in range(time_steps):
203
+ chunk = x[:, :, idx:idx+1, :]
204
+
205
+ y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
206
+ outputs.append(y)
207
+
208
+ outputs = torch.concat(outputs, dim=2)
209
+ return outputs
210
+
211
+
212
+ class MaskDecoder(nn.Module):
213
+ def __init__(self,
214
+ num_blocks: int,
215
+ hidden_size: int,
216
+ out_channels: int = 1,
217
+ beta: float = 2.0,
218
+ n_fft: int = 512,
219
+ ):
220
+ super(MaskDecoder, self).__init__()
221
+ self.dense_block = CausalConv2dEncoder(
222
+ num_blocks=num_blocks, hidden_size=hidden_size,
223
+ )
224
+ self.mask_conv = nn.Sequential(
225
+ SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2),
226
+ nn.InstanceNorm2d(hidden_size, affine=True),
227
+ nn.PReLU(hidden_size),
228
+ nn.Conv2d(hidden_size, out_channels, (1, 2))
229
+ )
230
+ self.lsigmoid = LearnableSigmoid2d(n_fft//2+1, beta=beta)
231
+
232
+ def forward(self,
233
+ x: torch.Tensor,
234
+ ):
235
+ """
236
+
237
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
238
+ :return:
239
+ """
240
+ x, _ = self.dense_block(x)
241
+ x = self.mask_conv(x)
242
+ # x shape: [batch_size, 1, time_steps, dim*2-1]
243
+ x = x.permute(0, 3, 2, 1).squeeze(-1)
244
+ # x shape: [b, f, t]
245
+ x = self.lsigmoid(x)
246
+ return x
247
+
248
+ def forward_chunk(self,
249
+ x: torch.Tensor,
250
+ cache_pad_list: List[torch.Tensor] = None,
251
+ ):
252
+ x, new_cache_pad_list = self.dense_block(x, cache_pad_list)
253
+ x = self.mask_conv(x)
254
+ # x shape: [batch_size, 1, time_steps, dim*2-1]
255
+ x = x.permute(0, 3, 2, 1).squeeze(-1)
256
+ # x shape: [b, f, t]
257
+ x = self.lsigmoid(x)
258
+ return x, new_cache_pad_list
259
+
260
+ def forward_chunk_by_chunk(self,
261
+ x: torch.Tensor,
262
+ ):
263
+ """
264
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
265
+ :return:
266
+ """
267
+ batch_size, channels, time_steps, _ = x.shape
268
+
269
+ cache_pad_list = None
270
+
271
+ outputs = list()
272
+ for idx in range(time_steps):
273
+ chunk = x[:, :, idx:idx+1, :]
274
+
275
+ y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
276
+ outputs.append(y)
277
+
278
+ outputs = torch.concat(outputs, dim=2)
279
+ return outputs
280
+
281
+
282
+ class PhaseDecoder(nn.Module):
283
+ def __init__(self,
284
+ num_blocks: int,
285
+ hidden_size: int,
286
+ out_channels: int = 1,
287
+ ):
288
+ super(PhaseDecoder, self).__init__()
289
+ self.dense_block = CausalConv2dEncoder(
290
+ num_blocks=num_blocks, hidden_size=hidden_size,
291
+ )
292
+
293
+ self.phase_conv = nn.Sequential(
294
+ SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2),
295
+ nn.InstanceNorm2d(hidden_size, affine=True),
296
+ nn.PReLU(hidden_size)
297
+ )
298
+ self.phase_conv_r = nn.Conv2d(hidden_size, out_channels, (1, 2))
299
+ self.phase_conv_i = nn.Conv2d(hidden_size, out_channels, (1, 2))
300
+
301
+ def forward(self,
302
+ x: torch.Tensor,
303
+ ):
304
+ """
305
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
306
+ :return:
307
+ """
308
+ x, _ = self.dense_block(x)
309
+
310
+ x = self.phase_conv(x)
311
+ x_r = self.phase_conv_r(x)
312
+ x_i = self.phase_conv_i(x)
313
+ x = torch.atan2(x_i, x_r)
314
+ x = x.permute(0, 3, 2, 1).squeeze(-1)
315
+ # x shape: [b, f, t]
316
+ return x
317
+
318
+ def forward_chunk(self,
319
+ x: torch.Tensor,
320
+ cache_pad_list: List[torch.Tensor] = None,
321
+ ):
322
+ x, new_cache_pad_list = self.dense_block(x, cache_pad_list)
323
+
324
+ x = self.phase_conv(x)
325
+ x_r = self.phase_conv_r(x)
326
+ x_i = self.phase_conv_i(x)
327
+ x = torch.atan2(x_i, x_r)
328
+ x = x.permute(0, 3, 2, 1).squeeze(-1)
329
+ # x shape: [b, f, t]
330
+ return x, new_cache_pad_list
331
+
332
+ def forward_chunk_by_chunk(self,
333
+ x: torch.Tensor,
334
+ ):
335
+ """
336
+ :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
337
+ :return:
338
+ """
339
+ batch_size, channels, time_steps, _ = x.shape
340
+
341
+ cache_pad_list = None
342
+
343
+ outputs = list()
344
+ for idx in range(time_steps):
345
+ chunk = x[:, :, idx:idx+1, :]
346
+
347
+ y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
348
+ outputs.append(y)
349
+
350
+ outputs = torch.concat(outputs, dim=2)
351
+ return outputs
352
+
353
+
354
+ def main1():
355
+
356
+ encoder = CausalConv2dEncoder(
357
+ num_blocks=3, hidden_size=8,
358
+ )
359
+
360
+ # x shape: [batch_size, channels, time_steps, dim]
361
+ x = torch.rand(size=(1, 8, 200, 32))
362
+ x, new_cache_pad_list = encoder.forward(x)
363
+ print(x.shape)
364
+ for new_cache_pad in new_cache_pad_list:
365
+ print(new_cache_pad.shape)
366
+
367
+ x = torch.rand(size=(1, 8, 200, 32))
368
+ x = encoder.forward_chunk_by_chunk(x)
369
+ print(x.shape)
370
+
371
+ return
372
+
373
+
374
+ def main2():
375
+
376
+ encoder = DenseEncoder(
377
+ num_blocks=3, in_channels=8, out_channels=8
378
+ )
379
+
380
+ # x shape: [batch_size, channels, time_steps, dim]
381
+ x = torch.rand(size=(1, 8, 200, 32))
382
+ x, new_cache_pad_list = encoder.forward(x)
383
+ print(x.shape)
384
+ for new_cache_pad in new_cache_pad_list:
385
+ print(new_cache_pad.shape)
386
+
387
+ x = torch.rand(size=(1, 8, 200, 32))
388
+ x = encoder.forward_chunk_by_chunk(x)
389
+ print(x.shape)
390
+
391
+ return
392
+
393
+
394
+ def main3():
395
+
396
+ encoder = MaskDecoder(
397
+ num_blocks=3, hidden_size=64, out_channels=1,
398
+ n_fft=512,
399
+ )
400
+
401
+ # 512 // 2 + 1 = 257
402
+ # 129 * 2 - 1 = 257
403
+ # 257 // 2 + 1 = 129
404
+
405
+ # x shape: [batch_size, channels, time_steps, dim]
406
+ x = torch.rand(size=(1, 64, 201, 129))
407
+ x, new_cache_pad_list = encoder.forward(x)
408
+ print(x.shape)
409
+ for new_cache_pad in new_cache_pad_list:
410
+ print(new_cache_pad.shape)
411
+
412
+ x = torch.rand(size=(1, 64, 201, 129))
413
+ x = encoder.forward_chunk_by_chunk(x)
414
+ print(x.shape)
415
+
416
+ return
417
+
418
+
419
+
420
+ def main():
421
+
422
+ encoder = PhaseDecoder(
423
+ num_blocks=3, hidden_size=64, out_channels=1,
424
+ )
425
+
426
+ # 512 // 2 + 1 = 257
427
+ # 129 * 2 - 1 = 257
428
+ # 257 // 2 + 1 = 129
429
+
430
+ # x shape: [batch_size, channels, time_steps, dim]
431
+ x = torch.rand(size=(1, 64, 201, 129))
432
+ x, new_cache_pad_list = encoder.forward(x)
433
+ print(x.shape)
434
+ for new_cache_pad in new_cache_pad_list:
435
+ print(new_cache_pad.shape)
436
+
437
+ x = torch.rand(size=(1, 64, 201, 129))
438
+ x = encoder.forward_chunk_by_chunk(x)
439
+ print(x.shape)
440
+
441
+ return
442
+
443
+
444
+ if __name__ == "__main__":
445
+ main()
toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class NXMPNetConfig(PretrainedConfig):
7
+ """
8
+ https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
+ """
10
+ def __init__(self,
11
+ sample_rate: int = 8000,
12
+ segment_size: int = 16000,
13
+ n_fft: int = 512,
14
+ win_length: int = 200,
15
+ hop_length: int = 80,
16
+
17
+ dense_num_blocks: int = 4,
18
+ dense_hidden_size: int = 64,
19
+
20
+ mask_num_blocks: int = 4,
21
+ mask_hidden_size: int = 64,
22
+
23
+ phase_num_blocks: int = 4,
24
+ phase_hidden_size: int = 64,
25
+
26
+ tsfm_hidden_size: int = 64,
27
+ tsfm_attention_heads: int = 4,
28
+ tsfm_num_blocks: int = 4,
29
+ tsfm_dropout_rate: float = 0.0,
30
+ tsfm_max_time_relative_position: int = 1024,
31
+ tsfm_max_freq_relative_position: int = 128,
32
+ tsfm_chunk_size: int = 4,
33
+ tsfm_num_left_chunks: int = 128,
34
+ tsfm_num_right_chunks: int = 2,
35
+
36
+ discriminator_dim: int = 32,
37
+ discriminator_in_channel: int = 2,
38
+
39
+ compress_factor: float = 0.3,
40
+
41
+ batch_size: int = 4,
42
+ learning_rate: float = 0.0005,
43
+ adam_b1: float = 0.8,
44
+ adam_b2: float = 0.99,
45
+ lr_decay: float = 0.99,
46
+ seed: int = 1234,
47
+
48
+ **kwargs
49
+ ):
50
+ super(NXMPNetConfig, self).__init__(**kwargs)
51
+ self.sample_rate = sample_rate
52
+ self.segment_size = segment_size
53
+ self.n_fft = n_fft
54
+ self.win_length = win_length
55
+ self.hop_length = hop_length
56
+
57
+ self.dense_num_blocks = dense_num_blocks
58
+ self.dense_hidden_size = dense_hidden_size
59
+
60
+ self.mask_num_blocks = mask_num_blocks
61
+ self.mask_hidden_size = mask_hidden_size
62
+
63
+ self.phase_num_blocks = phase_num_blocks
64
+ self.phase_hidden_size = phase_hidden_size
65
+
66
+ self.tsfm_hidden_size = tsfm_hidden_size
67
+ self.tsfm_attention_heads = tsfm_attention_heads
68
+ self.tsfm_num_blocks = tsfm_num_blocks
69
+ self.tsfm_dropout_rate = tsfm_dropout_rate
70
+ self.tsfm_max_time_relative_position = tsfm_max_time_relative_position
71
+ self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position
72
+ self.tsfm_chunk_size = tsfm_chunk_size
73
+ self.tsfm_num_left_chunks = tsfm_num_left_chunks
74
+ self.tsfm_num_right_chunks = tsfm_num_right_chunks
75
+
76
+ self.discriminator_dim = discriminator_dim
77
+ self.discriminator_in_channel = discriminator_in_channel
78
+
79
+ self.compress_factor = compress_factor
80
+
81
+ self.batch_size = batch_size
82
+ self.learning_rate = learning_rate
83
+ self.adam_b1 = adam_b1
84
+ self.adam_b2 = adam_b2
85
+ self.lr_decay = lr_decay
86
+ self.seed = seed
87
+
88
+
89
+ if __name__ == '__main__':
90
+ pass
toolbox/torchaudio/models/nx_mpnet/discriminator.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchaudio
9
+
10
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
11
+ from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
12
+ from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d
13
+
14
+
15
+ class MetricDiscriminator(nn.Module):
16
+ def __init__(self, config: NXMPNetConfig):
17
+ super(MetricDiscriminator, self).__init__()
18
+ dim = config.discriminator_dim
19
+ self.in_channel = config.discriminator_in_channel
20
+
21
+ self.n_fft = config.n_fft
22
+ self.win_length = config.win_length
23
+ self.hop_length = config.hop_length
24
+
25
+ self.transform = torchaudio.transforms.Spectrogram(
26
+ n_fft=self.n_fft,
27
+ win_length=self.win_length,
28
+ hop_length=self.hop_length,
29
+ power=1.0,
30
+ window_fn=torch.hann_window,
31
+ # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
32
+ )
33
+
34
+ self.layers = nn.Sequential(
35
+ nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
36
+ nn.InstanceNorm2d(dim, affine=True),
37
+ nn.PReLU(dim),
38
+ nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
39
+ nn.InstanceNorm2d(dim*2, affine=True),
40
+ nn.PReLU(dim*2),
41
+ nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
42
+ nn.InstanceNorm2d(dim*4, affine=True),
43
+ nn.PReLU(dim*4),
44
+ nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
45
+ nn.InstanceNorm2d(dim*8, affine=True),
46
+ nn.PReLU(dim*8),
47
+ nn.AdaptiveMaxPool2d(1),
48
+ nn.Flatten(),
49
+ nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
50
+ nn.Dropout(0.3),
51
+ nn.PReLU(dim*4),
52
+ nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
53
+ LearnableSigmoid1d(1)
54
+ )
55
+
56
+ def forward(self, x, y):
57
+ x = self.transform.forward(x)
58
+ y = self.transform.forward(y)
59
+
60
+ xy = torch.stack((x, y), dim=1)
61
+ return self.layers(xy)
62
+
63
+
64
+ MODEL_FILE = "discriminator.pt"
65
+
66
+
67
+ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
68
+ def __init__(self,
69
+ config: NXMPNetConfig,
70
+ ):
71
+ super(MetricDiscriminatorPretrainedModel, self).__init__(
72
+ config=config,
73
+ )
74
+ self.config = config
75
+
76
+ @classmethod
77
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
78
+ config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
79
+
80
+ model = cls(config)
81
+
82
+ if os.path.isdir(pretrained_model_name_or_path):
83
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
84
+ else:
85
+ ckpt_file = pretrained_model_name_or_path
86
+
87
+ with open(ckpt_file, "rb") as f:
88
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
89
+ model.load_state_dict(state_dict, strict=True)
90
+ return model
91
+
92
+ def save_pretrained(self,
93
+ save_directory: Union[str, os.PathLike],
94
+ state_dict: Optional[dict] = None,
95
+ ):
96
+
97
+ model = self
98
+
99
+ if state_dict is None:
100
+ state_dict = model.state_dict()
101
+
102
+ os.makedirs(save_directory, exist_ok=True)
103
+
104
+ # save state dict
105
+ model_file = os.path.join(save_directory, MODEL_FILE)
106
+ torch.save(state_dict, model_file)
107
+
108
+ # save config
109
+ config_file = os.path.join(save_directory, CONFIG_FILE)
110
+ self.config.to_yaml_file(config_file)
111
+ return save_directory
112
+
113
+
114
+ def main():
115
+ config = NXMPNetConfig()
116
+ discriminator = MetricDiscriminator(config=config)
117
+
118
+ # shape: [batch_size, num_samples]
119
+ # x = torch.ones([4, int(4.5 * 16000)])
120
+ # y = torch.ones([4, int(4.5 * 16000)])
121
+ x = torch.ones([4, 16000])
122
+ y = torch.ones([4, 16000])
123
+
124
+ output = discriminator.forward(x, y)
125
+ print(output.shape)
126
+ print(output)
127
+
128
+ return
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
toolbox/torchaudio/models/nx_mpnet/loss.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def anti_wrapping_function(x):
8
+
9
+ return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
10
+
11
+
12
+ def phase_losses(phase_r, phase_g):
13
+
14
+ ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
15
+ gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
16
+ iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
17
+
18
+ return ip_loss, gd_loss, iaf_loss
19
+
20
+
21
+ if __name__ == '__main__':
22
+ pass
toolbox/torchaudio/models/nx_mpnet/metrics.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from joblib import Parallel, delayed
4
+ import numpy as np
5
+ from pesq import pesq
6
+ from typing import List
7
+
8
+ from pesq import cypesq
9
+
10
+
11
+ def run_pesq(clean_audio: np.ndarray,
12
+ noisy_audio: np.ndarray,
13
+ sample_rate: int = 16000,
14
+ mode: str = "wb",
15
+ ) -> float:
16
+ if sample_rate == 8000 and mode == "wb":
17
+ raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
+ try:
19
+ pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
+ except cypesq.NoUtterancesError as e:
21
+ pesq_score = -1
22
+ except Exception as e:
23
+ print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
24
+ pesq_score = -1
25
+ return pesq_score
26
+
27
+
28
+ def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
+ noisy_audio_list: List[np.ndarray],
30
+ sample_rate: int = 16000,
31
+ mode: str = "wb",
32
+ n_jobs: int = 4,
33
+ ) -> List[float]:
34
+ parallel = Parallel(n_jobs=n_jobs)
35
+
36
+ parallel_tasks = list()
37
+ for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
+ parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
+ parallel_tasks.append(parallel_task)
40
+
41
+ pesq_score_list = parallel.__call__(parallel_tasks)
42
+ return pesq_score_list
43
+
44
+
45
+ def run_pesq_score(clean_audio_list: List[np.ndarray],
46
+ noisy_audio_list: List[np.ndarray],
47
+ sample_rate: int = 16000,
48
+ mode: str = "wb",
49
+ n_jobs: int = 4,
50
+ ) -> List[float]:
51
+
52
+ pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
+ noisy_audio_list=noisy_audio_list,
54
+ sample_rate=sample_rate,
55
+ mode=mode,
56
+ n_jobs=n_jobs,
57
+ )
58
+
59
+ pesq_score = np.mean(pesq_score_list)
60
+ return pesq_score
61
+
62
+
63
+ def main():
64
+ clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
+ noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
+
67
+ clean_audio_list = list(clean_audio)
68
+ noisy_audio_list = list(noisy_audio)
69
+
70
+ pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
+ print(pesq_score_list)
72
+
73
+ pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
+ print(pesq_score)
75
+
76
+ return
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
toolbox/torchaudio/models/nx_mpnet/modeling_nx_mpnet.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
11
+ from toolbox.torchaudio.models.nx_mpnet.causal_convolution.causal_conv2d import DenseEncoder, MaskDecoder, PhaseDecoder
12
+ from toolbox.torchaudio.models.nx_mpnet.transformers.transformers import TSTransformerEncoder
13
+ from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
14
+
15
+
16
+ class NXMPNet(nn.Module):
17
+ def __init__(self,
18
+ config: NXMPNetConfig,
19
+ ):
20
+ super(NXMPNet, self).__init__()
21
+ self.dense_encoder = DenseEncoder(
22
+ num_blocks=config.dense_num_blocks,
23
+ in_channels=2,
24
+ out_channels=config.dense_hidden_size,
25
+ )
26
+ self.ts_transformer = TSTransformerEncoder(
27
+ input_size=config.dense_hidden_size,
28
+ hidden_size=config.tsfm_hidden_size,
29
+ attention_heads=config.tsfm_attention_heads,
30
+ num_blocks=config.tsfm_num_blocks,
31
+ dropout_rate=config.tsfm_dropout_rate,
32
+ max_time_relative_position=config.tsfm_max_time_relative_position,
33
+ max_freq_relative_position=config.tsfm_max_freq_relative_position,
34
+ chunk_size=config.tsfm_chunk_size,
35
+ num_left_chunks=config.tsfm_num_left_chunks,
36
+ num_right_chunks=config.tsfm_num_right_chunks,
37
+ )
38
+ self.mask_decoder = MaskDecoder(
39
+ num_blocks=config.mask_num_blocks,
40
+ hidden_size=config.mask_hidden_size,
41
+ out_channels=1,
42
+ n_fft=config.n_fft,
43
+ )
44
+ self.phase_decoder = PhaseDecoder(
45
+ num_blocks=config.phase_num_blocks,
46
+ hidden_size=config.phase_hidden_size,
47
+ out_channels=1,
48
+ )
49
+
50
+ def forward(self, noisy_amp, noisy_pha):
51
+ """
52
+ :param noisy_amp: Tensor, shape: [b, f, t]
53
+ :param noisy_pha: Tensor, shape: [b, f, t]
54
+ :return:
55
+ """
56
+ x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F]
57
+ # x shape: [b, 2, t, f]
58
+ x = self.dense_encoder.forward(x)
59
+ # x shape: [b, c, t, f//2]
60
+
61
+ x = self.ts_transformer.forward(x)
62
+ # x shape: [b, c, t, f//2]
63
+
64
+ denoised_amp = noisy_amp * self.mask_decoder(x)
65
+ denoised_pha = self.phase_decoder(x)
66
+ denoised_com = torch.stack(
67
+ tensors=(
68
+ denoised_amp * torch.cos(denoised_pha),
69
+ denoised_amp * torch.sin(denoised_pha)
70
+ ),
71
+ dim=-1
72
+ )
73
+
74
+ return denoised_amp, denoised_pha, denoised_com
75
+
76
+
77
+ MODEL_FILE = "generator.pt"
78
+
79
+
80
+ class NXMPNetPretrainedModel(NXMPNet):
81
+ def __init__(self,
82
+ config: NXMPNetConfig,
83
+ ):
84
+ super(NXMPNetPretrainedModel, self).__init__(
85
+ config=config,
86
+ )
87
+ self.config = config
88
+
89
+ @classmethod
90
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
91
+ config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
92
+
93
+ model = cls(config)
94
+
95
+ if os.path.isdir(pretrained_model_name_or_path):
96
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
97
+ else:
98
+ ckpt_file = pretrained_model_name_or_path
99
+
100
+ with open(ckpt_file, "rb") as f:
101
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
102
+ model.load_state_dict(state_dict, strict=True)
103
+ return model
104
+
105
+ def save_pretrained(self,
106
+ save_directory: Union[str, os.PathLike],
107
+ state_dict: Optional[dict] = None,
108
+ ):
109
+
110
+ model = self
111
+
112
+ if state_dict is None:
113
+ state_dict = model.state_dict()
114
+
115
+ os.makedirs(save_directory, exist_ok=True)
116
+
117
+ # save state dict
118
+ model_file = os.path.join(save_directory, MODEL_FILE)
119
+ torch.save(state_dict, model_file)
120
+
121
+ # save config
122
+ config_file = os.path.join(save_directory, CONFIG_FILE)
123
+ self.config.to_yaml_file(config_file)
124
+ return save_directory
125
+
126
+
127
+ def main():
128
+ config = NXMPNetConfig()
129
+
130
+ model = NXMPNet(config)
131
+
132
+ noisy_amp = torch.rand([1, 257, 201], dtype=torch.float32)
133
+ noisy_pha = torch.rand([1, 257, 201], dtype=torch.float32)
134
+
135
+ denoised_amp, denoised_pha, denoised_com = model.forward(noisy_amp, noisy_pha)
136
+ print(denoised_amp.shape)
137
+ print(denoised_pha.shape)
138
+ print(denoised_com.shape)
139
+ return
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()
toolbox/torchaudio/models/nx_mpnet/transformers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_mpnet/transformers/attention.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class MultiHeadSelfAttention(nn.Module):
11
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
12
+ """
13
+ :param n_head: int. the number of heads.
14
+ :param n_feat: int. the number of features.
15
+ :param dropout_rate: float. dropout rate.
16
+ """
17
+ super().__init__()
18
+ assert n_feat % n_head == 0
19
+ # We assume d_v always equals d_k
20
+ self.d_k = n_feat // n_head
21
+ self.h = n_head
22
+ self.linear_q = nn.Linear(n_feat, n_feat)
23
+ self.linear_k = nn.Linear(n_feat, n_feat)
24
+ self.linear_v = nn.Linear(n_feat, n_feat)
25
+ self.linear_out = nn.Linear(n_feat, n_feat)
26
+ self.dropout = nn.Dropout(p=dropout_rate)
27
+
28
+ def forward_qkv(self,
29
+ query: torch.Tensor,
30
+ key: torch.Tensor,
31
+ value: torch.Tensor
32
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
33
+ """
34
+ transform query, key and value.
35
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
36
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
37
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
38
+ :return:
39
+ """
40
+ n_batch = query.size(0)
41
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
42
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
43
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
44
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
45
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
46
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
47
+
48
+ return q, k, v
49
+
50
+ def forward_attention(self,
51
+ value: torch.Tensor,
52
+ scores: torch.Tensor,
53
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
54
+ ) -> torch.Tensor:
55
+ """
56
+ compute attention context vector.
57
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
58
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
59
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
60
+ (batch_size, time1, time2), (0, 0, 0) means fake mask.
61
+ :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
62
+ weighted by the attention score (batch_size, time1, time2).
63
+ """
64
+ n_batch = value.size(0)
65
+ # NOTE: When will `if mask.size(2) > 0` be True?
66
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
67
+ # 1st chunk to ease the onnx export.]
68
+ # 2. pytorch training
69
+ if mask.size(2) > 0: # time2 > 0
70
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
71
+ # For last chunk, time2 might be larger than scores.size(-1)
72
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
73
+ scores = scores.masked_fill(mask, -float('inf'))
74
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
75
+
76
+ # NOTE: When will `if mask.size(2) > 0` be False?
77
+ # 1. onnx(16/-1, -1/-1, 16/0)
78
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
79
+ else:
80
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
81
+
82
+ p_attn = self.dropout(attn)
83
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
84
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
85
+
86
+ return self.linear_out(x) # (batch, time1, n_feat)
87
+
88
+ def forward(self,
89
+ x: torch.Tensor,
90
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
91
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
92
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
93
+
94
+ q, k, v = self.forward_qkv(x, x, x)
95
+
96
+ if cache.size(0) > 0:
97
+ key_cache, value_cache = torch.split(
98
+ cache, cache.size(-1) // 2, dim=-1)
99
+ k = torch.cat([key_cache, k], dim=2)
100
+ v = torch.cat([value_cache, v], dim=2)
101
+ # NOTE: We do cache slicing in encoder.forward_chunk, since it's
102
+ # non-trivial to calculate `next_cache_start` here.
103
+ new_cache = torch.cat((k, v), dim=-1)
104
+
105
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
106
+ return self.forward_attention(v, scores, mask), new_cache
107
+
108
+
109
+ class RelativeMultiHeadSelfAttention(nn.Module):
110
+
111
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
112
+ """
113
+ :param n_head: int. the number of heads.
114
+ :param n_feat: int. the number of features.
115
+ :param dropout_rate: float. dropout rate.
116
+ :param max_relative_position: int. maximum relative position for relative position encoding.
117
+ """
118
+ super().__init__()
119
+ assert n_feat % n_head == 0
120
+ # We assume d_v always equals d_k
121
+ self.d_k = n_feat // n_head
122
+ self.h = n_head
123
+ self.linear_q = nn.Linear(n_feat, n_feat)
124
+ self.linear_k = nn.Linear(n_feat, n_feat)
125
+ self.linear_v = nn.Linear(n_feat, n_feat)
126
+ self.linear_out = nn.Linear(n_feat, n_feat)
127
+ self.dropout = nn.Dropout(p=dropout_rate)
128
+
129
+ # Relative position encoding
130
+ self.max_relative_position = max_relative_position
131
+ self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
132
+
133
+ def forward_qkv(self,
134
+ query: torch.Tensor,
135
+ key: torch.Tensor,
136
+ value: torch.Tensor
137
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
138
+ """
139
+ transform query, key and value.
140
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
141
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
142
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
143
+ :return:
144
+ """
145
+ n_batch = query.size(0)
146
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
147
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
148
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
149
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
150
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
151
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
152
+
153
+ return q, k, v
154
+
155
+ def forward_attention(self,
156
+ value: torch.Tensor,
157
+ scores: torch.Tensor,
158
+ mask: torch.Tensor = None
159
+ ) -> torch.Tensor:
160
+ """
161
+ compute attention context vector.
162
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
163
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
164
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
165
+ :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
166
+ weighted by the attention score (batch_size, query_time_steps, key_time_steps).
167
+ """
168
+ n_batch = value.size(0)
169
+ if mask is not None:
170
+ mask = mask.unsqueeze(1).eq(0)
171
+ # mask shape: [batch_size, 1, query_time_steps, key_time_steps]
172
+ scores = scores.masked_fill(mask, -float('inf'))
173
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
174
+ else:
175
+ attn = torch.softmax(scores, dim=-1)
176
+ # attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
177
+
178
+ p_attn = self.dropout(attn)
179
+
180
+ x = torch.matmul(p_attn, value)
181
+ # x shape: [batch_size, n_head, query_time_steps, d_k]
182
+ x = x.transpose(1, 2)
183
+ # x shape: [batch_size, query_time_steps, n_head, d_k]
184
+
185
+ x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
186
+ # x shape: [batch_size, query_time_steps, n_head * d_k]
187
+ # x shape: [batch_size, query_time_steps, n_feat]
188
+
189
+ x = self.linear_out(x)
190
+ # x shape: [batch_size, query_time_steps, n_feat]
191
+ return x
192
+
193
+ def relative_position_encoding(self, length: int) -> torch.Tensor:
194
+ """
195
+ Generate relative position encoding.
196
+ :param length: int. length of the sequence.
197
+ :return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
198
+ """
199
+ range_vec = torch.arange(length)
200
+ distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
201
+ distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
202
+ final_mat = distance_mat_clipped + self.max_relative_position
203
+ return final_mat
204
+
205
+ def forward(self,
206
+ x: torch.Tensor,
207
+ mask: torch.Tensor = None,
208
+ cache: torch.Tensor = None
209
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
210
+ # attention! self attention.
211
+
212
+ q, k, v = self.forward_qkv(x, x, x)
213
+ # q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
214
+
215
+ if cache is not None:
216
+ key_cache, value_cache = torch.split(
217
+ cache, cache.size(-1) // 2, dim=-1)
218
+ k = torch.cat([key_cache, k], dim=2)
219
+ v = torch.cat([value_cache, v], dim=2)
220
+
221
+ # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
222
+ new_cache = torch.cat((k, v), dim=-1)
223
+
224
+ # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
225
+ native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
226
+
227
+ # Compute relative position encoding
228
+ q_length, k_length = q.size(2), k.size(2)
229
+ relative_position = self.relative_position_encoding(k_length)
230
+
231
+ relative_position = relative_position[-q_length:]
232
+
233
+ relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
234
+
235
+ relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
236
+ relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
237
+
238
+ relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
239
+ # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
240
+
241
+ # score
242
+ scores = native_scores + relative_position_scores
243
+
244
+ return self.forward_attention(v, scores, mask), new_cache
245
+
246
+
247
+ def main():
248
+ rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
249
+
250
+ x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
251
+ xt, new_cache = rel_attention.forward(x, x, x)
252
+
253
+ # x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
254
+ # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
255
+ # xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
256
+
257
+ print(xt.shape)
258
+ print(new_cache.shape)
259
+ return
260
+
261
+
262
+ if __name__ == '__main__':
263
+ main()
toolbox/torchaudio/models/nx_mpnet/transformers/mask.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+
5
+
6
+ def make_pad_mask(lengths: torch.Tensor,
7
+ max_len: int = 0,
8
+ ) -> torch.Tensor:
9
+ batch_size = lengths.size(0)
10
+ max_len = max_len if max_len > 0 else lengths.max().item()
11
+ seq_range = torch.arange(
12
+ 0,
13
+ max_len,
14
+ dtype=torch.int64,
15
+ device=lengths.device
16
+ )
17
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
18
+ seq_length_expand = lengths.unsqueeze(-1)
19
+ mask = seq_range_expand >= seq_length_expand
20
+ return mask
21
+
22
+
23
+
24
+ def subsequent_chunk_mask(
25
+ size: int,
26
+ chunk_size: int,
27
+ num_left_chunks: int = -1,
28
+ num_right_chunks: int = 0,
29
+ device: torch.device = torch.device("cpu"),
30
+ ) -> torch.Tensor:
31
+ """
32
+ Create mask for subsequent steps (size, size) with chunk size,
33
+ this is for streaming encoder
34
+
35
+ Examples:
36
+ > subsequent_chunk_mask(4, 2)
37
+ [[1, 1, 0, 0],
38
+ [1, 1, 0, 0],
39
+ [1, 1, 1, 1],
40
+ [1, 1, 1, 1]]
41
+
42
+ :param size: int. size of mask.
43
+ :param chunk_size: int. size of chunk.
44
+ :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
45
+ :param num_right_chunks: int. number of right chunks.
46
+ :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
47
+ :return: torch.Tensor. mask
48
+ """
49
+
50
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
51
+ for i in range(size):
52
+ if num_left_chunks < 0:
53
+ start = 0
54
+ else:
55
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
56
+ ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
57
+ ret[i, start:ending] = True
58
+ return ret
59
+
60
+
61
+ def main():
62
+ chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
63
+ print(chunk_mask)
64
+
65
+ chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
66
+ print(chunk_mask)
67
+
68
+ chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
69
+ print(chunk_mask)
70
+ return
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
toolbox/torchaudio/models/nx_mpnet/transformers/transformers.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Dict, Optional, Tuple, List, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
9
+ from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
10
+
11
+
12
+ class PositionwiseFeedForward(nn.Module):
13
+ def __init__(self,
14
+ input_dim: int,
15
+ hidden_units: int,
16
+ dropout_rate: float,
17
+ activation: torch.nn.Module = torch.nn.ReLU()):
18
+ """
19
+ FeedForward are applied on each position of the sequence.
20
+ the output dim is same with the input dim.
21
+
22
+ :param input_dim: int. input dimension.
23
+ :param hidden_units: int. the number of hidden units.
24
+ :param dropout_rate: float. dropout rate.
25
+ :param activation: torch.nn.Module. activation function.
26
+ """
27
+ super(PositionwiseFeedForward, self).__init__()
28
+ self.w_1 = torch.nn.Linear(input_dim, hidden_units)
29
+ self.activation = activation
30
+ self.dropout = torch.nn.Dropout(dropout_rate)
31
+ self.w_2 = torch.nn.Linear(hidden_units, input_dim)
32
+
33
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Forward function.
36
+ :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim).
37
+ :return: output tensor. shape=(batch_size, max_length, dim).
38
+ """
39
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
40
+
41
+
42
+ class TransformerBlock(nn.Module):
43
+ def __init__(self,
44
+ input_dim: int,
45
+ dropout_rate: float = 0.1,
46
+ n_heads: int = 4,
47
+ max_relative_position: int = 5120
48
+ ):
49
+ super().__init__()
50
+ self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
51
+ self.attention = RelativeMultiHeadSelfAttention(
52
+ n_head=n_heads,
53
+ n_feat=input_dim,
54
+ dropout_rate=dropout_rate,
55
+ max_relative_position=max_relative_position,
56
+ )
57
+
58
+ self.dropout1 = nn.Dropout(dropout_rate)
59
+ self.norm2 = nn.LayerNorm(input_dim, eps=1e-5)
60
+ self.ffn = PositionwiseFeedForward(
61
+ input_dim=input_dim,
62
+ hidden_units=input_dim,
63
+ dropout_rate=dropout_rate
64
+ )
65
+ self.dropout2 = nn.Dropout(dropout_rate)
66
+ self.norm3 = nn.LayerNorm(input_dim, eps=1e-5)
67
+
68
+ def forward(
69
+ self,
70
+ x: torch.Tensor,
71
+ mask: torch.Tensor = None,
72
+ attention_cache: torch.Tensor = None,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+
76
+ :param x: torch.Tensor. shape=(batch_size, time, input_dim).
77
+ :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
78
+ :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
79
+ shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
80
+ :return:
81
+ torch.Tensor: Output tensor (batch_size, time, input_dim).
82
+ torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2).
83
+ """
84
+ xt = self.norm1(x)
85
+
86
+ x_att, new_att_cache = self.attention.forward(
87
+ xt, mask=mask, cache=attention_cache
88
+ )
89
+ x = x + self.dropout1(xt)
90
+ xt = self.norm2(x)
91
+ xt = self.ffn.forward(xt)
92
+ x = x + self.dropout2(xt)
93
+
94
+ x = self.norm3(x)
95
+
96
+ return x, new_att_cache
97
+
98
+
99
+ class TransformerEncoder(nn.Module):
100
+ """
101
+ https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364
102
+ """
103
+ def __init__(self,
104
+ input_size: int = 64,
105
+ hidden_size: int = 256,
106
+ attention_heads: int = 4,
107
+ num_blocks: int = 6,
108
+ dropout_rate: float = 0.1,
109
+ max_relative_position: int = 1024,
110
+ chunk_size: int = 1,
111
+ num_left_chunks: int = 128,
112
+ num_right_chunks: int = 2,
113
+ ):
114
+ super().__init__()
115
+ self.input_size = input_size
116
+ self.hidden_size = hidden_size
117
+
118
+ self.max_relative_position = max_relative_position
119
+ self.chunk_size = chunk_size
120
+ self.num_left_chunks = num_left_chunks
121
+ self.num_right_chunks = num_right_chunks
122
+
123
+ self.input_linear = nn.Linear(
124
+ in_features=self.input_size,
125
+ out_features=self.hidden_size,
126
+ )
127
+
128
+ self.encoder_layer_list = torch.nn.ModuleList([
129
+ TransformerBlock(
130
+ input_dim=hidden_size,
131
+ n_heads=attention_heads,
132
+ dropout_rate=dropout_rate,
133
+ max_relative_position=max_relative_position,
134
+ ) for _ in range(num_blocks)
135
+ ])
136
+
137
+ self.output_linear = nn.Linear(
138
+ in_features=self.hidden_size,
139
+ out_features=self.input_size,
140
+ )
141
+
142
+ def forward(self,
143
+ xs: torch.Tensor,
144
+ ):
145
+ """
146
+ :param xs: Tensor, shape: [batch_size, time_steps, input_size]
147
+ :return: Tensor, shape: [batch_size, time_steps, input_size]
148
+ """
149
+ batch_size, time_steps, _ = xs.shape
150
+ # xs shape: [batch_size, time_steps, input_size]
151
+ xs = self.input_linear.forward(xs)
152
+ # xs shape: [batch_size, time_steps, hidden_size]
153
+
154
+ chunk_masks = subsequent_chunk_mask(
155
+ size=time_steps,
156
+ chunk_size=self.chunk_size,
157
+ num_left_chunks=self.num_left_chunks,
158
+ num_right_chunks=self.num_right_chunks,
159
+ )
160
+ chunk_masks = chunk_masks.to(xs.device)
161
+ # chunk_masks shape: [time_steps, time_steps]
162
+ chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
163
+ # chunk_masks shape: [batch_size, time_steps, time_steps]
164
+
165
+ for encoder_layer in self.encoder_layer_list:
166
+ xs, _ = encoder_layer.forward(xs, chunk_masks)
167
+
168
+ # xs shape: [batch_size, time_steps, hidden_size]
169
+ xs = self.output_linear.forward(xs)
170
+ # xs shape: [batch_size, time_steps, input_size]
171
+
172
+ return xs
173
+
174
+ def forward_chunk(self,
175
+ xs: torch.Tensor,
176
+ max_att_cache_length: int,
177
+ attention_cache: torch.Tensor = None,
178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
179
+ """
180
+
181
+ :param xs:
182
+ :param max_att_cache_length:
183
+ :param attention_cache: Tensor, [num_layers, ...]
184
+ :return:
185
+ """
186
+ # xs shape: [batch_size, time_steps, input_size]
187
+ xs = self.input_linear.forward(xs)
188
+ # xs shape: [batch_size, time_steps, hidden_size]
189
+
190
+ r_att_cache = []
191
+ for idx, encoder_layer in enumerate(self.encoder_layer_list):
192
+ xs, new_att_cache = encoder_layer.forward(
193
+ x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
194
+ )
195
+ # new_att_cache shape: [batch_size, n_heads, time_steps, dim]
196
+ if new_att_cache.size(2) > max_att_cache_length:
197
+ begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
198
+ end = self.num_right_chunks * self.chunk_size
199
+ new_att_cache = new_att_cache[:, :, -begin:-end, :]
200
+ r_att_cache.append(new_att_cache)
201
+
202
+ r_att_cache = torch.stack(r_att_cache, dim=0)
203
+
204
+ # xs shape: [batch_size, time_steps, hidden_size]
205
+ xs = self.output_linear.forward(xs)
206
+ # xs shape: [batch_size, time_steps, input_size]
207
+
208
+ return xs, r_att_cache
209
+
210
+ def forward_chunk_by_chunk(
211
+ self,
212
+ xs: torch.Tensor,
213
+ ) -> torch.Tensor:
214
+
215
+ batch_size, time_steps, _ = xs.shape
216
+
217
+ # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
218
+ max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
219
+ attention_cache = None
220
+
221
+ outputs = []
222
+ for idx in range(0, time_steps, self.chunk_size):
223
+ begin = idx
224
+ end = begin + self.chunk_size * (self.num_right_chunks + 1)
225
+ chunk_xs = xs[:, begin:end, :]
226
+ # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
227
+
228
+ ys, attention_cache = self.forward_chunk(
229
+ xs=chunk_xs,
230
+ max_att_cache_length=max_att_cache_length,
231
+ attention_cache=attention_cache,
232
+ )
233
+
234
+ # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), input_size]
235
+ ys = ys[:, :self.chunk_size, :]
236
+
237
+ outputs.append(ys)
238
+
239
+ ys = torch.cat(outputs, 1)
240
+ return ys
241
+
242
+
243
+ class TSTransformerBlock(nn.Module):
244
+ def __init__(self,
245
+ input_dim: int,
246
+ dropout_rate: float = 0.1,
247
+ n_heads: int = 4,
248
+ max_time_relative_position: int = 1024,
249
+ max_freq_relative_position: int = 128,
250
+ ):
251
+ super(TSTransformerBlock, self).__init__()
252
+ self.time_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_time_relative_position)
253
+ self.freq_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_freq_relative_position)
254
+
255
+ def forward(self,
256
+ x: torch.Tensor,
257
+ mask: torch.Tensor = None,
258
+ attention_cache: torch.Tensor = None,
259
+ ):
260
+ """
261
+
262
+ :param x: Tensor. shape: [batch_size, hidden_size, time_steps, input_size]
263
+ :param mask: Tensor. shape: [time_steps, time_steps]
264
+ :param attention_cache:
265
+ :return:
266
+ """
267
+ b, c, t, f = x.size()
268
+
269
+ mask = None if mask is None else torch.broadcast_to(mask, size=(b*f, t, t))
270
+
271
+ x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
272
+ x_, new_att_cache = self.time_transformer.forward(x, mask, attention_cache)
273
+ x = x_ + x
274
+ x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
275
+ x_, _ = self.freq_transformer.forward(x)
276
+ x = x_ + x
277
+ x = x.view(b, t, f, c).permute(0, 3, 1, 2)
278
+ return x, new_att_cache
279
+
280
+
281
+ class TSTransformerEncoder(nn.Module):
282
+ def __init__(self,
283
+ input_size: int = 64,
284
+ hidden_size: int = 256,
285
+ attention_heads: int = 4,
286
+ num_blocks: int = 6,
287
+ dropout_rate: float = 0.1,
288
+ max_time_relative_position: int = 1024,
289
+ max_freq_relative_position: int = 128,
290
+ chunk_size: int = 1,
291
+ num_left_chunks: int = 128,
292
+ num_right_chunks: int = 2,
293
+ ):
294
+ super().__init__()
295
+ self.input_size = input_size
296
+ self.hidden_size = hidden_size
297
+
298
+ self.max_time_relative_position = max_time_relative_position
299
+ self.max_freq_relative_position = max_freq_relative_position
300
+ self.chunk_size = chunk_size
301
+ self.num_left_chunks = num_left_chunks
302
+ self.num_right_chunks = num_right_chunks
303
+
304
+ self.input_linear = nn.Linear(
305
+ in_features=self.input_size,
306
+ out_features=self.hidden_size,
307
+ )
308
+
309
+ self.encoder_layer_list = torch.nn.ModuleList([
310
+ TSTransformerBlock(
311
+ input_dim=hidden_size,
312
+ n_heads=attention_heads,
313
+ dropout_rate=dropout_rate,
314
+ max_time_relative_position=max_time_relative_position,
315
+ max_freq_relative_position=max_freq_relative_position,
316
+ ) for _ in range(num_blocks)
317
+ ])
318
+
319
+ self.output_linear = nn.Linear(
320
+ in_features=self.hidden_size,
321
+ out_features=self.input_size,
322
+ )
323
+
324
+ def forward(self,
325
+ xs: torch.Tensor,
326
+ ):
327
+ """
328
+ :param xs: Tensor, shape: [batch_size, channels, time_steps, input_size]
329
+ :return: Tensor, shape: [batch_size, channels, time_steps, input_size]
330
+ """
331
+ batch_size, channels, time_steps, _ = xs.shape
332
+ # xs shape: [batch_size, channels, time_steps, input_size]
333
+ xs = xs.permute(0, 3, 2, 1)
334
+ # xs shape: [batch_size, input_size, time_steps, channels]
335
+ xs = self.input_linear.forward(xs)
336
+ # xs shape: [batch_size, input_size, time_steps, hidden_size]
337
+ xs = xs.permute(0, 3, 2, 1)
338
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
339
+
340
+ chunk_masks = subsequent_chunk_mask(
341
+ size=time_steps,
342
+ chunk_size=self.chunk_size,
343
+ num_left_chunks=self.num_left_chunks,
344
+ num_right_chunks=self.num_right_chunks,
345
+ )
346
+ chunk_masks = chunk_masks.to(xs.device)
347
+ # chunk_masks shape: [time_steps, time_steps]
348
+
349
+ for encoder_layer in self.encoder_layer_list:
350
+ xs, _ = encoder_layer.forward(xs, chunk_masks)
351
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
352
+ xs = xs.permute(0, 3, 2, 1)
353
+ # xs shape: [batch_size, input_size, time_steps, hidden_size]
354
+ xs = self.output_linear.forward(xs)
355
+ # xs shape: [batch_size, input_size, time_steps, channels]
356
+ xs = xs.permute(0, 3, 2, 1)
357
+ # xs shape: [batch_size, channels, time_steps, input_size]
358
+
359
+ return xs
360
+
361
+ def forward_chunk(self,
362
+ xs: torch.Tensor,
363
+ max_att_cache_length: int,
364
+ attention_cache: torch.Tensor = None,
365
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
366
+ """
367
+
368
+ :param xs:
369
+ :param max_att_cache_length:
370
+ :param attention_cache: Tensor, shape: [num_layers, ...]
371
+ :return:
372
+ """
373
+ # xs shape: [batch_size, channels, time_steps, input_size]
374
+ xs = xs.permute(0, 3, 2, 1)
375
+ xs = self.input_linear.forward(xs)
376
+ xs = xs.permute(0, 3, 2, 1)
377
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
378
+
379
+ r_att_cache = []
380
+ for idx, encoder_layer in enumerate(self.encoder_layer_list):
381
+ xs, new_att_cache = encoder_layer.forward(
382
+ x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
383
+ )
384
+ # new_att_cache shape: [b*f, n_heads, time_steps, dim]
385
+ if new_att_cache.size(2) > max_att_cache_length:
386
+ begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
387
+ end = self.num_right_chunks * self.chunk_size
388
+ new_att_cache = new_att_cache[:, :, -begin:-end, :]
389
+ r_att_cache.append(new_att_cache)
390
+
391
+ r_att_cache = torch.stack(r_att_cache, dim=0)
392
+
393
+ # xs shape: [batch_size, hidden_size, time_steps, input_size]
394
+ xs = xs.permute(0, 3, 2, 1)
395
+ xs = self.output_linear.forward(xs)
396
+ xs = xs.permute(0, 3, 2, 1)
397
+ # xs shape: [batch_size, channels, time_steps, input_size]
398
+
399
+ return xs, r_att_cache
400
+
401
+ def forward_chunk_by_chunk(
402
+ self,
403
+ xs: torch.Tensor,
404
+ ) -> torch.Tensor:
405
+
406
+ batch_size, channels, time_steps, _ = xs.shape
407
+
408
+ max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
409
+ attention_cache = None
410
+
411
+ outputs = []
412
+ for idx in range(0, time_steps, self.chunk_size):
413
+ begin = idx
414
+ end = begin + self.chunk_size * (self.num_right_chunks + 1)
415
+ chunk_xs = xs[:, :, begin:end, :]
416
+ # chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
417
+
418
+ ys, attention_cache = self.forward_chunk(
419
+ xs=chunk_xs,
420
+ max_att_cache_length=max_att_cache_length,
421
+ attention_cache=attention_cache,
422
+ )
423
+ # ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
424
+ ys = ys[:, :, :self.chunk_size, :]
425
+
426
+ outputs.append(ys)
427
+
428
+ ys = torch.cat(outputs, dim=2)
429
+ return ys
430
+
431
+
432
+ def main2():
433
+
434
+ encoder = TransformerEncoder(
435
+ input_size=64,
436
+ hidden_size=256,
437
+ attention_heads=4,
438
+ num_blocks=6,
439
+ dropout_rate=0.1,
440
+ )
441
+ print(encoder)
442
+
443
+ x = torch.ones([4, 200, 64])
444
+
445
+ x = torch.ones([4, 200, 64])
446
+ y = encoder.forward(xs=x)
447
+ print(y.shape)
448
+
449
+ x = torch.ones([4, 200, 64])
450
+ y = encoder.forward_chunk_by_chunk(xs=x)
451
+ print(y.shape)
452
+
453
+ return
454
+
455
+
456
+ def main():
457
+
458
+ encoder = TSTransformerEncoder(
459
+ input_size=8,
460
+ hidden_size=16,
461
+ attention_heads=2,
462
+ num_blocks=2,
463
+ dropout_rate=0.1,
464
+ )
465
+ # print(encoder)
466
+
467
+ x = torch.ones([4, 8, 200, 8])
468
+ y = encoder.forward(xs=x)
469
+ print(y.shape)
470
+
471
+ x = torch.ones([4, 8, 200, 8])
472
+ y = encoder.forward_chunk_by_chunk(xs=x)
473
+ print(y.shape)
474
+
475
+ return
476
+
477
+
478
+ if __name__ == '__main__':
479
+ main()
toolbox/torchaudio/models/nx_mpnet/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class LearnableSigmoid1d(nn.Module):
8
+ def __init__(self, in_features, beta=1):
9
+ super().__init__()
10
+ self.beta = beta
11
+ self.slope = nn.Parameter(torch.ones(in_features))
12
+ self.slope.requiresGrad = True
13
+
14
+ def forward(self, x):
15
+ # x shape: [batch_size, time_steps, spec_bins]
16
+ return self.beta * torch.sigmoid(self.slope * x)
17
+
18
+
19
+ class LearnableSigmoid2d(nn.Module):
20
+ def __init__(self, in_features, beta=1):
21
+ super().__init__()
22
+ self.beta = beta
23
+ self.slope = nn.Parameter(torch.ones(in_features, 1))
24
+ self.slope.requiresGrad = True
25
+
26
+ def forward(self, x):
27
+ return self.beta * torch.sigmoid(self.slope * x)
28
+
29
+
30
+ def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
31
+
32
+ hann_window = torch.hann_window(win_size).to(y.device)
33
+ stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
34
+ center=center, pad_mode='reflect', normalized=False, return_complex=True)
35
+ stft_spec = torch.view_as_real(stft_spec)
36
+ mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
37
+ pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
38
+ # Magnitude Compression
39
+ mag = torch.pow(mag, compress_factor)
40
+ com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
41
+
42
+ return mag, pha, com
43
+
44
+
45
+ def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
46
+ # Magnitude Decompression
47
+ mag = torch.pow(mag, (1.0/compress_factor))
48
+ com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
49
+ hann_window = torch.hann_window(win_size).to(com.device)
50
+ wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
51
+
52
+ return wav
53
+
54
+
55
+ if __name__ == '__main__':
56
+ pass