HoneyTian commited on
Commit
8ed9309
·
1 Parent(s): ad1f7b5
examples/clean_unet_aishell/run.sh ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
+
19
+
20
+ END
21
+
22
+
23
+ # params
24
+ system_version="windows";
25
+ verbose=true;
26
+ stage=0 # start from 0 if you need to start from data preparation
27
+ stop_stage=9
28
+
29
+ work_dir="$(pwd)"
30
+ file_folder_name=file_folder_name
31
+ final_model_name=final_model_name
32
+ config_file="yaml/config.yaml"
33
+ limit=10
34
+
35
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
36
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
37
+
38
+ nohup_name=nohup.out
39
+
40
+ # model params
41
+ batch_size=64
42
+ max_epochs=200
43
+ save_top_k=10
44
+ patience=5
45
+
46
+
47
+ # parse options
48
+ while true; do
49
+ [ -z "${1:-}" ] && break; # break if there are no arguments
50
+ case "$1" in
51
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
52
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
53
+ old_value="(eval echo \\$$name)";
54
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
55
+ was_bool=true;
56
+ else
57
+ was_bool=false;
58
+ fi
59
+
60
+ # Set the variable to the right value-- the escaped quotes make it work if
61
+ # the option had spaces, like --cmd "queue.pl -sync y"
62
+ eval "${name}=\"$2\"";
63
+
64
+ # Check that Boolean-valued arguments are really Boolean.
65
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
66
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
67
+ exit 1;
68
+ fi
69
+ shift 2;
70
+ ;;
71
+
72
+ *) break;
73
+ esac
74
+ done
75
+
76
+ file_dir="${work_dir}/${file_folder_name}"
77
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
78
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
79
+
80
+ dataset="${file_dir}/dataset.xlsx"
81
+ train_dataset="${file_dir}/train.xlsx"
82
+ valid_dataset="${file_dir}/valid.xlsx"
83
+
84
+ $verbose && echo "system_version: ${system_version}"
85
+ $verbose && echo "file_folder_name: ${file_folder_name}"
86
+
87
+ if [ $system_version == "windows" ]; then
88
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
89
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
90
+ #source /data/local/bin/nx_denoise/bin/activate
91
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
92
+ fi
93
+
94
+
95
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
96
+ $verbose && echo "stage 1: prepare data"
97
+ cd "${work_dir}" || exit 1
98
+ python3 step_1_prepare_data.py \
99
+ --file_dir "${file_dir}" \
100
+ --noise_dir "${noise_dir}" \
101
+ --speech_dir "${speech_dir}" \
102
+ --train_dataset "${train_dataset}" \
103
+ --valid_dataset "${valid_dataset}" \
104
+
105
+ fi
106
+
107
+
108
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
109
+ $verbose && echo "stage 2: train model"
110
+ cd "${work_dir}" || exit 1
111
+ python3 step_2_train_model.py \
112
+ --train_dataset "${train_dataset}" \
113
+ --valid_dataset "${valid_dataset}" \
114
+ --serialization_dir "${file_dir}" \
115
+ --config_file "${config_file}" \
116
+
117
+ fi
118
+
119
+
120
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
121
+ $verbose && echo "stage 3: test model"
122
+ cd "${work_dir}" || exit 1
123
+ python3 step_3_evaluation.py \
124
+ --valid_dataset "${valid_dataset}" \
125
+ --model_dir "${file_dir}/best" \
126
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
127
+ --limit "${limit}" \
128
+
129
+ fi
130
+
131
+
132
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
133
+ $verbose && echo "stage 4: export model"
134
+ cd "${work_dir}" || exit 1
135
+ python3 step_5_export_models.py \
136
+ --vocabulary_dir "${vocabulary_dir}" \
137
+ --model_dir "${file_dir}/best" \
138
+ --serialization_dir "${file_dir}" \
139
+
140
+ fi
141
+
142
+
143
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
144
+ $verbose && echo "stage 5: collect files"
145
+ cd "${work_dir}" || exit 1
146
+
147
+ mkdir -p ${final_model_dir}
148
+
149
+ cp "${file_dir}/best"/* "${final_model_dir}"
150
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
151
+
152
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
153
+
154
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
155
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
156
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
157
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
158
+
159
+ cd "${final_model_dir}/.." || exit 1;
160
+
161
+ if [ -e "${final_model_name}.zip" ]; then
162
+ rm -rf "${final_model_name}_backup.zip"
163
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
164
+ fi
165
+
166
+ zip -r "${final_model_name}.zip" "${final_model_name}"
167
+ rm -rf "${final_model_name}"
168
+
169
+ fi
170
+
171
+
172
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
173
+ $verbose && echo "stage 6: clear file_dir"
174
+ cd "${work_dir}" || exit 1
175
+
176
+ rm -rf "${file_dir}";
177
+
178
+ fi
examples/clean_unet_aishell/step_1_prepare_data.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ parser.add_argument("--scale", default=1, type=float)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def filename_generator(data_dir: str):
52
+ data_dir = Path(data_dir)
53
+ for filename in data_dir.glob("**/*.wav"):
54
+ yield filename.as_posix()
55
+
56
+
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
58
+ data_dir = Path(data_dir)
59
+ for filename in data_dir.glob("**/*.wav"):
60
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
61
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
62
+
63
+ if raw_duration < duration:
64
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
65
+ continue
66
+ if signal.ndim != 1:
67
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
68
+
69
+ signal_length = len(signal)
70
+ win_size = int(duration * sample_rate)
71
+ for begin in range(0, signal_length - win_size, win_size):
72
+ row = {
73
+ "filename": filename.as_posix(),
74
+ "raw_duration": round(raw_duration, 4),
75
+ "offset": round(begin / sample_rate, 4),
76
+ "duration": round(duration, 4),
77
+ }
78
+ yield row
79
+
80
+
81
+ def get_dataset(args):
82
+ file_dir = Path(args.file_dir)
83
+ file_dir.mkdir(exist_ok=True)
84
+
85
+ noise_dir = Path(args.noise_dir)
86
+ speech_dir = Path(args.speech_dir)
87
+
88
+ noise_generator = target_second_signal_generator(
89
+ noise_dir.as_posix(),
90
+ duration=args.duration,
91
+ sample_rate=args.target_sample_rate
92
+ )
93
+ speech_generator = target_second_signal_generator(
94
+ speech_dir.as_posix(),
95
+ duration=args.duration,
96
+ sample_rate=args.target_sample_rate
97
+ )
98
+
99
+ dataset = list()
100
+
101
+ count = 0
102
+ process_bar = tqdm(desc="build dataset excel")
103
+ for noise, speech in zip(noise_generator, speech_generator):
104
+ flag = random.random()
105
+ if flag > args.scale:
106
+ continue
107
+
108
+ noise_filename = noise["filename"]
109
+ noise_raw_duration = noise["raw_duration"]
110
+ noise_offset = noise["offset"]
111
+ noise_duration = noise["duration"]
112
+
113
+ speech_filename = speech["filename"]
114
+ speech_raw_duration = speech["raw_duration"]
115
+ speech_offset = speech["offset"]
116
+ speech_duration = speech["duration"]
117
+
118
+ random1 = random.random()
119
+ random2 = random.random()
120
+
121
+ row = {
122
+ "noise_filename": noise_filename,
123
+ "noise_raw_duration": noise_raw_duration,
124
+ "noise_offset": noise_offset,
125
+ "noise_duration": noise_duration,
126
+
127
+ "speech_filename": speech_filename,
128
+ "speech_raw_duration": speech_raw_duration,
129
+ "speech_offset": speech_offset,
130
+ "speech_duration": speech_duration,
131
+
132
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
133
+
134
+ "random1": random1,
135
+ "random2": random2,
136
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
137
+ }
138
+ dataset.append(row)
139
+ count += 1
140
+ duration_seconds = count * args.duration
141
+ duration_hours = duration_seconds / 3600
142
+
143
+ process_bar.update(n=1)
144
+ process_bar.set_postfix({
145
+ # "duration_seconds": round(duration_seconds, 4),
146
+ "duration_hours": round(duration_hours, 4),
147
+
148
+ })
149
+
150
+ dataset = pd.DataFrame(dataset)
151
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
152
+ dataset.to_excel(
153
+ file_dir / "dataset.xlsx",
154
+ index=False,
155
+ )
156
+ return
157
+
158
+
159
+
160
+ def split_dataset(args):
161
+ """分割训练集, 测试集"""
162
+ file_dir = Path(args.file_dir)
163
+ file_dir.mkdir(exist_ok=True)
164
+
165
+ df = pd.read_excel(file_dir / "dataset.xlsx")
166
+
167
+ train = list()
168
+ test = list()
169
+
170
+ for i, row in df.iterrows():
171
+ flag = row["flag"]
172
+ if flag == "TRAIN":
173
+ train.append(row)
174
+ else:
175
+ test.append(row)
176
+
177
+ train = pd.DataFrame(train)
178
+ train.to_excel(
179
+ args.train_dataset,
180
+ index=False,
181
+ # encoding="utf_8_sig"
182
+ )
183
+ test = pd.DataFrame(test)
184
+ test.to_excel(
185
+ args.valid_dataset,
186
+ index=False,
187
+ # encoding="utf_8_sig"
188
+ )
189
+
190
+ return
191
+
192
+
193
+ def main():
194
+ args = get_args()
195
+
196
+ get_dataset(args)
197
+ split_dataset(args)
198
+ return
199
+
200
+
201
+ if __name__ == "__main__":
202
+ main()
examples/clean_unet_aishell/step_2_train_model.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/NVIDIA/CleanUNet/blob/main/train.py
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch.nn import functional as F
24
+ from torch.utils.data.dataloader import DataLoader
25
+ from tqdm import tqdm
26
+
27
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
+ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
29
+ from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
30
+
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
35
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
36
+
37
+ parser.add_argument("--max_epochs", default=100, type=int)
38
+
39
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
40
+ parser.add_argument("--patience", default=5, type=int)
41
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
42
+
43
+ parser.add_argument("--config_file", default="config.yaml", type=str)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def logging_config(file_dir: str):
50
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
51
+
52
+ logging.basicConfig(format=fmt,
53
+ datefmt="%m/%d/%Y %H:%M:%S",
54
+ level=logging.INFO)
55
+ file_handler = TimedRotatingFileHandler(
56
+ filename=os.path.join(file_dir, "main.log"),
57
+ encoding="utf-8",
58
+ when="D",
59
+ interval=1,
60
+ backupCount=7
61
+ )
62
+ file_handler.setLevel(logging.INFO)
63
+ file_handler.setFormatter(logging.Formatter(fmt))
64
+ logger = logging.getLogger(__name__)
65
+ logger.addHandler(file_handler)
66
+
67
+ return logger
68
+
69
+
70
+ class CollateFunction(object):
71
+ def __init__(self):
72
+ pass
73
+
74
+ def __call__(self, batch: List[dict]):
75
+ clean_audios = list()
76
+ noisy_audios = list()
77
+
78
+ for sample in batch:
79
+ # noise_wave: torch.Tensor = sample["noise_wave"]
80
+ clean_audio: torch.Tensor = sample["speech_wave"]
81
+ noisy_audio: torch.Tensor = sample["mix_wave"]
82
+ # snr_db: float = sample["snr_db"]
83
+
84
+ clean_audios.append(clean_audio)
85
+ noisy_audios.append(noisy_audio)
86
+
87
+ clean_audios = torch.stack(clean_audios)
88
+ noisy_audios = torch.stack(noisy_audios)
89
+
90
+ # assert
91
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
92
+ raise AssertionError("nan or inf in clean_audios")
93
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
94
+ raise AssertionError("nan or inf in noisy_audios")
95
+ return clean_audios, noisy_audios
96
+
97
+
98
+ collate_fn = CollateFunction()
99
+
100
+
101
+ def main():
102
+ args = get_args()
103
+
104
+ config = CleanUnetConfig.from_pretrained(
105
+ pretrained_model_name_or_path=args.config_file,
106
+ )
107
+
108
+ serialization_dir = Path(args.serialization_dir)
109
+ serialization_dir.mkdir(parents=True, exist_ok=True)
110
+
111
+ logger = logging_config(serialization_dir)
112
+
113
+ random.seed(config.seed)
114
+ np.random.seed(config.seed)
115
+ torch.manual_seed(config.seed)
116
+ logger.info(f"set seed: {config.seed}")
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ n_gpu = torch.cuda.device_count()
120
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
121
+
122
+ # datasets
123
+ train_dataset = DenoiseExcelDataset(
124
+ excel_file=args.train_dataset,
125
+ expected_sample_rate=8000,
126
+ max_wave_value=32768.0,
127
+ )
128
+ valid_dataset = DenoiseExcelDataset(
129
+ excel_file=args.valid_dataset,
130
+ expected_sample_rate=8000,
131
+ max_wave_value=32768.0,
132
+ )
133
+ train_data_loader = DataLoader(
134
+ dataset=train_dataset,
135
+ batch_size=config.batch_size,
136
+ shuffle=True,
137
+ sampler=None,
138
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
139
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
140
+ collate_fn=collate_fn,
141
+ pin_memory=False,
142
+ # prefetch_factor=64,
143
+ )
144
+ valid_data_loader = DataLoader(
145
+ dataset=valid_dataset,
146
+ batch_size=config.batch_size,
147
+ shuffle=True,
148
+ sampler=None,
149
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
150
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
151
+ collate_fn=collate_fn,
152
+ pin_memory=False,
153
+ # prefetch_factor=64,
154
+ )
155
+
156
+ # models
157
+ logger.info(f"prepare models. config_file: {args.config_file}")
158
+ model = CleanUNetPretrainedModel(config).to(device)
159
+
160
+ # optimizer
161
+ logger.info("prepare optimizer, lr_scheduler")
162
+ optim_g = torch.optim.AdamW(model.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
163
+
164
+ # resume training
165
+ last_epoch = -1
166
+ for epoch_i in serialization_dir.glob("epoch-*"):
167
+ epoch_i = Path(epoch_i)
168
+ epoch_idx = epoch_i.stem.split("-")[1]
169
+ epoch_idx = int(epoch_idx)
170
+ if epoch_idx > last_epoch:
171
+ last_epoch = epoch_idx
172
+
173
+ if last_epoch != -1:
174
+ logger.info(f"resume from epoch-{last_epoch}.")
175
+ generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
176
+ discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
177
+ optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
178
+ optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
179
+
180
+ logger.info(f"load state dict for generator.")
181
+ with open(generator_pt.as_posix(), "rb") as f:
182
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
183
+ generator.load_state_dict(state_dict, strict=True)
184
+ logger.info(f"load state dict for discriminator.")
185
+ with open(discriminator_pt.as_posix(), "rb") as f:
186
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
187
+ discriminator.load_state_dict(state_dict, strict=True)
188
+
189
+ logger.info(f"load state dict for optim_g.")
190
+ with open(optim_g_pth.as_posix(), "rb") as f:
191
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
+ optim_g.load_state_dict(state_dict)
193
+ logger.info(f"load state dict for optim_d.")
194
+ with open(optim_d_pth.as_posix(), "rb") as f:
195
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
196
+ optim_d.load_state_dict(state_dict)
197
+
198
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
199
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
200
+
201
+ # training loop
202
+
203
+ # state
204
+ loss_d = 10000000000
205
+ loss_g = 10000000000
206
+ pesq_metric = 10000000000
207
+ mag_err = 10000000000
208
+ pha_err = 10000000000
209
+ com_err = 10000000000
210
+ stft_err = 10000000000
211
+
212
+ model_list = list()
213
+ best_idx_epoch = None
214
+ best_metric = None
215
+ patience_count = 0
216
+
217
+ logger.info("training")
218
+ for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
219
+ # train
220
+ generator.train()
221
+ discriminator.train()
222
+
223
+ total_loss_d = 0.
224
+ total_loss_g = 0.
225
+ total_batches = 0.
226
+ progress_bar = tqdm(
227
+ total=len(train_data_loader),
228
+ desc="Training; epoch: {}".format(idx_epoch),
229
+ )
230
+ for batch in train_data_loader:
231
+ clean_audio, noisy_audio = batch
232
+ clean_audio = clean_audio.to(device)
233
+ noisy_audio = noisy_audio.to(device)
234
+ one_labels = torch.ones(clean_audio.shape[0]).to(device)
235
+
236
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
237
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
238
+
239
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
240
+
241
+ audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
242
+ mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
243
+
244
+ audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
245
+ batch_pesq_score = batch_pesq(audio_list_r, audio_list_g)
246
+
247
+ # Discriminator
248
+ optim_d.zero_grad()
249
+ metric_r = discriminator.forward(clean_mag, clean_mag)
250
+ metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
251
+ loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
252
+
253
+ if batch_pesq_score is not None:
254
+ loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
255
+ else:
256
+ # print("pesq is None!")
257
+ loss_disc_g = 0
258
+
259
+ loss_disc_all = loss_disc_r + loss_disc_g
260
+ loss_disc_all.backward()
261
+ optim_d.step()
262
+
263
+ # Generator
264
+ optim_g.zero_grad()
265
+ # L2 Magnitude Loss
266
+ loss_mag = F.mse_loss(clean_mag, mag_g)
267
+ # Anti-wrapping Phase Loss
268
+ loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
269
+ loss_pha = loss_ip + loss_gd + loss_iaf
270
+ # L2 Complex Loss
271
+ loss_com = F.mse_loss(clean_com, com_g) * 2
272
+ # L2 Consistency Loss
273
+ loss_stft = F.mse_loss(com_g, com_g_hat) * 2
274
+ # Time Loss
275
+ loss_time = F.l1_loss(clean_audio, audio_g)
276
+ # Metric Loss
277
+ metric_g = discriminator.forward(clean_mag, mag_g_hat)
278
+ loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
279
+
280
+ loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
281
+
282
+ loss_gen_all.backward()
283
+ optim_g.step()
284
+
285
+ total_loss_d += loss_disc_all.item()
286
+ total_loss_g += loss_gen_all.item()
287
+ total_batches += 1
288
+
289
+ loss_d = round(total_loss_d / total_batches, 4)
290
+ loss_g = round(total_loss_g / total_batches, 4)
291
+
292
+ progress_bar.update(1)
293
+ progress_bar.set_postfix({
294
+ "loss_d": loss_d,
295
+ "loss_g": loss_g,
296
+ })
297
+
298
+ # evaluation
299
+ generator.eval()
300
+ discriminator.eval()
301
+
302
+ torch.cuda.empty_cache()
303
+ total_pesq_score = 0.
304
+ total_mag_err = 0.
305
+ total_pha_err = 0.
306
+ total_com_err = 0.
307
+ total_stft_err = 0.
308
+ total_batches = 0.
309
+
310
+ progress_bar = tqdm(
311
+ total=len(valid_data_loader),
312
+ desc="Evaluation; epoch: {}".format(idx_epoch),
313
+ )
314
+ with torch.no_grad():
315
+ for batch in valid_data_loader:
316
+ clean_audio, noisy_audio = batch
317
+ clean_audio = clean_audio.to(device)
318
+ noisy_audio = noisy_audio.to(device)
319
+
320
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
321
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
322
+
323
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
324
+
325
+ audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
326
+ mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
327
+
328
+ total_pesq_score += pesq_score(
329
+ torch.split(clean_audio, 1, dim=0),
330
+ torch.split(audio_g, 1, dim=0),
331
+ config
332
+ ).item()
333
+ total_mag_err += F.mse_loss(clean_mag, mag_g).item()
334
+ val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
335
+ total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
336
+ total_com_err += F.mse_loss(clean_com, com_g).item()
337
+ total_stft_err += F.mse_loss(com_g, com_g_hat).item()
338
+
339
+ total_batches += 1
340
+
341
+ pesq_metric = round(total_pesq_score / total_batches, 4)
342
+ mag_err = round(total_mag_err / total_batches, 4)
343
+ pha_err = round(total_pha_err / total_batches, 4)
344
+ com_err = round(total_com_err / total_batches, 4)
345
+ stft_err = round(total_stft_err / total_batches, 4)
346
+
347
+ progress_bar.update(1)
348
+ progress_bar.set_postfix({
349
+ "pesq_metric": pesq_metric,
350
+ "mag_err": mag_err,
351
+ "pha_err": pha_err,
352
+ "com_err": com_err,
353
+ "stft_err": stft_err,
354
+ })
355
+
356
+ # scheduler
357
+ scheduler_g.step()
358
+ scheduler_d.step()
359
+
360
+ # save path
361
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
362
+ epoch_dir.mkdir(parents=True, exist_ok=False)
363
+
364
+ # save models
365
+ generator.save_pretrained(epoch_dir.as_posix())
366
+ discriminator.save_pretrained(epoch_dir.as_posix())
367
+
368
+ # save optim
369
+ torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
370
+ torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
371
+
372
+ model_list.append(epoch_dir)
373
+ if len(model_list) >= args.num_serialized_models_to_keep:
374
+ model_to_delete: Path = model_list.pop(0)
375
+ shutil.rmtree(model_to_delete.as_posix())
376
+
377
+ # save metric
378
+ if best_metric is None:
379
+ best_idx_epoch = idx_epoch
380
+ best_metric = pesq_metric
381
+ elif pesq_metric > best_metric:
382
+ # great is better.
383
+ best_idx_epoch = idx_epoch
384
+ best_metric = pesq_metric
385
+ else:
386
+ pass
387
+
388
+ metrics = {
389
+ "idx_epoch": idx_epoch,
390
+ "best_idx_epoch": best_idx_epoch,
391
+ "loss_d": loss_d,
392
+ "loss_g": loss_g,
393
+
394
+ "pesq_metric": pesq_metric,
395
+ "mag_err": mag_err,
396
+ "pha_err": pha_err,
397
+ "com_err": com_err,
398
+ "stft_err": stft_err,
399
+
400
+ }
401
+ metrics_filename = epoch_dir / "metrics_epoch.json"
402
+ with open(metrics_filename, "w", encoding="utf-8") as f:
403
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
404
+
405
+ # save best
406
+ best_dir = serialization_dir / "best"
407
+ if best_idx_epoch == idx_epoch:
408
+ if best_dir.exists():
409
+ shutil.rmtree(best_dir)
410
+ shutil.copytree(epoch_dir, best_dir)
411
+
412
+ # early stop
413
+ early_stop_flag = False
414
+ if best_idx_epoch == idx_epoch:
415
+ patience_count = 0
416
+ else:
417
+ patience_count += 1
418
+ if patience_count >= args.patience:
419
+ early_stop_flag = True
420
+
421
+ # early stop
422
+ if early_stop_flag:
423
+ break
424
+
425
+ return
426
+
427
+
428
+ if __name__ == "__main__":
429
+ main()
examples/clean_unet_aishell/step_3_evaluation.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
examples/clean_unet_aishell/yaml/config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "clean_unet"
2
+
3
+ channels_input: 1
4
+ channels_output: 1
5
+ channels_h: 64
6
+ max_h: 768
7
+ encoder_n_layers: 8
8
+ kernel_size: 4
9
+ stride: 2
10
+ tsfm_n_layers: 5
11
+ tsfm_n_head: 8
12
+ tsfm_d_model: 512
13
+ tsfm_d_inner: 2048
examples/mpnet_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -20,18 +20,12 @@ sys.path.append(os.path.join(pwd, "../../"))
20
 
21
  import numpy as np
22
  import torch
23
- from torch.distributed import init_process_group
24
- import torch.multiprocessing as mp
25
- from torch.nn.parallel import DistributedDataParallel
26
- import torch.nn as nn
27
  from torch.nn import functional as F
28
- from torch.utils.data import DistributedSampler
29
  from torch.utils.data.dataloader import DataLoader
30
- import torchaudio
31
  from tqdm import tqdm
32
 
33
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
34
- from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
35
  from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel, batch_pesq
36
  from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
37
  from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
 
20
 
21
  import numpy as np
22
  import torch
 
 
 
 
23
  from torch.nn import functional as F
 
24
  from torch.utils.data.dataloader import DataLoader
 
25
  from tqdm import tqdm
26
 
27
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
+ from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
29
  from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel, batch_pesq
30
  from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
31
  from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
examples/mpnet_aishell/step_3_evaluation.py CHANGED
@@ -1,6 +1,184 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  if __name__ == '__main__':
6
- pass
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/yxlu-0102/MP-SENet/blob/main/inference.py
5
+ """
6
+ import argparse
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ import sys
11
+ import uuid
12
+
13
+ pwd = os.path.abspath(os.path.dirname(__file__))
14
+ sys.path.append(os.path.join(pwd, "../../"))
15
+
16
+ import librosa
17
+ import numpy as np
18
+ import pandas as pd
19
+ from scipy.io import wavfile
20
+ import torch
21
+ import torch.nn as nn
22
+ import torchaudio
23
+ from tqdm import tqdm
24
+
25
+ from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
26
+ from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel
27
+ from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
33
+ parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
34
+ parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
35
+
36
+ parser.add_argument("--limit", default=10, type=int)
37
+
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def logging_config():
43
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
44
+
45
+ logging.basicConfig(format=fmt,
46
+ datefmt="%m/%d/%Y %H:%M:%S",
47
+ level=logging.INFO)
48
+ stream_handler = logging.StreamHandler()
49
+ stream_handler.setLevel(logging.INFO)
50
+ stream_handler.setFormatter(logging.Formatter(fmt))
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+ return logger
55
+
56
+
57
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
58
+ l1 = len(speech)
59
+ l2 = len(noise)
60
+ l = min(l1, l2)
61
+ speech = speech[:l]
62
+ noise = noise[:l]
63
+
64
+ # np.float32, value between (-1, 1).
65
+
66
+ speech_power = np.mean(np.square(speech))
67
+ noise_power = speech_power / (10 ** (snr_db / 10))
68
+
69
+ noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
70
+
71
+ noisy_signal = speech + noise_adjusted
72
+
73
+ return noisy_signal
74
+
75
+
76
+ def save_audios(noise_audio: torch.Tensor,
77
+ clean_audio: torch.Tensor,
78
+ noisy_audio: torch.Tensor,
79
+ enhanced_audio: torch.Tensor,
80
+ output_dir: str,
81
+ sample_rate: int = 8000,
82
+ ):
83
+ basename = uuid.uuid4().__str__()
84
+ output_dir = Path(output_dir) / basename
85
+ output_dir.mkdir(parents=True, exist_ok=True)
86
+
87
+ filename = output_dir / "noise_audio.wav"
88
+ torchaudio.save(filename, noise_audio, sample_rate)
89
+ filename = output_dir / "clean_audio.wav"
90
+ torchaudio.save(filename, clean_audio, sample_rate)
91
+ filename = output_dir / "noisy_audio.wav"
92
+ torchaudio.save(filename, noisy_audio, sample_rate)
93
+
94
+ filename = output_dir / "enhanced_audio.wav"
95
+ torchaudio.save(filename, enhanced_audio, sample_rate)
96
+
97
+ return output_dir.as_posix()
98
+
99
+
100
+ def main():
101
+ args = get_args()
102
+
103
+ logger = logging_config()
104
+
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ n_gpu = torch.cuda.device_count()
107
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
108
+
109
+ logger.info("prepare model")
110
+ config = MPNetConfig.from_pretrained(
111
+ pretrained_model_name_or_path=args.model_dir,
112
+ )
113
+ generator = MPNetPretrainedModel.from_pretrained(
114
+ pretrained_model_name_or_path=args.model_dir,
115
+ )
116
+ generator.to(device)
117
+ generator.eval()
118
+
119
+ logger.info("read excel")
120
+ df = pd.read_excel(args.valid_dataset)
121
+
122
+ progress_bar = tqdm(total=len(df), desc="Evaluation")
123
+ for idx, row in df.iterrows():
124
+ noise_filename = row["noise_filename"]
125
+ noise_offset = row["noise_offset"]
126
+ noise_duration = row["noise_duration"]
127
+
128
+ speech_filename = row["speech_filename"]
129
+ speech_offset = row["speech_offset"]
130
+ speech_duration = row["speech_duration"]
131
+
132
+ snr_db = row["snr_db"]
133
+
134
+ noise_audio, _ = librosa.load(
135
+ noise_filename,
136
+ sr=8000,
137
+ offset=noise_offset,
138
+ duration=noise_duration,
139
+ )
140
+ clean_audio, _ = librosa.load(
141
+ speech_filename,
142
+ sr=8000,
143
+ offset=speech_offset,
144
+ duration=speech_duration,
145
+ )
146
+ noisy_audio: np.ndarray = mix_speech_and_noise(
147
+ speech=clean_audio,
148
+ noise=noise_audio,
149
+ snr_db=snr_db,
150
+ )
151
+ noise_audio = torch.tensor(noise_audio, dtype=torch.float32)
152
+ clean_audio = torch.tensor(clean_audio, dtype=torch.float32)
153
+ noisy_audio: torch.Tensor = torch.tensor(noisy_audio, dtype=torch.float32)
154
+
155
+ noise_audio = noise_audio.unsqueeze(dim=0)
156
+ clean_audio = clean_audio.unsqueeze(dim=0)
157
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
158
+
159
+ # inference
160
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
161
+ noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor
162
+ )
163
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
164
+ audio_g = mag_pha_istft(
165
+ mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor
166
+ )
167
+ enhanced_audio = audio_g.detach()
168
+
169
+ save_audios(
170
+ noise_audio, clean_audio, noisy_audio,
171
+ enhanced_audio,
172
+ args.evaluation_audio_dir
173
+ )
174
+
175
+ progress_bar.update(1)
176
+
177
+ if idx > args.limit:
178
+ break
179
+
180
+ return
181
 
182
 
183
  if __name__ == '__main__':
184
+ main()
toolbox/torchaudio/models/clean_unet/configuration_clean_unet.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class CleanUnetConfig(PretrainedConfig):
7
+ def __init__(self,
8
+ channels_input: int = 1,
9
+ channels_output: int = 1,
10
+
11
+ channels_h: int = 64,
12
+ max_h: int = 768,
13
+
14
+ encoder_n_layers: int = 8,
15
+ kernel_size: int = 4,
16
+ stride: int = 2,
17
+ tsfm_n_layers: int = 5,
18
+ tsfm_n_head: int = 8,
19
+ tsfm_d_model: int = 512,
20
+ tsfm_d_inner: int = 2048,
21
+
22
+ **kwargs
23
+ ):
24
+ super(CleanUnetConfig, self).__init__(**kwargs)
25
+ self.channels_input = channels_input
26
+ self.channels_output = channels_output
27
+
28
+ self.channels_h = channels_h
29
+ self.max_h = max_h
30
+
31
+ self.encoder_n_layers = encoder_n_layers
32
+ self.kernel_size = kernel_size
33
+ self.stride = stride
34
+ self.tsfm_n_layers = tsfm_n_layers
35
+ self.tsfm_n_head = tsfm_n_head
36
+ self.tsfm_d_model = tsfm_d_model
37
+ self.tsfm_d_inner = tsfm_d_inner
38
+
39
+
40
+ if __name__ == "__main__":
41
+ pass
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -2,8 +2,287 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://arxiv.org/abs/2202.07790
 
 
 
 
 
5
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  if __name__ == '__main__':
9
- pass
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://arxiv.org/abs/2202.07790
5
+
6
+ https://github.com/nvidia/cleanunet
7
+
8
+ https://huggingface.co/spaces/fsoft-ai-center/Speech-Enhancement/blob/main/src/model.py
9
+
10
  """
11
+ import os
12
+ from typing import Optional, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
20
+ from toolbox.torchaudio.models.clean_unet.transformer import TransformerEncoder
21
+ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
22
+
23
+
24
+ def weight_scaling_init(layer):
25
+ """
26
+ weight rescaling initialization from https://arxiv.org/abs/1911.13254
27
+ """
28
+ w = layer.weight.detach()
29
+ alpha = 10.0 * w.std()
30
+ layer.weight.data /= torch.sqrt(alpha)
31
+ layer.bias.data /= torch.sqrt(alpha)
32
+
33
+
34
+ def print_size(net, keyword=None):
35
+ """
36
+ Print the number of parameters of a network
37
+ """
38
+
39
+ if net is not None and isinstance(net, torch.nn.Module):
40
+ module_parameters = filter(lambda p: p.requires_grad, net.parameters())
41
+ params = sum([np.prod(p.size()) for p in module_parameters])
42
+
43
+ print("{} Parameters: {:.6f}M".format(
44
+ net.__class__.__name__, params / 1e6), flush=True, end="; ")
45
+
46
+ if keyword is not None:
47
+ keyword_parameters = [p for name, p in net.named_parameters() if p.requires_grad and keyword in name]
48
+ params = sum([np.prod(p.size()) for p in keyword_parameters])
49
+ print("{} Parameters: {:.6f}M".format(
50
+ keyword, params / 1e6), flush=True, end="; ")
51
+
52
+ print(" ")
53
+
54
+
55
+ # CleanUNet architecture
56
+
57
+ def padding(x, D, K, S):
58
+ """padding zeroes to x so that denoised audio has the same length"""
59
+
60
+ L = x.shape[-1]
61
+ for _ in range(D):
62
+ if L < K:
63
+ L = 1
64
+ else:
65
+ L = 1 + np.ceil((L - K) / S)
66
+
67
+ for _ in range(D):
68
+ L = (L - 1) * S + K
69
+
70
+ L = int(L)
71
+ x = F.pad(x, (0, L - x.shape[-1]))
72
+ return x
73
+
74
+
75
+ class CleanUNet(nn.Module):
76
+ """
77
+ CleanUNet architecture.
78
+ """
79
+
80
+ def __init__(self,
81
+ channels_input=1, channels_output=1,
82
+ channels_h=64, max_h=768,
83
+ encoder_n_layers=8, kernel_size=4, stride=2,
84
+ tsfm_n_layers=3,
85
+ tsfm_n_head=8,
86
+ tsfm_d_model=512,
87
+ tsfm_d_inner=2048):
88
+ """
89
+ Parameters:
90
+ channels_input (int): input channels
91
+ channels_output (int): output channels
92
+ channels_H (int): middle channels H that controls capacity
93
+ max_H (int): maximum H
94
+ encoder_n_layers (int): number of encoder/decoder layers D
95
+ kernel_size (int): kernel size K
96
+ stride (int): stride S
97
+ tsfm_n_layers (int): number of self attention blocks N
98
+ tsfm_n_head (int): number of heads in each self attention block
99
+ tsfm_d_model (int): d_model of self attention
100
+ tsfm_d_inner (int): d_inner of self attention
101
+ """
102
+
103
+ super(CleanUNet, self).__init__()
104
+
105
+ self.channels_input = channels_input
106
+ self.channels_output = channels_output
107
+ self.channels_h = channels_h
108
+ self.max_h = max_h
109
+ self.encoder_n_layers = encoder_n_layers
110
+ self.kernel_size = kernel_size
111
+ self.stride = stride
112
+
113
+ self.tsfm_n_layers = tsfm_n_layers
114
+ self.tsfm_n_head = tsfm_n_head
115
+ self.tsfm_d_model = tsfm_d_model
116
+ self.tsfm_d_inner = tsfm_d_inner
117
+
118
+ # encoder and decoder
119
+ self.encoder = nn.ModuleList()
120
+ self.decoder = nn.ModuleList()
121
+
122
+ for i in range(encoder_n_layers):
123
+ self.encoder.append(nn.Sequential(
124
+ nn.Conv1d(channels_input, channels_h, kernel_size, stride),
125
+ nn.ReLU(),
126
+ nn.Conv1d(channels_h, channels_h * 2, 1),
127
+ nn.GLU(dim=1)
128
+ ))
129
+ channels_input = channels_h
130
+
131
+ if i == 0:
132
+ # no relu at end
133
+ self.decoder.append(nn.Sequential(
134
+ nn.Conv1d(channels_h, channels_h * 2, 1),
135
+ nn.GLU(dim=1),
136
+ nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride)
137
+ ))
138
+ else:
139
+ self.decoder.insert(0, nn.Sequential(
140
+ nn.Conv1d(channels_h, channels_h * 2, 1),
141
+ nn.GLU(dim=1),
142
+ nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
143
+ nn.ReLU()
144
+ ))
145
+ channels_output = channels_h
146
+
147
+ # double H but keep below max_H
148
+ channels_h *= 2
149
+ channels_h = min(channels_h, max_h)
150
+
151
+ # self attention block
152
+ self.tsfm_conv1 = nn.Conv1d(channels_output, tsfm_d_model, kernel_size=1)
153
+ self.tsfm_encoder = TransformerEncoder(d_word_vec=tsfm_d_model,
154
+ n_layers=tsfm_n_layers,
155
+ n_head=tsfm_n_head,
156
+ d_k=tsfm_d_model // tsfm_n_head,
157
+ d_v=tsfm_d_model // tsfm_n_head,
158
+ d_model=tsfm_d_model,
159
+ d_inner=tsfm_d_inner,
160
+ dropout=0.0,
161
+ n_position=0,
162
+ scale_emb=False)
163
+ self.tsfm_conv2 = nn.Conv1d(tsfm_d_model, channels_output, kernel_size=1)
164
+
165
+ # weight scaling initialization
166
+ for layer in self.modules():
167
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
168
+ weight_scaling_init(layer)
169
+
170
+ def forward(self, noisy_audio):
171
+ # (B, L) -> (B, C, L)
172
+ if len(noisy_audio.shape) == 2:
173
+ noisy_audio = noisy_audio.unsqueeze(1)
174
+ B, C, L = noisy_audio.shape
175
+ assert C == 1
176
+
177
+ # normalization and padding
178
+ std = noisy_audio.std(dim=2, keepdim=True) + 1e-3
179
+ noisy_audio /= std
180
+ x = padding(noisy_audio, self.encoder_n_layers, self.kernel_size, self.stride)
181
+
182
+ # encoder
183
+ skip_connections = []
184
+ for downsampling_block in self.encoder:
185
+ x = downsampling_block(x)
186
+ skip_connections.append(x)
187
+ skip_connections = skip_connections[::-1]
188
+
189
+ # attention mask for causal inference; for non-causal, set attn_mask to None
190
+ len_s = x.shape[-1] # length at bottleneck
191
+ attn_mask = (1 - torch.triu(torch.ones((1, len_s, len_s), device=x.device), diagonal=1)).bool()
192
+
193
+ x = self.tsfm_conv1(x) # C 1024 -> 512
194
+ x = x.permute(0, 2, 1)
195
+ x = self.tsfm_encoder(x, src_mask=attn_mask)
196
+ x = x.permute(0, 2, 1)
197
+ x = self.tsfm_conv2(x) # C 512 -> 1024
198
+
199
+ # decoder
200
+ for i, upsampling_block in enumerate(self.decoder):
201
+ skip_i = skip_connections[i]
202
+ x += skip_i[:, :, :x.shape[-1]]
203
+ x = upsampling_block(x)
204
+
205
+ x = x[:, :, :L] * std
206
+ return x
207
+
208
+
209
+ MODEL_FILE = "model.pt"
210
+
211
+
212
+ class CleanUNetPretrainedModel(CleanUNet):
213
+ def __init__(self,
214
+ config: CleanUnetConfig,
215
+ ):
216
+ super(CleanUNetPretrainedModel, self).__init__(
217
+ channels_input=config.channels_input,
218
+ channels_output=config.channels_output,
219
+ channels_h=config.channels_h,
220
+ max_h=config.max_h,
221
+ encoder_n_layers=config.encoder_n_layers,
222
+ kernel_size=config.kernel_size,
223
+ stride=config.stride,
224
+ tsfm_n_layers=config.tsfm_n_layers,
225
+ tsfm_n_head=config.tsfm_n_head,
226
+ tsfm_d_model=config.tsfm_d_model,
227
+ tsfm_d_inner=config.tsfm_d_inner,
228
+ )
229
+ self.config = config
230
+
231
+ @classmethod
232
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
233
+ config = CleanUnetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
234
+
235
+ model = cls(config)
236
+
237
+ if os.path.isdir(pretrained_model_name_or_path):
238
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
239
+ else:
240
+ ckpt_file = pretrained_model_name_or_path
241
+
242
+ with open(ckpt_file, "rb") as f:
243
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
244
+ model.load_state_dict(state_dict, strict=True)
245
+ return model
246
+
247
+ def save_pretrained(self,
248
+ save_directory: Union[str, os.PathLike],
249
+ state_dict: Optional[dict] = None,
250
+ ):
251
+
252
+ model = self
253
+
254
+ if state_dict is None:
255
+ state_dict = model.state_dict()
256
+
257
+ os.makedirs(save_directory, exist_ok=True)
258
+
259
+ # save state dict
260
+ model_file = os.path.join(save_directory, MODEL_FILE)
261
+ torch.save(state_dict, model_file)
262
+
263
+ # save config
264
+ config_file = os.path.join(save_directory, CONFIG_FILE)
265
+ self.config.to_yaml_file(config_file)
266
+ return save_directory
267
+
268
+
269
+ def main():
270
+
271
+ config = CleanUnetConfig()
272
+ model = CleanUNetPretrainedModel(config)
273
+
274
+ print_size(model, keyword="tsfm")
275
+
276
+ input_data = torch.ones([4, 1, int(4.5 * 16000)])
277
+ output = model(input_data)
278
+ print(output.shape)
279
+
280
+ # y = torch.rand([4, 1, int(4.5 * 16000)])
281
+ # loss = torch.nn.MSELoss()(y, output)
282
+ # loss.backward()
283
+ # print(loss.item())
284
+ return
285
 
