Spaces:
Running
Running
update
Browse files- examples/conv_tasnet/step_1_prepare_data.py +3 -1
- examples/dfnet/run.sh +153 -0
- examples/dfnet/step_1_prepare_data.py +164 -0
- examples/dfnet/step_2_train_model.py +440 -0
- examples/dfnet/yaml/config.yaml +53 -0
- examples/frcrn/step_1_prepare_data.py +6 -3
- examples/mpnet/step_1_prepare_data.py +2 -0
- toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py +2 -0
- toolbox/torchaudio/models/dfnet/modeling_dfnet.py +103 -38
examples/conv_tasnet/step_1_prepare_data.py
CHANGED
@@ -107,7 +107,7 @@ def main():
|
|
107 |
process_bar = tqdm(desc="build dataset excel")
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
-
if count >= args.max_count:
|
111 |
break
|
112 |
|
113 |
noise_filename = noise["filename"]
|
@@ -124,6 +124,8 @@ def main():
|
|
124 |
random2 = random.random()
|
125 |
|
126 |
row = {
|
|
|
|
|
127 |
"noise_filename": noise_filename,
|
128 |
"noise_raw_duration": noise_raw_duration,
|
129 |
"noise_offset": noise_offset,
|
|
|
107 |
process_bar = tqdm(desc="build dataset excel")
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
+
if count >= args.max_count > 0:
|
111 |
break
|
112 |
|
113 |
noise_filename = noise["filename"]
|
|
|
124 |
random2 = random.random()
|
125 |
|
126 |
row = {
|
127 |
+
"count": count,
|
128 |
+
|
129 |
"noise_filename": noise_filename,
|
130 |
"noise_raw_duration": noise_raw_duration,
|
131 |
"noise_offset": noise_offset,
|
examples/dfnet/run.sh
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
|
6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn \
|
7 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
9 |
+
|
10 |
+
|
11 |
+
END
|
12 |
+
|
13 |
+
|
14 |
+
# params
|
15 |
+
system_version="windows";
|
16 |
+
verbose=true;
|
17 |
+
stage=0 # start from 0 if you need to start from data preparation
|
18 |
+
stop_stage=9
|
19 |
+
|
20 |
+
work_dir="$(pwd)"
|
21 |
+
file_folder_name=file_folder_name
|
22 |
+
final_model_name=final_model_name
|
23 |
+
config_file="yaml/config.yaml"
|
24 |
+
limit=10
|
25 |
+
|
26 |
+
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
27 |
+
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
28 |
+
|
29 |
+
max_count=10000000
|
30 |
+
|
31 |
+
nohup_name=nohup.out
|
32 |
+
|
33 |
+
# model params
|
34 |
+
batch_size=64
|
35 |
+
max_epochs=200
|
36 |
+
save_top_k=10
|
37 |
+
patience=5
|
38 |
+
|
39 |
+
|
40 |
+
# parse options
|
41 |
+
while true; do
|
42 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
43 |
+
case "$1" in
|
44 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
45 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
46 |
+
old_value="(eval echo \\$$name)";
|
47 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
48 |
+
was_bool=true;
|
49 |
+
else
|
50 |
+
was_bool=false;
|
51 |
+
fi
|
52 |
+
|
53 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
54 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
55 |
+
eval "${name}=\"$2\"";
|
56 |
+
|
57 |
+
# Check that Boolean-valued arguments are really Boolean.
|
58 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
59 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
60 |
+
exit 1;
|
61 |
+
fi
|
62 |
+
shift 2;
|
63 |
+
;;
|
64 |
+
|
65 |
+
*) break;
|
66 |
+
esac
|
67 |
+
done
|
68 |
+
|
69 |
+
file_dir="${work_dir}/${file_folder_name}"
|
70 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
71 |
+
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
72 |
+
|
73 |
+
train_dataset="${file_dir}/train.jsonl"
|
74 |
+
valid_dataset="${file_dir}/valid.jsonl"
|
75 |
+
|
76 |
+
$verbose && echo "system_version: ${system_version}"
|
77 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
78 |
+
|
79 |
+
if [ $system_version == "windows" ]; then
|
80 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
81 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
82 |
+
#source /data/local/bin/nx_denoise/bin/activate
|
83 |
+
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
84 |
+
fi
|
85 |
+
|
86 |
+
|
87 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
88 |
+
$verbose && echo "stage 1: prepare data"
|
89 |
+
cd "${work_dir}" || exit 1
|
90 |
+
python3 step_1_prepare_data.py \
|
91 |
+
--file_dir "${file_dir}" \
|
92 |
+
--noise_dir "${noise_dir}" \
|
93 |
+
--speech_dir "${speech_dir}" \
|
94 |
+
--train_dataset "${train_dataset}" \
|
95 |
+
--valid_dataset "${valid_dataset}" \
|
96 |
+
--max_count "${max_count}" \
|
97 |
+
|
98 |
+
fi
|
99 |
+
|
100 |
+
|
101 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
102 |
+
$verbose && echo "stage 2: train model"
|
103 |
+
cd "${work_dir}" || exit 1
|
104 |
+
python3 step_2_train_model.py \
|
105 |
+
--train_dataset "${train_dataset}" \
|
106 |
+
--valid_dataset "${valid_dataset}" \
|
107 |
+
--serialization_dir "${file_dir}" \
|
108 |
+
--config_file "${config_file}" \
|
109 |
+
|
110 |
+
fi
|
111 |
+
|
112 |
+
|
113 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
114 |
+
$verbose && echo "stage 3: test model"
|
115 |
+
cd "${work_dir}" || exit 1
|
116 |
+
python3 step_3_evaluation.py \
|
117 |
+
--valid_dataset "${valid_dataset}" \
|
118 |
+
--model_dir "${file_dir}/best" \
|
119 |
+
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
120 |
+
--limit "${limit}" \
|
121 |
+
|
122 |
+
fi
|
123 |
+
|
124 |
+
|
125 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
126 |
+
$verbose && echo "stage 4: collect files"
|
127 |
+
cd "${work_dir}" || exit 1
|
128 |
+
|
129 |
+
mkdir -p ${final_model_dir}
|
130 |
+
|
131 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
132 |
+
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
133 |
+
|
134 |
+
cd "${final_model_dir}/.." || exit 1;
|
135 |
+
|
136 |
+
if [ -e "${final_model_name}.zip" ]; then
|
137 |
+
rm -rf "${final_model_name}_backup.zip"
|
138 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
139 |
+
fi
|
140 |
+
|
141 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
142 |
+
rm -rf "${final_model_name}"
|
143 |
+
|
144 |
+
fi
|
145 |
+
|
146 |
+
|
147 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
148 |
+
$verbose && echo "stage 5: clear file_dir"
|
149 |
+
cd "${work_dir}" || exit 1
|
150 |
+
|
151 |
+
rm -rf "${file_dir}";
|
152 |
+
|
153 |
+
fi
|
examples/dfnet/step_1_prepare_data.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import random
|
8 |
+
import sys
|
9 |
+
|
10 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"--noise_dir",
|
24 |
+
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
25 |
+
type=str
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--speech_dir",
|
29 |
+
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
30 |
+
type=str
|
31 |
+
)
|
32 |
+
|
33 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
34 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
35 |
+
|
36 |
+
parser.add_argument("--duration", default=4.0, type=float)
|
37 |
+
parser.add_argument("--min_snr_db", default=-10, type=float)
|
38 |
+
parser.add_argument("--max_snr_db", default=20, type=float)
|
39 |
+
|
40 |
+
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
+
|
42 |
+
parser.add_argument("--max_count", default=10000, type=int)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
return args
|
46 |
+
|
47 |
+
|
48 |
+
def filename_generator(data_dir: str):
|
49 |
+
data_dir = Path(data_dir)
|
50 |
+
for filename in data_dir.glob("**/*.wav"):
|
51 |
+
yield filename.as_posix()
|
52 |
+
|
53 |
+
|
54 |
+
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
55 |
+
data_dir = Path(data_dir)
|
56 |
+
for epoch_idx in range(max_epoch):
|
57 |
+
for filename in data_dir.glob("**/*.wav"):
|
58 |
+
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
59 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
60 |
+
|
61 |
+
if raw_duration < duration:
|
62 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
63 |
+
continue
|
64 |
+
if signal.ndim != 1:
|
65 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
66 |
+
|
67 |
+
signal_length = len(signal)
|
68 |
+
win_size = int(duration * sample_rate)
|
69 |
+
for begin in range(0, signal_length - win_size, win_size):
|
70 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
71 |
+
continue
|
72 |
+
row = {
|
73 |
+
"epoch_idx": epoch_idx,
|
74 |
+
"filename": filename.as_posix(),
|
75 |
+
"raw_duration": round(raw_duration, 4),
|
76 |
+
"offset": round(begin / sample_rate, 4),
|
77 |
+
"duration": round(duration, 4),
|
78 |
+
}
|
79 |
+
yield row
|
80 |
+
|
81 |
+
|
82 |
+
def main():
|
83 |
+
args = get_args()
|
84 |
+
|
85 |
+
file_dir = Path(args.file_dir)
|
86 |
+
file_dir.mkdir(exist_ok=True)
|
87 |
+
|
88 |
+
noise_dir = Path(args.noise_dir)
|
89 |
+
speech_dir = Path(args.speech_dir)
|
90 |
+
|
91 |
+
noise_generator = target_second_signal_generator(
|
92 |
+
noise_dir.as_posix(),
|
93 |
+
duration=args.duration,
|
94 |
+
sample_rate=args.target_sample_rate,
|
95 |
+
max_epoch=100000,
|
96 |
+
)
|
97 |
+
speech_generator = target_second_signal_generator(
|
98 |
+
speech_dir.as_posix(),
|
99 |
+
duration=args.duration,
|
100 |
+
sample_rate=args.target_sample_rate,
|
101 |
+
max_epoch=1,
|
102 |
+
)
|
103 |
+
|
104 |
+
dataset = list()
|
105 |
+
|
106 |
+
count = 0
|
107 |
+
process_bar = tqdm(desc="build dataset excel")
|
108 |
+
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
+
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
+
if count >= args.max_count > 0:
|
111 |
+
break
|
112 |
+
|
113 |
+
noise_filename = noise["filename"]
|
114 |
+
noise_raw_duration = noise["raw_duration"]
|
115 |
+
noise_offset = noise["offset"]
|
116 |
+
noise_duration = noise["duration"]
|
117 |
+
|
118 |
+
speech_filename = speech["filename"]
|
119 |
+
speech_raw_duration = speech["raw_duration"]
|
120 |
+
speech_offset = speech["offset"]
|
121 |
+
speech_duration = speech["duration"]
|
122 |
+
|
123 |
+
random1 = random.random()
|
124 |
+
random2 = random.random()
|
125 |
+
|
126 |
+
row = {
|
127 |
+
"count": count,
|
128 |
+
|
129 |
+
"noise_filename": noise_filename,
|
130 |
+
"noise_raw_duration": noise_raw_duration,
|
131 |
+
"noise_offset": noise_offset,
|
132 |
+
"noise_duration": noise_duration,
|
133 |
+
|
134 |
+
"speech_filename": speech_filename,
|
135 |
+
"speech_raw_duration": speech_raw_duration,
|
136 |
+
"speech_offset": speech_offset,
|
137 |
+
"speech_duration": speech_duration,
|
138 |
+
|
139 |
+
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
140 |
+
|
141 |
+
"random1": random1,
|
142 |
+
}
|
143 |
+
row = json.dumps(row, ensure_ascii=False)
|
144 |
+
if random2 < (1 / 300 / 1):
|
145 |
+
fvalid.write(f"{row}\n")
|
146 |
+
else:
|
147 |
+
ftrain.write(f"{row}\n")
|
148 |
+
|
149 |
+
count += 1
|
150 |
+
duration_seconds = count * args.duration
|
151 |
+
duration_hours = duration_seconds / 3600
|
152 |
+
|
153 |
+
process_bar.update(n=1)
|
154 |
+
process_bar.set_postfix({
|
155 |
+
# "duration_seconds": round(duration_seconds, 4),
|
156 |
+
"duration_hours": round(duration_hours, 4),
|
157 |
+
|
158 |
+
})
|
159 |
+
|
160 |
+
return
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
main()
|
examples/dfnet/step_2_train_model.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
from logging.handlers import TimedRotatingFileHandler
|
7 |
+
import os
|
8 |
+
import platform
|
9 |
+
from pathlib import Path
|
10 |
+
import random
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
from torch.utils.data.dataloader import DataLoader
|
23 |
+
from tqdm import tqdm
|
24 |
+
|
25 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
26 |
+
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
27 |
+
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
28 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
29 |
+
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
30 |
+
from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
|
31 |
+
|
32 |
+
|
33 |
+
def get_args():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
36 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
37 |
+
|
38 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
39 |
+
parser.add_argument("--patience", default=5, type=int)
|
40 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
41 |
+
|
42 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
return args
|
46 |
+
|
47 |
+
|
48 |
+
def logging_config(file_dir: str):
|
49 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
50 |
+
|
51 |
+
logging.basicConfig(format=fmt,
|
52 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
53 |
+
level=logging.INFO)
|
54 |
+
file_handler = TimedRotatingFileHandler(
|
55 |
+
filename=os.path.join(file_dir, "main.log"),
|
56 |
+
encoding="utf-8",
|
57 |
+
when="D",
|
58 |
+
interval=1,
|
59 |
+
backupCount=7
|
60 |
+
)
|
61 |
+
file_handler.setLevel(logging.INFO)
|
62 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
logger.addHandler(file_handler)
|
65 |
+
|
66 |
+
return logger
|
67 |
+
|
68 |
+
|
69 |
+
class CollateFunction(object):
|
70 |
+
def __init__(self):
|
71 |
+
pass
|
72 |
+
|
73 |
+
def __call__(self, batch: List[dict]):
|
74 |
+
clean_audios = list()
|
75 |
+
noisy_audios = list()
|
76 |
+
snr_db_list = list()
|
77 |
+
|
78 |
+
for sample in batch:
|
79 |
+
# noise_wave: torch.Tensor = sample["noise_wave"]
|
80 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
81 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
82 |
+
snr_db: float = sample["snr_db"]
|
83 |
+
|
84 |
+
clean_audios.append(clean_audio)
|
85 |
+
noisy_audios.append(noisy_audio)
|
86 |
+
snr_db_list.append(snr_db)
|
87 |
+
|
88 |
+
clean_audios = torch.stack(clean_audios)
|
89 |
+
noisy_audios = torch.stack(noisy_audios)
|
90 |
+
snr_db_list = torch.stack(snr_db_list)
|
91 |
+
|
92 |
+
# assert
|
93 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
94 |
+
raise AssertionError("nan or inf in clean_audios")
|
95 |
+
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
96 |
+
raise AssertionError("nan or inf in noisy_audios")
|
97 |
+
return clean_audios, noisy_audios, snr_db_list
|
98 |
+
|
99 |
+
|
100 |
+
collate_fn = CollateFunction()
|
101 |
+
|
102 |
+
|
103 |
+
def main():
|
104 |
+
args = get_args()
|
105 |
+
|
106 |
+
config = DfNetConfig.from_pretrained(
|
107 |
+
pretrained_model_name_or_path=args.config_file,
|
108 |
+
)
|
109 |
+
|
110 |
+
serialization_dir = Path(args.serialization_dir)
|
111 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
112 |
+
|
113 |
+
logger = logging_config(serialization_dir)
|
114 |
+
|
115 |
+
random.seed(config.seed)
|
116 |
+
np.random.seed(config.seed)
|
117 |
+
torch.manual_seed(config.seed)
|
118 |
+
logger.info(f"set seed: {config.seed}")
|
119 |
+
|
120 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
121 |
+
n_gpu = torch.cuda.device_count()
|
122 |
+
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
123 |
+
|
124 |
+
# datasets
|
125 |
+
train_dataset = DenoiseJsonlDataset(
|
126 |
+
jsonl_file=args.train_dataset,
|
127 |
+
expected_sample_rate=config.sample_rate,
|
128 |
+
max_wave_value=32768.0,
|
129 |
+
min_snr_db=config.min_snr_db,
|
130 |
+
max_snr_db=config.max_snr_db,
|
131 |
+
# skip=225000,
|
132 |
+
)
|
133 |
+
valid_dataset = DenoiseJsonlDataset(
|
134 |
+
jsonl_file=args.valid_dataset,
|
135 |
+
expected_sample_rate=config.sample_rate,
|
136 |
+
max_wave_value=32768.0,
|
137 |
+
min_snr_db=config.min_snr_db,
|
138 |
+
max_snr_db=config.max_snr_db,
|
139 |
+
)
|
140 |
+
train_data_loader = DataLoader(
|
141 |
+
dataset=train_dataset,
|
142 |
+
batch_size=config.batch_size,
|
143 |
+
# shuffle=True,
|
144 |
+
sampler=None,
|
145 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
146 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
147 |
+
collate_fn=collate_fn,
|
148 |
+
pin_memory=False,
|
149 |
+
prefetch_factor=2,
|
150 |
+
)
|
151 |
+
valid_data_loader = DataLoader(
|
152 |
+
dataset=valid_dataset,
|
153 |
+
batch_size=config.batch_size,
|
154 |
+
# shuffle=True,
|
155 |
+
sampler=None,
|
156 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
157 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
158 |
+
collate_fn=collate_fn,
|
159 |
+
pin_memory=False,
|
160 |
+
prefetch_factor=2,
|
161 |
+
)
|
162 |
+
|
163 |
+
# models
|
164 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
165 |
+
model = DfNetPretrainedModel(config).to(device)
|
166 |
+
model.to(device)
|
167 |
+
model.train()
|
168 |
+
|
169 |
+
# optimizer
|
170 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
171 |
+
optimizer = torch.optim.AdamW(model.named_parameters(), config.lr)
|
172 |
+
|
173 |
+
# resume training
|
174 |
+
last_step_idx = -1
|
175 |
+
last_epoch = -1
|
176 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
177 |
+
step_idx_str = Path(step_idx_str)
|
178 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
179 |
+
step_idx = int(step_idx)
|
180 |
+
if step_idx > last_step_idx:
|
181 |
+
last_step_idx = step_idx
|
182 |
+
# last_epoch = 1
|
183 |
+
|
184 |
+
if last_step_idx != -1:
|
185 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
186 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
187 |
+
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
188 |
+
|
189 |
+
logger.info(f"load state dict for model.")
|
190 |
+
with open(model_pt.as_posix(), "rb") as f:
|
191 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
192 |
+
model.load_state_dict(state_dict, strict=True)
|
193 |
+
|
194 |
+
logger.info(f"load state dict for optimizer.")
|
195 |
+
with open(optimizer_pth.as_posix(), "rb") as f:
|
196 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
197 |
+
optimizer.load_state_dict(state_dict)
|
198 |
+
|
199 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
200 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
201 |
+
optimizer,
|
202 |
+
last_epoch=last_epoch,
|
203 |
+
# T_max=10 * config.eval_steps,
|
204 |
+
# eta_min=0.01 * config.lr,
|
205 |
+
**config.lr_scheduler_kwargs,
|
206 |
+
)
|
207 |
+
elif config.lr_scheduler == "MultiStepLR":
|
208 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
209 |
+
optimizer,
|
210 |
+
last_epoch=last_epoch,
|
211 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
215 |
+
|
216 |
+
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
217 |
+
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
218 |
+
fft_size_list=[256, 512, 1024],
|
219 |
+
win_size_list=[256, 512, 1024],
|
220 |
+
hop_size_list=[128, 256, 512],
|
221 |
+
factor_sc=1.5,
|
222 |
+
factor_mag=1.0,
|
223 |
+
reduction="mean"
|
224 |
+
).to(device)
|
225 |
+
lsnr_loss_fn = nn.L1Loss(reduction="mean")
|
226 |
+
|
227 |
+
# training loop
|
228 |
+
|
229 |
+
# state
|
230 |
+
average_pesq_score = 1000000000
|
231 |
+
average_loss = 1000000000
|
232 |
+
average_neg_si_snr_loss = 1000000000
|
233 |
+
average_mask_loss = 1000000000
|
234 |
+
|
235 |
+
model_list = list()
|
236 |
+
best_epoch_idx = None
|
237 |
+
best_step_idx = None
|
238 |
+
best_metric = None
|
239 |
+
patience_count = 0
|
240 |
+
|
241 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
242 |
+
|
243 |
+
logger.info("training")
|
244 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
245 |
+
# train
|
246 |
+
model.train()
|
247 |
+
|
248 |
+
total_pesq_score = 0.
|
249 |
+
total_loss = 0.
|
250 |
+
total_neg_si_snr_loss = 0.
|
251 |
+
total_mask_loss = 0.
|
252 |
+
total_batches = 0.
|
253 |
+
|
254 |
+
progress_bar_train = tqdm(
|
255 |
+
initial=step_idx,
|
256 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
257 |
+
)
|
258 |
+
for train_batch in train_data_loader:
|
259 |
+
clean_audios, noisy_audios, snr_db_list = train_batch
|
260 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
261 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
262 |
+
snr_db_list: torch.Tensor = snr_db_list.to(device)
|
263 |
+
|
264 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
265 |
+
|
266 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
267 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
268 |
+
# mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
269 |
+
# neg_si_snr_loss = lsnr_loss_fn.forward(lsnr, snr_db_list)
|
270 |
+
|
271 |
+
loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
272 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
273 |
+
logger.info(f"find nan or inf in loss.")
|
274 |
+
continue
|
275 |
+
|
276 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
277 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
278 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
279 |
+
|
280 |
+
optimizer.zero_grad()
|
281 |
+
loss.backward()
|
282 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
283 |
+
optimizer.step()
|
284 |
+
lr_scheduler.step()
|
285 |
+
|
286 |
+
total_pesq_score += pesq_score
|
287 |
+
total_loss += loss.item()
|
288 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
289 |
+
total_mask_loss += mask_loss.item()
|
290 |
+
total_batches += 1
|
291 |
+
|
292 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
293 |
+
average_loss = round(total_loss / total_batches, 4)
|
294 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
295 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
296 |
+
|
297 |
+
progress_bar_train.update(1)
|
298 |
+
progress_bar_train.set_postfix({
|
299 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
300 |
+
"pesq_score": average_pesq_score,
|
301 |
+
"loss": average_loss,
|
302 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
303 |
+
"mask_loss": average_mask_loss,
|
304 |
+
})
|
305 |
+
|
306 |
+
# evaluation
|
307 |
+
step_idx += 1
|
308 |
+
if step_idx % config.eval_steps == 0:
|
309 |
+
with torch.no_grad():
|
310 |
+
torch.cuda.empty_cache()
|
311 |
+
|
312 |
+
total_pesq_score = 0.
|
313 |
+
total_loss = 0.
|
314 |
+
total_neg_si_snr_loss = 0.
|
315 |
+
total_mask_loss = 0.
|
316 |
+
total_batches = 0.
|
317 |
+
|
318 |
+
progress_bar_train.close()
|
319 |
+
progress_bar_eval = tqdm(
|
320 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
321 |
+
)
|
322 |
+
for eval_batch in valid_data_loader:
|
323 |
+
clean_audios, noisy_audios, snr_db_list = eval_batch
|
324 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
325 |
+
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
326 |
+
snr_db_list: torch.Tensor = snr_db_list.to(device)
|
327 |
+
|
328 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
329 |
+
|
330 |
+
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
331 |
+
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
332 |
+
|
333 |
+
loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
334 |
+
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
335 |
+
logger.info(f"find nan or inf in loss.")
|
336 |
+
continue
|
337 |
+
|
338 |
+
denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
|
339 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
340 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
341 |
+
|
342 |
+
total_pesq_score += pesq_score
|
343 |
+
total_loss += loss.item()
|
344 |
+
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
345 |
+
total_mask_loss += mask_loss.item()
|
346 |
+
total_batches += 1
|
347 |
+
|
348 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
349 |
+
average_loss = round(total_loss / total_batches, 4)
|
350 |
+
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
351 |
+
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
352 |
+
|
353 |
+
progress_bar_eval.update(1)
|
354 |
+
progress_bar_eval.set_postfix({
|
355 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
356 |
+
"pesq_score": average_pesq_score,
|
357 |
+
"loss": average_loss,
|
358 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
359 |
+
"mask_loss": average_mask_loss,
|
360 |
+
})
|
361 |
+
|
362 |
+
total_pesq_score = 0.
|
363 |
+
total_loss = 0.
|
364 |
+
total_neg_si_snr_loss = 0.
|
365 |
+
total_mask_loss = 0.
|
366 |
+
total_batches = 0.
|
367 |
+
|
368 |
+
progress_bar_eval.close()
|
369 |
+
progress_bar_train = tqdm(
|
370 |
+
initial=progress_bar_train.n,
|
371 |
+
postfix=progress_bar_train.postfix,
|
372 |
+
desc=progress_bar_train.desc,
|
373 |
+
)
|
374 |
+
|
375 |
+
# save path
|
376 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
377 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
378 |
+
|
379 |
+
# save models
|
380 |
+
model.save_pretrained(save_dir.as_posix())
|
381 |
+
|
382 |
+
model_list.append(save_dir)
|
383 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
384 |
+
model_to_delete: Path = model_list.pop(0)
|
385 |
+
shutil.rmtree(model_to_delete.as_posix())
|
386 |
+
|
387 |
+
# save optim
|
388 |
+
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
389 |
+
|
390 |
+
# save metric
|
391 |
+
if best_metric is None:
|
392 |
+
best_epoch_idx = epoch_idx
|
393 |
+
best_step_idx = step_idx
|
394 |
+
best_metric = average_pesq_score
|
395 |
+
elif average_pesq_score > best_metric:
|
396 |
+
# great is better.
|
397 |
+
best_epoch_idx = epoch_idx
|
398 |
+
best_step_idx = step_idx
|
399 |
+
best_metric = average_pesq_score
|
400 |
+
else:
|
401 |
+
pass
|
402 |
+
|
403 |
+
metrics = {
|
404 |
+
"epoch_idx": epoch_idx,
|
405 |
+
"best_epoch_idx": best_epoch_idx,
|
406 |
+
"best_step_idx": best_step_idx,
|
407 |
+
"pesq_score": average_pesq_score,
|
408 |
+
"loss": average_loss,
|
409 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
410 |
+
"mask_loss": average_mask_loss,
|
411 |
+
}
|
412 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
413 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
414 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
415 |
+
|
416 |
+
# save best
|
417 |
+
best_dir = serialization_dir / "best"
|
418 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
419 |
+
if best_dir.exists():
|
420 |
+
shutil.rmtree(best_dir)
|
421 |
+
shutil.copytree(save_dir, best_dir)
|
422 |
+
|
423 |
+
# early stop
|
424 |
+
early_stop_flag = False
|
425 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
426 |
+
patience_count = 0
|
427 |
+
else:
|
428 |
+
patience_count += 1
|
429 |
+
if patience_count >= args.patience:
|
430 |
+
early_stop_flag = True
|
431 |
+
|
432 |
+
# early stop
|
433 |
+
if early_stop_flag:
|
434 |
+
break
|
435 |
+
|
436 |
+
return
|
437 |
+
|
438 |
+
|
439 |
+
if __name__ == "__main__":
|
440 |
+
main()
|
examples/dfnet/yaml/config.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "dfnet"
|
2 |
+
|
3 |
+
# spec
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
|
9 |
+
spec_bins: 256
|
10 |
+
|
11 |
+
# model
|
12 |
+
conv_channels: 64
|
13 |
+
conv_kernel_size_input:
|
14 |
+
- 3
|
15 |
+
- 3
|
16 |
+
conv_kernel_size_inner:
|
17 |
+
- 1
|
18 |
+
- 3
|
19 |
+
conv_lookahead: 0
|
20 |
+
|
21 |
+
convt_kernel_size_inner:
|
22 |
+
- 1
|
23 |
+
- 3
|
24 |
+
|
25 |
+
embedding_hidden_size: 256
|
26 |
+
encoder_combine_op: "concat"
|
27 |
+
|
28 |
+
encoder_emb_skip_op: "none"
|
29 |
+
encoder_emb_linear_groups: 16
|
30 |
+
encoder_emb_hidden_size: 256
|
31 |
+
|
32 |
+
encoder_linear_groups: 32
|
33 |
+
|
34 |
+
lsnr_max: 30
|
35 |
+
lsnr_min: -15
|
36 |
+
norm_tau: 1.
|
37 |
+
|
38 |
+
decoder_emb_num_layers: 3
|
39 |
+
decoder_emb_skip_op: "none"
|
40 |
+
decoder_emb_linear_groups: 16
|
41 |
+
decoder_emb_hidden_size: 256
|
42 |
+
|
43 |
+
df_decoder_hidden_size: 256
|
44 |
+
df_num_layers: 2
|
45 |
+
df_order: 5
|
46 |
+
df_bins: 96
|
47 |
+
df_gru_skip: "grouped_linear"
|
48 |
+
df_decoder_linear_groups: 16
|
49 |
+
df_pathway_kernel_size_t: 5
|
50 |
+
df_lookahead: 2
|
51 |
+
|
52 |
+
# runtime
|
53 |
+
use_post_filter: true
|
examples/frcrn/step_1_prepare_data.py
CHANGED
@@ -39,7 +39,7 @@ def get_args():
|
|
39 |
|
40 |
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
|
42 |
-
parser.add_argument("--
|
43 |
|
44 |
args = parser.parse_args()
|
45 |
return args
|
@@ -107,8 +107,9 @@ def main():
|
|
107 |
process_bar = tqdm(desc="build dataset excel")
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
-
|
111 |
-
|
|
|
112 |
|
113 |
noise_filename = noise["filename"]
|
114 |
noise_raw_duration = noise["raw_duration"]
|
@@ -124,6 +125,8 @@ def main():
|
|
124 |
random2 = random.random()
|
125 |
|
126 |
row = {
|
|
|
|
|
127 |
"noise_filename": noise_filename,
|
128 |
"noise_raw_duration": noise_raw_duration,
|
129 |
"noise_offset": noise_offset,
|
|
|
39 |
|
40 |
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
|
42 |
+
parser.add_argument("--scale", default=1, type=float)
|
43 |
|
44 |
args = parser.parse_args()
|
45 |
return args
|
|
|
107 |
process_bar = tqdm(desc="build dataset excel")
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
+
flag = random.random()
|
111 |
+
if flag > args.scale:
|
112 |
+
continue
|
113 |
|
114 |
noise_filename = noise["filename"]
|
115 |
noise_raw_duration = noise["raw_duration"]
|
|
|
125 |
random2 = random.random()
|
126 |
|
127 |
row = {
|
128 |
+
"count": count,
|
129 |
+
|
130 |
"noise_filename": noise_filename,
|
131 |
"noise_raw_duration": noise_raw_duration,
|
132 |
"noise_offset": noise_offset,
|
examples/mpnet/step_1_prepare_data.py
CHANGED
@@ -119,6 +119,8 @@ def get_dataset(args):
|
|
119 |
random2 = random.random()
|
120 |
|
121 |
row = {
|
|
|
|
|
122 |
"noise_filename": noise_filename,
|
123 |
"noise_raw_duration": noise_raw_duration,
|
124 |
"noise_offset": noise_offset,
|
|
|
119 |
random2 = random.random()
|
120 |
|
121 |
row = {
|
122 |
+
"count": count,
|
123 |
+
|
124 |
"noise_filename": noise_filename,
|
125 |
"noise_raw_duration": noise_raw_duration,
|
126 |
"noise_offset": noise_offset,
|
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py
CHANGED
@@ -35,6 +35,8 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
35 |
self.buffer_samples: List[dict] = list()
|
36 |
|
37 |
def __iter__(self):
|
|
|
|
|
38 |
iterable_source = self.iterable_source()
|
39 |
|
40 |
try:
|
|
|
35 |
self.buffer_samples: List[dict] = list()
|
36 |
|
37 |
def __iter__(self):
|
38 |
+
self.buffer_samples = list()
|
39 |
+
|
40 |
iterable_source = self.iterable_source()
|
41 |
|
42 |
try:
|
toolbox/torchaudio/models/dfnet/modeling_dfnet.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn as nn
|
|
|
10 |
import torchaudio
|
11 |
|
12 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
@@ -752,11 +753,11 @@ class DeepFiltering(nn.Module):
|
|
752 |
coefs: torch.Tensor,
|
753 |
):
|
754 |
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
755 |
-
spec_u = self.spec_unfold(torch.view_as_complex(spec))
|
756 |
# spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]
|
757 |
|
758 |
# coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
759 |
-
coefs = torch.view_as_complex(coefs)
|
760 |
# coefs shape: [batch_size, df_order, time_steps, df_bins]
|
761 |
spec_f = spec_u.narrow(-2, 0, self.df_bins)
|
762 |
# spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
|
@@ -789,6 +790,13 @@ class DfNet(nn.Module):
|
|
789 |
super(DfNet, self).__init__()
|
790 |
self.config = config
|
791 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
792 |
self.stft = ConvSTFT(
|
793 |
nfft=config.nfft,
|
794 |
win_size=config.win_size,
|
@@ -820,32 +828,41 @@ class DfNet(nn.Module):
|
|
820 |
self.mask = Mask(use_post_filter=config.use_post_filter)
|
821 |
|
822 |
def forward(self,
|
823 |
-
|
824 |
):
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
# feat_power shape: [batch_size, 1, time_steps, spec_bins]
|
830 |
-
feat_power = feat_power.detach()
|
831 |
|
832 |
-
|
833 |
-
feat_spec
|
834 |
-
# spec shape: [batch_size, spec_bins, time_steps, 2]
|
835 |
-
feat_spec = feat_spec.permute(0, 3, 2, 1)
|
836 |
-
# feat_spec shape: [batch_size, 2, time_steps, spec_bins]
|
837 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
838 |
# feat_spec shape: [batch_size, 2, time_steps, df_bins]
|
839 |
-
feat_spec = feat_spec.detach()
|
840 |
-
|
841 |
-
# spec shape: [batch_size, spec_bins, time_steps]
|
842 |
-
spec = torch.unsqueeze(spec_complex, dim=1)
|
843 |
-
# spec shape: [batch_size, 1, spec_bins, time_steps]
|
844 |
-
spec = spec.permute(0, 1, 3, 2)
|
845 |
-
# spec shape: [batch_size, 1, time_steps, spec_bins]
|
846 |
-
spec = torch.view_as_real(spec)
|
847 |
-
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
848 |
-
spec = spec.detach()
|
849 |
|
850 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
851 |
|
@@ -865,7 +882,7 @@ class DfNet(nn.Module):
|
|
865 |
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
866 |
|
867 |
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
868 |
-
#
|
869 |
|
870 |
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
871 |
|
@@ -874,10 +891,68 @@ class DfNet(nn.Module):
|
|
874 |
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
875 |
|
876 |
mask = torch.squeeze(mask, dim=1)
|
877 |
-
|
878 |
# mask shape: [batch_size, spec_bins, time_steps]
|
879 |
|
880 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
881 |
|
882 |
|
883 |
class DfNetPretrainedModel(DfNet):
|
@@ -928,22 +1003,12 @@ class DfNetPretrainedModel(DfNet):
|
|
928 |
|
929 |
def main():
|
930 |
|
931 |
-
transformer = torchaudio.transforms.Spectrogram(
|
932 |
-
n_fft=512,
|
933 |
-
win_length=200,
|
934 |
-
hop_length=80,
|
935 |
-
window_fn=torch.hamming_window,
|
936 |
-
power=None,
|
937 |
-
)
|
938 |
-
|
939 |
config = DfNetConfig()
|
940 |
model = DfNetPretrainedModel(config=config)
|
941 |
|
942 |
-
|
943 |
-
spec_complex = transformer.forward(inputs)
|
944 |
-
spec_complex = spec_complex[:, :-1, :]
|
945 |
|
946 |
-
output = model.forward(
|
947 |
print(output[1].shape)
|
948 |
return
|
949 |
|
|
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
import torchaudio
|
12 |
|
13 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
|
|
753 |
coefs: torch.Tensor,
|
754 |
):
|
755 |
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
756 |
+
spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous()))
|
757 |
# spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]
|
758 |
|
759 |
# coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
760 |
+
coefs = torch.view_as_complex(coefs.contiguous())
|
761 |
# coefs shape: [batch_size, df_order, time_steps, df_bins]
|
762 |
spec_f = spec_u.narrow(-2, 0, self.df_bins)
|
763 |
# spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
|
|
|
790 |
super(DfNet, self).__init__()
|
791 |
self.config = config
|
792 |
|
793 |
+
self.freq_bins = self.config.nfft // 2 + 1
|
794 |
+
|
795 |
+
self.nfft = config.nfft
|
796 |
+
self.win_size = config.win_size
|
797 |
+
self.hop_size = config.hop_size
|
798 |
+
self.win_type = config.win_type
|
799 |
+
|
800 |
self.stft = ConvSTFT(
|
801 |
nfft=config.nfft,
|
802 |
win_size=config.win_size,
|
|
|
828 |
self.mask = Mask(use_post_filter=config.use_post_filter)
|
829 |
|
830 |
def forward(self,
|
831 |
+
noisy: torch.Tensor,
|
832 |
):
|
833 |
+
if noisy.dim() == 2:
|
834 |
+
noisy = torch.unsqueeze(noisy, dim=1)
|
835 |
+
_, _, n_samples = noisy.shape
|
836 |
+
remainder = (n_samples - self.win_size) % self.hop_size
|
837 |
+
if remainder > 0:
|
838 |
+
n_samples_pad = self.hop_size - remainder
|
839 |
+
noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
|
840 |
+
|
841 |
+
# [batch_size, freq_bins * 2, time_steps]
|
842 |
+
cmp_spec = self.stft.forward(noisy)
|
843 |
+
# [batch_size, 1, freq_bins * 2, time_steps]
|
844 |
+
cmp_spec = torch.unsqueeze(cmp_spec, 1)
|
845 |
+
|
846 |
+
# [batch_size, 2, freq_bins, time_steps]
|
847 |
+
cmp_spec = torch.cat([
|
848 |
+
cmp_spec[:, :, :self.freq_bins, :],
|
849 |
+
cmp_spec[:, :, self.freq_bins:, :],
|
850 |
+
], dim=1)
|
851 |
+
# n//2+1 -> n//2; 257 -> 256
|
852 |
+
cmp_spec = cmp_spec[:, :, :-1, :]
|
853 |
+
|
854 |
+
spec = torch.unsqueeze(cmp_spec, dim=4)
|
855 |
+
# [batch_size, 2, freq_bins, time_steps, 1]
|
856 |
+
spec = spec.permute(0, 4, 3, 2, 1)
|
857 |
+
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
858 |
+
|
859 |
+
feat_power = torch.sum(torch.square(spec), dim=-1)
|
860 |
# feat_power shape: [batch_size, 1, time_steps, spec_bins]
|
|
|
861 |
|
862 |
+
feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
|
863 |
+
# feat_spec shape: [batch_size, 2, time_steps, freq_bins]
|
|
|
|
|
|
|
864 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
865 |
# feat_spec shape: [batch_size, 2, time_steps, df_bins]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
866 |
|
867 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
868 |
|
|
|
882 |
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
883 |
|
884 |
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
885 |
+
# est_spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
886 |
|
887 |
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
888 |
|
|
|
891 |
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
892 |
|
893 |
mask = torch.squeeze(mask, dim=1)
|
894 |
+
est_mask = mask.permute(0, 2, 1)
|
895 |
# mask shape: [batch_size, spec_bins, time_steps]
|
896 |
|
897 |
+
b, _, t, _ = spec_e.shape
|
898 |
+
est_spec = torch.cat(tensors=[
|
899 |
+
torch.concat(tensors=[
|
900 |
+
spec_e[..., 0],
|
901 |
+
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
902 |
+
], dim=1),
|
903 |
+
torch.concat(tensors=[
|
904 |
+
spec_e[..., 1],
|
905 |
+
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
906 |
+
], dim=1),
|
907 |
+
], dim=1)
|
908 |
+
# est_spec shape: [b, n+2, t]
|
909 |
+
est_wav = self.istft.forward(est_spec)
|
910 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
911 |
+
est_wav = est_wav[:, :n_samples]
|
912 |
+
# est_wav shape: [b, n_samples]
|
913 |
+
return est_spec, est_wav, est_mask, lsnr
|
914 |
+
|
915 |
+
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
916 |
+
"""
|
917 |
+
|
918 |
+
:param est_mask: torch.Tensor, shape: [b, n+2, t]
|
919 |
+
:param clean:
|
920 |
+
:param noisy:
|
921 |
+
:return:
|
922 |
+
"""
|
923 |
+
clean_stft = self.stft(clean)
|
924 |
+
clean_re = clean_stft[:, :self.freq_bins, :]
|
925 |
+
clean_im = clean_stft[:, self.freq_bins:, :]
|
926 |
+
|
927 |
+
noisy_stft = self.stft(noisy)
|
928 |
+
noisy_re = noisy_stft[:, :self.freq_bins, :]
|
929 |
+
noisy_im = noisy_stft[:, self.freq_bins:, :]
|
930 |
+
|
931 |
+
noisy_power = noisy_re ** 2 + noisy_im ** 2
|
932 |
+
|
933 |
+
sr = clean_re
|
934 |
+
yr = noisy_re
|
935 |
+
si = clean_im
|
936 |
+
yi = noisy_im
|
937 |
+
y_pow = noisy_power
|
938 |
+
# (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
|
939 |
+
gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
|
940 |
+
# (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
|
941 |
+
gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)
|
942 |
+
|
943 |
+
gth_mask_re[gth_mask_re > 2] = 1
|
944 |
+
gth_mask_re[gth_mask_re < -2] = -1
|
945 |
+
gth_mask_im[gth_mask_im > 2] = 1
|
946 |
+
gth_mask_im[gth_mask_im < -2] = -1
|
947 |
+
|
948 |
+
mask_re = est_mask[:, :self.freq_bins, :]
|
949 |
+
mask_im = est_mask[:, self.freq_bins:, :]
|
950 |
+
|
951 |
+
loss_re = F.mse_loss(gth_mask_re, mask_re)
|
952 |
+
loss_im = F.mse_loss(gth_mask_im, mask_im)
|
953 |
+
|
954 |
+
loss = loss_re + loss_im
|
955 |
+
return loss
|
956 |
|
957 |
|
958 |
class DfNetPretrainedModel(DfNet):
|
|
|
1003 |
|
1004 |
def main():
|
1005 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1006 |
config = DfNetConfig()
|
1007 |
model = DfNetPretrainedModel(config=config)
|
1008 |
|
1009 |
+
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
|
|
|
|
1010 |
|
1011 |
+
output = model.forward(noisy)
|
1012 |
print(output[1].shape)
|
1013 |
return
|
1014 |
|