286
 
287
  if __name__ == '__main__':
288
+ main()
toolbox/torchaudio/models/clean_unet/transformer.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ # Transformer (encoder) https://github.com/jadore801120/attention-is-all-you-need-pytorch
11
+ # Original Copyright 2017 Victor Huang
12
+ # MIT License (https://opensource.org/licenses/MIT)
13
+
14
+ class ScaledDotProductAttention(nn.Module):
15
+ """
16
+ Scaled Dot-Product Attention
17
+ """
18
+
19
+ def __init__(self, temperature, attn_dropout=0.1):
20
+ super().__init__()
21
+ self.temperature = temperature
22
+ self.dropout = nn.Dropout(attn_dropout)
23
+
24
+ def forward(self, q, k, v, mask=None):
25
+
26
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
27
+
28
+ if mask is not None:
29
+ attn = attn.masked_fill(mask == 0, -1e9)
30
+
31
+ attn = self.dropout(F.softmax(attn, dim=-1))
32
+ output = torch.matmul(attn, v)
33
+
34
+ return output, attn
35
+
36
+
37
+ class MultiHeadAttention(nn.Module):
38
+ """
39
+ Multi-Head Attention module
40
+ """
41
+
42
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
43
+ super().__init__()
44
+
45
+ self.n_head = n_head
46
+ self.d_k = d_k
47
+ self.d_v = d_v
48
+
49
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
50
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
51
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
52
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
53
+
54
+ self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
55
+
56
+ self.dropout = nn.Dropout(dropout)
57
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
58
+
59
+ def forward(self, q, k, v, mask=None):
60
+
61
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
62
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
63
+
64
+ residual = q
65
+
66
+ # Pass through the pre-attention projection: b x lq x (n*dv)
67
+ # Separate different heads: b x lq x n x dv
68
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
69
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
70
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
71
+
72
+ # Transpose for attention dot product: b x n x lq x dv
73
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
74
+
75
+ if mask is not None:
76
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
77
+
78
+ q, attn = self.attention(q, k, v, mask=mask)
79
+
80
+ # Transpose to move the head dimension back: b x lq x n x dv
81
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
82
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
83
+ q = self.dropout(self.fc(q))
84
+ q += residual
85
+
86
+ q = self.layer_norm(q)
87
+
88
+ return q, attn
89
+
90
+
91
+ class PositionwiseFeedForward(nn.Module):
92
+ """
93
+ A two-feed-forward-layer module
94
+ """
95
+
96
+ def __init__(self, d_in, d_hid, dropout=0.1):
97
+ super().__init__()
98
+ self.w_1 = nn.Linear(d_in, d_hid) # position-wise
99
+ self.w_2 = nn.Linear(d_hid, d_in) # position-wise
100
+ self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
101
+ self.dropout = nn.Dropout(dropout)
102
+
103
+ def forward(self, x):
104
+
105
+ residual = x
106
+
107
+ x = self.w_2(F.relu(self.w_1(x)))
108
+ x = self.dropout(x)
109
+ x += residual
110
+
111
+ x = self.layer_norm(x)
112
+
113
+ return x
114
+
115
+
116
+ def get_subsequent_mask(seq):
117
+ """
118
+ For masking out the subsequent info.
119
+ """
120
+ sz_b, len_s = seq.size()
121
+ subsequent_mask = (1 - torch.triu(
122
+ torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
123
+ return subsequent_mask
124
+
125
+
126
+ class PositionalEncoding(nn.Module):
127
+
128
+ def __init__(self, d_hid, n_position=200):
129
+ super(PositionalEncoding, self).__init__()
130
+
131
+ # Not a parameter
132
+ self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
133
+
134
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
135
+ """
136
+ Sinusoid position encoding table
137
+ """
138
+ # TODO: make it with torch instead of numpy
139
+
140
+ def get_position_angle_vec(position):
141
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
142
+
143
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
144
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
145
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
146
+
147
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
148
+
149
+ def forward(self, x):
150
+ return x + self.pos_table[:, :x.size(1)].clone().detach()
151
+
152
+
153
+ class EncoderLayer(nn.Module):
154
+ """
155
+ Compose with two layers
156
+ """
157
+
158
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0):
159
+ super(EncoderLayer, self).__init__()
160
+ self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
161
+ self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
162
+
163
+ def forward(self, enc_input, slf_attn_mask=None):
164
+ enc_output, enc_slf_attn = self.slf_attn(
165
+ enc_input, enc_input, enc_input, mask=slf_attn_mask)
166
+ enc_output = self.pos_ffn(enc_output)
167
+ return enc_output, enc_slf_attn
168
+
169
+
170
+ class TransformerEncoder(nn.Module):
171
+ """
172
+ A encoder model with self attention mechanism.
173
+ """
174
+
175
+ def __init__(
176
+ self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64,
177
+ d_model=512, d_inner=2048, dropout=0.1, n_position=624, scale_emb=False):
178
+
179
+ super().__init__()
180
+
181
+ # self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
182
+ if n_position > 0:
183
+ self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
184
+ else:
185
+ self.position_enc = lambda x: x
186
+ self.dropout = nn.Dropout(p=dropout)
187
+ self.layer_stack = nn.ModuleList([
188
+ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
189
+ for _ in range(n_layers)])
190
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
191
+ self.scale_emb = scale_emb
192
+ self.d_model = d_model
193
+
194
+ def forward(self, src_seq, src_mask, return_attns=False):
195
+
196
+ enc_slf_attn_list = []
197
+
198
+ # -- Forward
199
+ # enc_output = self.src_word_emb(src_seq)
200
+ enc_output = src_seq
201
+ if self.scale_emb:
202
+ enc_output *= self.d_model ** 0.5
203
+ enc_output = self.dropout(self.position_enc(enc_output))
204
+ enc_output = self.layer_norm(enc_output)
205
+
206
+ for enc_layer in self.layer_stack:
207
+ enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
208
+ enc_slf_attn_list += [enc_slf_attn] if return_attns else []
209
+
210
+ if return_attns:
211
+ return enc_output, enc_slf_attn_list
212
+ return enc_output
213
+
214
+
215
+ if __name__ == '__main__':
216
+ pass
toolbox/torchaudio/models/clean_unet/yaml/config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "clean_unet"
2
+
3
+ channels_input: 1
4
+ channels_output: 1
5
+ channels_h: 64
6
+ max_h: 768
7
+ encoder_n_layers: 8
8
+ kernel_size: 4
9
+ stride: 2
10
+ tsfm_n_layers: 5
11
+ tsfm_n_head: 8
12
+ tsfm_d_model: 512
13
+ tsfm_d_inner: 2048
toolbox/torchaudio/models/mpnet/{configuation_mpnet.py → configuration_mpnet.py} RENAMED
File without changes
toolbox/torchaudio/models/mpnet/modeling_mpnet.py CHANGED
@@ -3,6 +3,8 @@
3
  """
4
  https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py
5
 
 
 
6
  https://arxiv.org/abs/2305.13686
7
  https://github.com/yxlu-0102/MP-SENet
8
 
@@ -19,7 +21,7 @@ import torch.nn as nn
19
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
20
  from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock
21
  from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock
22
- from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
23
  from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d
24
 
25
 
 
3
  """
4
  https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py
5
 
6
+ https://huggingface.co/spaces/JacobLinCool/MP-SENet
7
+
8
  https://arxiv.org/abs/2305.13686
9
  https://github.com/yxlu-0102/MP-SENet
10
 
 
21
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
22
  from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock
23
  from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock
24
+ from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
25
  from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d
26
 
27