Spaces:
Running
Running
add dfnet2
Browse files- examples/{clean_unet_aishell → clean_unet}/run.sh +0 -0
- examples/{clean_unet_aishell → clean_unet}/step_1_prepare_data.py +0 -0
- examples/{clean_unet_aishell → clean_unet}/step_2_train_model.py +0 -0
- examples/{clean_unet_aishell → clean_unet}/step_3_evaluation.py +0 -0
- examples/{clean_unet_aishell → clean_unet}/yaml/config.yaml +0 -0
- examples/conv_tasnet/step_2_train_model.py +2 -0
- examples/conv_tasnet_gan/run.sh +0 -156
- examples/conv_tasnet_gan/step_1_prepare_data.py +0 -162
- examples/conv_tasnet_gan/step_2_train_model.py +0 -582
- examples/conv_tasnet_gan/yaml/config.yaml +0 -31
- examples/conv_tasnet_gan/yaml/discriminator_config.yaml +0 -10
- examples/dfnet/step_2_train_model.py +2 -0
- examples/dfnet2/step_2_train_model.py +2 -0
- examples/dtln/step_2_train_model.py +2 -0
- examples/frcrn/step_2_train_model.py +2 -0
- examples/lstm/step_2_train_model.py +2 -0
- examples/rnnoise/step_2_train_model.py +3 -0
- toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py +3 -0
- toolbox/torchaudio/modules/utils/ema.py +87 -0
examples/{clean_unet_aishell → clean_unet}/run.sh
RENAMED
File without changes
|
examples/{clean_unet_aishell → clean_unet}/step_1_prepare_data.py
RENAMED
File without changes
|
examples/{clean_unet_aishell → clean_unet}/step_2_train_model.py
RENAMED
File without changes
|
examples/{clean_unet_aishell → clean_unet}/step_3_evaluation.py
RENAMED
File without changes
|
examples/{clean_unet_aishell → clean_unet}/yaml/config.yaml
RENAMED
File without changes
|
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -346,6 +346,7 @@ def main():
|
|
346 |
# evaluation
|
347 |
step_idx += 1
|
348 |
if step_idx % config.eval_steps == 0:
|
|
|
349 |
with torch.no_grad():
|
350 |
torch.cuda.empty_cache()
|
351 |
|
@@ -499,6 +500,7 @@ def main():
|
|
499 |
# early stop
|
500 |
if early_stop_flag:
|
501 |
break
|
|
|
502 |
|
503 |
return
|
504 |
|
|
|
346 |
# evaluation
|
347 |
step_idx += 1
|
348 |
if step_idx % config.eval_steps == 0:
|
349 |
+
model.eval()
|
350 |
with torch.no_grad():
|
351 |
torch.cuda.empty_cache()
|
352 |
|
|
|
500 |
# early stop
|
501 |
if early_stop_flag:
|
502 |
break
|
503 |
+
model.train()
|
504 |
|
505 |
return
|
506 |
|
examples/conv_tasnet_gan/run.sh
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
#!/usr/bin/env bash
|
2 |
-
|
3 |
-
: <<'END'
|
4 |
-
|
5 |
-
|
6 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
7 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
-
--max_epochs 400
|
10 |
-
|
11 |
-
|
12 |
-
END
|
13 |
-
|
14 |
-
|
15 |
-
# params
|
16 |
-
system_version="windows";
|
17 |
-
verbose=true;
|
18 |
-
stage=0 # start from 0 if you need to start from data preparation
|
19 |
-
stop_stage=9
|
20 |
-
|
21 |
-
work_dir="$(pwd)"
|
22 |
-
file_folder_name=file_folder_name
|
23 |
-
final_model_name=final_model_name
|
24 |
-
config_file="yaml/config.yaml"
|
25 |
-
discriminator_config_file="yaml/discriminator_config.yaml"
|
26 |
-
limit=10
|
27 |
-
|
28 |
-
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
29 |
-
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
30 |
-
|
31 |
-
max_count=10000000
|
32 |
-
|
33 |
-
nohup_name=nohup.out
|
34 |
-
|
35 |
-
# model params
|
36 |
-
batch_size=64
|
37 |
-
max_epochs=200
|
38 |
-
save_top_k=10
|
39 |
-
patience=5
|
40 |
-
|
41 |
-
|
42 |
-
# parse options
|
43 |
-
while true; do
|
44 |
-
[ -z "${1:-}" ] && break; # break if there are no arguments
|
45 |
-
case "$1" in
|
46 |
-
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
47 |
-
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
48 |
-
old_value="(eval echo \\$$name)";
|
49 |
-
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
50 |
-
was_bool=true;
|
51 |
-
else
|
52 |
-
was_bool=false;
|
53 |
-
fi
|
54 |
-
|
55 |
-
# Set the variable to the right value-- the escaped quotes make it work if
|
56 |
-
# the option had spaces, like --cmd "queue.pl -sync y"
|
57 |
-
eval "${name}=\"$2\"";
|
58 |
-
|
59 |
-
# Check that Boolean-valued arguments are really Boolean.
|
60 |
-
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
61 |
-
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
62 |
-
exit 1;
|
63 |
-
fi
|
64 |
-
shift 2;
|
65 |
-
;;
|
66 |
-
|
67 |
-
*) break;
|
68 |
-
esac
|
69 |
-
done
|
70 |
-
|
71 |
-
file_dir="${work_dir}/${file_folder_name}"
|
72 |
-
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
73 |
-
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
74 |
-
|
75 |
-
train_dataset="${file_dir}/train.jsonl"
|
76 |
-
valid_dataset="${file_dir}/valid.jsonl"
|
77 |
-
|
78 |
-
$verbose && echo "system_version: ${system_version}"
|
79 |
-
$verbose && echo "file_folder_name: ${file_folder_name}"
|
80 |
-
|
81 |
-
if [ $system_version == "windows" ]; then
|
82 |
-
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
83 |
-
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
84 |
-
#source /data/local/bin/nx_denoise/bin/activate
|
85 |
-
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
86 |
-
fi
|
87 |
-
|
88 |
-
|
89 |
-
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
90 |
-
$verbose && echo "stage 1: prepare data"
|
91 |
-
cd "${work_dir}" || exit 1
|
92 |
-
python3 step_1_prepare_data.py \
|
93 |
-
--file_dir "${file_dir}" \
|
94 |
-
--noise_dir "${noise_dir}" \
|
95 |
-
--speech_dir "${speech_dir}" \
|
96 |
-
--train_dataset "${train_dataset}" \
|
97 |
-
--valid_dataset "${valid_dataset}" \
|
98 |
-
--max_count "${max_count}" \
|
99 |
-
|
100 |
-
fi
|
101 |
-
|
102 |
-
|
103 |
-
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
104 |
-
$verbose && echo "stage 2: train model"
|
105 |
-
cd "${work_dir}" || exit 1
|
106 |
-
python3 step_2_train_model.py \
|
107 |
-
--train_dataset "${train_dataset}" \
|
108 |
-
--valid_dataset "${valid_dataset}" \
|
109 |
-
--serialization_dir "${file_dir}" \
|
110 |
-
--config_file "${config_file}" \
|
111 |
-
--discriminator_config_file "${discriminator_config_file}" \
|
112 |
-
|
113 |
-
fi
|
114 |
-
|
115 |
-
|
116 |
-
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
117 |
-
$verbose && echo "stage 3: test model"
|
118 |
-
cd "${work_dir}" || exit 1
|
119 |
-
python3 step_3_evaluation.py \
|
120 |
-
--valid_dataset "${valid_dataset}" \
|
121 |
-
--model_dir "${file_dir}/best" \
|
122 |
-
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
123 |
-
--limit "${limit}" \
|
124 |
-
|
125 |
-
fi
|
126 |
-
|
127 |
-
|
128 |
-
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
129 |
-
$verbose && echo "stage 4: collect files"
|
130 |
-
cd "${work_dir}" || exit 1
|
131 |
-
|
132 |
-
mkdir -p ${final_model_dir}
|
133 |
-
|
134 |
-
cp "${file_dir}/best"/* "${final_model_dir}"
|
135 |
-
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
136 |
-
|
137 |
-
cd "${final_model_dir}/.." || exit 1;
|
138 |
-
|
139 |
-
if [ -e "${final_model_name}.zip" ]; then
|
140 |
-
rm -rf "${final_model_name}_backup.zip"
|
141 |
-
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
142 |
-
fi
|
143 |
-
|
144 |
-
zip -r "${final_model_name}.zip" "${final_model_name}"
|
145 |
-
rm -rf "${final_model_name}"
|
146 |
-
|
147 |
-
fi
|
148 |
-
|
149 |
-
|
150 |
-
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
151 |
-
$verbose && echo "stage 5: clear file_dir"
|
152 |
-
cd "${work_dir}" || exit 1
|
153 |
-
|
154 |
-
rm -rf "${file_dir}";
|
155 |
-
|
156 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/conv_tasnet_gan/step_1_prepare_data.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import argparse
|
4 |
-
import json
|
5 |
-
import os
|
6 |
-
from pathlib import Path
|
7 |
-
import random
|
8 |
-
import sys
|
9 |
-
|
10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
12 |
-
|
13 |
-
import librosa
|
14 |
-
import numpy as np
|
15 |
-
from tqdm import tqdm
|
16 |
-
|
17 |
-
|
18 |
-
def get_args():
|
19 |
-
parser = argparse.ArgumentParser()
|
20 |
-
parser.add_argument("--file_dir", default="./", type=str)
|
21 |
-
|
22 |
-
parser.add_argument(
|
23 |
-
"--noise_dir",
|
24 |
-
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
25 |
-
type=str
|
26 |
-
)
|
27 |
-
parser.add_argument(
|
28 |
-
"--speech_dir",
|
29 |
-
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
30 |
-
type=str
|
31 |
-
)
|
32 |
-
|
33 |
-
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
34 |
-
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
35 |
-
|
36 |
-
parser.add_argument("--duration", default=4.0, type=float)
|
37 |
-
parser.add_argument("--min_snr_db", default=-10, type=float)
|
38 |
-
parser.add_argument("--max_snr_db", default=20, type=float)
|
39 |
-
|
40 |
-
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
-
|
42 |
-
parser.add_argument("--max_count", default=10000, type=int)
|
43 |
-
|
44 |
-
args = parser.parse_args()
|
45 |
-
return args
|
46 |
-
|
47 |
-
|
48 |
-
def filename_generator(data_dir: str):
|
49 |
-
data_dir = Path(data_dir)
|
50 |
-
for filename in data_dir.glob("**/*.wav"):
|
51 |
-
yield filename.as_posix()
|
52 |
-
|
53 |
-
|
54 |
-
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
|
55 |
-
data_dir = Path(data_dir)
|
56 |
-
for epoch_idx in range(max_epoch):
|
57 |
-
for filename in data_dir.glob("**/*.wav"):
|
58 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
59 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
60 |
-
|
61 |
-
if raw_duration < duration:
|
62 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
63 |
-
continue
|
64 |
-
if signal.ndim != 1:
|
65 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
66 |
-
|
67 |
-
signal_length = len(signal)
|
68 |
-
win_size = int(duration * sample_rate)
|
69 |
-
for begin in range(0, signal_length - win_size, win_size):
|
70 |
-
if np.sum(signal[begin: begin+win_size]) == 0:
|
71 |
-
continue
|
72 |
-
row = {
|
73 |
-
"epoch_idx": epoch_idx,
|
74 |
-
"filename": filename.as_posix(),
|
75 |
-
"raw_duration": round(raw_duration, 4),
|
76 |
-
"offset": round(begin / sample_rate, 4),
|
77 |
-
"duration": round(duration, 4),
|
78 |
-
}
|
79 |
-
yield row
|
80 |
-
|
81 |
-
|
82 |
-
def main():
|
83 |
-
args = get_args()
|
84 |
-
|
85 |
-
file_dir = Path(args.file_dir)
|
86 |
-
file_dir.mkdir(exist_ok=True)
|
87 |
-
|
88 |
-
noise_dir = Path(args.noise_dir)
|
89 |
-
speech_dir = Path(args.speech_dir)
|
90 |
-
|
91 |
-
noise_generator = target_second_signal_generator(
|
92 |
-
noise_dir.as_posix(),
|
93 |
-
duration=args.duration,
|
94 |
-
sample_rate=args.target_sample_rate,
|
95 |
-
max_epoch=100000,
|
96 |
-
)
|
97 |
-
speech_generator = target_second_signal_generator(
|
98 |
-
speech_dir.as_posix(),
|
99 |
-
duration=args.duration,
|
100 |
-
sample_rate=args.target_sample_rate,
|
101 |
-
max_epoch=1,
|
102 |
-
)
|
103 |
-
|
104 |
-
dataset = list()
|
105 |
-
|
106 |
-
count = 0
|
107 |
-
process_bar = tqdm(desc="build dataset excel")
|
108 |
-
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
-
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
-
if count >= args.max_count:
|
111 |
-
break
|
112 |
-
|
113 |
-
noise_filename = noise["filename"]
|
114 |
-
noise_raw_duration = noise["raw_duration"]
|
115 |
-
noise_offset = noise["offset"]
|
116 |
-
noise_duration = noise["duration"]
|
117 |
-
|
118 |
-
speech_filename = speech["filename"]
|
119 |
-
speech_raw_duration = speech["raw_duration"]
|
120 |
-
speech_offset = speech["offset"]
|
121 |
-
speech_duration = speech["duration"]
|
122 |
-
|
123 |
-
random1 = random.random()
|
124 |
-
random2 = random.random()
|
125 |
-
|
126 |
-
row = {
|
127 |
-
"noise_filename": noise_filename,
|
128 |
-
"noise_raw_duration": noise_raw_duration,
|
129 |
-
"noise_offset": noise_offset,
|
130 |
-
"noise_duration": noise_duration,
|
131 |
-
|
132 |
-
"speech_filename": speech_filename,
|
133 |
-
"speech_raw_duration": speech_raw_duration,
|
134 |
-
"speech_offset": speech_offset,
|
135 |
-
"speech_duration": speech_duration,
|
136 |
-
|
137 |
-
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
138 |
-
|
139 |
-
"random1": random1,
|
140 |
-
}
|
141 |
-
row = json.dumps(row, ensure_ascii=False)
|
142 |
-
if random2 < (1 / 300 / 1):
|
143 |
-
fvalid.write(f"{row}\n")
|
144 |
-
else:
|
145 |
-
ftrain.write(f"{row}\n")
|
146 |
-
|
147 |
-
count += 1
|
148 |
-
duration_seconds = count * args.duration
|
149 |
-
duration_hours = duration_seconds / 3600
|
150 |
-
|
151 |
-
process_bar.update(n=1)
|
152 |
-
process_bar.set_postfix({
|
153 |
-
# "duration_seconds": round(duration_seconds, 4),
|
154 |
-
"duration_hours": round(duration_hours, 4),
|
155 |
-
|
156 |
-
})
|
157 |
-
|
158 |
-
return
|
159 |
-
|
160 |
-
|
161 |
-
if __name__ == "__main__":
|
162 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/conv_tasnet_gan/step_2_train_model.py
DELETED
@@ -1,582 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
https://github.com/kaituoxu/Conv-TasNet/tree/master/src
|
5 |
-
|
6 |
-
一般场景:
|
7 |
-
|
8 |
-
目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
|
9 |
-
|
10 |
-
高要求场景(如医疗助听、语音识别):
|
11 |
-
需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
|
12 |
-
|
13 |
-
DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。
|
14 |
-
https://arxiv.org/abs/2205.05474
|
15 |
-
|
16 |
-
"""
|
17 |
-
import argparse
|
18 |
-
import json
|
19 |
-
import logging
|
20 |
-
from logging.handlers import TimedRotatingFileHandler
|
21 |
-
import os
|
22 |
-
import platform
|
23 |
-
from pathlib import Path
|
24 |
-
import random
|
25 |
-
import sys
|
26 |
-
import shutil
|
27 |
-
from typing import List
|
28 |
-
|
29 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
30 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
31 |
-
|
32 |
-
import numpy as np
|
33 |
-
import torch
|
34 |
-
import torch.nn as nn
|
35 |
-
from torch.nn import functional as F
|
36 |
-
from torch.utils.data.dataloader import DataLoader
|
37 |
-
from tqdm import tqdm
|
38 |
-
|
39 |
-
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
40 |
-
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
41 |
-
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
|
42 |
-
from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.modeling_waveform_metric_discriminator import WaveformMetricDiscriminatorPretrainedModel
|
43 |
-
from toolbox.torchaudio.models.discriminators.waveform_metric_discriminator.configuration_waveform_metric_discriminator import WaveformMetricDiscriminatorConfig
|
44 |
-
from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
|
45 |
-
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
46 |
-
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
47 |
-
from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
|
48 |
-
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
49 |
-
|
50 |
-
|
51 |
-
def get_args():
|
52 |
-
parser = argparse.ArgumentParser()
|
53 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
54 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
55 |
-
|
56 |
-
parser.add_argument("--max_epochs", default=200, type=int)
|
57 |
-
|
58 |
-
parser.add_argument("--batch_size", default=64, type=int)
|
59 |
-
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
60 |
-
parser.add_argument("--patience", default=5, type=int)
|
61 |
-
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
62 |
-
parser.add_argument("--seed", default=1234, type=int)
|
63 |
-
|
64 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
65 |
-
parser.add_argument("--discriminator_config_file", default="discriminator_config.yaml", type=str)
|
66 |
-
|
67 |
-
args = parser.parse_args()
|
68 |
-
return args
|
69 |
-
|
70 |
-
|
71 |
-
def logging_config(file_dir: str):
|
72 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
73 |
-
|
74 |
-
logging.basicConfig(format=fmt,
|
75 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
76 |
-
level=logging.INFO)
|
77 |
-
file_handler = TimedRotatingFileHandler(
|
78 |
-
filename=os.path.join(file_dir, "main.log"),
|
79 |
-
encoding="utf-8",
|
80 |
-
when="D",
|
81 |
-
interval=1,
|
82 |
-
backupCount=7
|
83 |
-
)
|
84 |
-
file_handler.setLevel(logging.INFO)
|
85 |
-
file_handler.setFormatter(logging.Formatter(fmt))
|
86 |
-
logger = logging.getLogger(__name__)
|
87 |
-
logger.addHandler(file_handler)
|
88 |
-
|
89 |
-
return logger
|
90 |
-
|
91 |
-
|
92 |
-
class CollateFunction(object):
|
93 |
-
def __init__(self):
|
94 |
-
pass
|
95 |
-
|
96 |
-
def __call__(self, batch: List[dict]):
|
97 |
-
clean_audios = list()
|
98 |
-
noisy_audios = list()
|
99 |
-
|
100 |
-
for sample in batch:
|
101 |
-
# noise_wave: torch.Tensor = sample["noise_wave"]
|
102 |
-
clean_audio: torch.Tensor = sample["speech_wave"]
|
103 |
-
noisy_audio: torch.Tensor = sample["mix_wave"]
|
104 |
-
# snr_db: float = sample["snr_db"]
|
105 |
-
|
106 |
-
clean_audios.append(clean_audio)
|
107 |
-
noisy_audios.append(noisy_audio)
|
108 |
-
|
109 |
-
clean_audios = torch.stack(clean_audios)
|
110 |
-
noisy_audios = torch.stack(noisy_audios)
|
111 |
-
|
112 |
-
# assert
|
113 |
-
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
114 |
-
raise AssertionError("nan or inf in clean_audios")
|
115 |
-
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
116 |
-
raise AssertionError("nan or inf in noisy_audios")
|
117 |
-
return clean_audios, noisy_audios
|
118 |
-
|
119 |
-
|
120 |
-
collate_fn = CollateFunction()
|
121 |
-
|
122 |
-
|
123 |
-
def main():
|
124 |
-
args = get_args()
|
125 |
-
|
126 |
-
config = ConvTasNetConfig.from_pretrained(
|
127 |
-
pretrained_model_name_or_path=args.config_file,
|
128 |
-
)
|
129 |
-
discriminator_config = WaveformMetricDiscriminatorConfig.from_pretrained(
|
130 |
-
pretrained_model_name_or_path=args.discriminator_config_file,
|
131 |
-
)
|
132 |
-
|
133 |
-
serialization_dir = Path(args.serialization_dir)
|
134 |
-
serialization_dir.mkdir(parents=True, exist_ok=True)
|
135 |
-
|
136 |
-
logger = logging_config(serialization_dir)
|
137 |
-
|
138 |
-
random.seed(args.seed)
|
139 |
-
np.random.seed(args.seed)
|
140 |
-
torch.manual_seed(args.seed)
|
141 |
-
logger.info(f"set seed: {args.seed}")
|
142 |
-
|
143 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
144 |
-
n_gpu = torch.cuda.device_count()
|
145 |
-
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
146 |
-
|
147 |
-
# datasets
|
148 |
-
train_dataset = DenoiseJsonlDataset(
|
149 |
-
jsonl_file=args.train_dataset,
|
150 |
-
expected_sample_rate=config.sample_rate,
|
151 |
-
max_wave_value=32768.0,
|
152 |
-
min_snr_db=config.min_snr_db,
|
153 |
-
max_snr_db=config.max_snr_db,
|
154 |
-
# skip=825000,
|
155 |
-
)
|
156 |
-
valid_dataset = DenoiseJsonlDataset(
|
157 |
-
jsonl_file=args.valid_dataset,
|
158 |
-
expected_sample_rate=config.sample_rate,
|
159 |
-
max_wave_value=32768.0,
|
160 |
-
min_snr_db=config.min_snr_db,
|
161 |
-
max_snr_db=config.max_snr_db,
|
162 |
-
)
|
163 |
-
train_data_loader = DataLoader(
|
164 |
-
dataset=train_dataset,
|
165 |
-
batch_size=args.batch_size,
|
166 |
-
# shuffle=True,
|
167 |
-
sampler=None,
|
168 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
169 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
170 |
-
collate_fn=collate_fn,
|
171 |
-
pin_memory=False,
|
172 |
-
prefetch_factor=2,
|
173 |
-
)
|
174 |
-
valid_data_loader = DataLoader(
|
175 |
-
dataset=valid_dataset,
|
176 |
-
batch_size=args.batch_size,
|
177 |
-
# shuffle=True,
|
178 |
-
sampler=None,
|
179 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
180 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
181 |
-
collate_fn=collate_fn,
|
182 |
-
pin_memory=False,
|
183 |
-
prefetch_factor=2,
|
184 |
-
)
|
185 |
-
|
186 |
-
# models
|
187 |
-
logger.info(f"prepare models. config_file: {args.config_file}")
|
188 |
-
model = ConvTasNetPretrainedModel(config).to(device)
|
189 |
-
model.to(device)
|
190 |
-
model.train()
|
191 |
-
|
192 |
-
discriminator = WaveformMetricDiscriminatorPretrainedModel(discriminator_config).to(device)
|
193 |
-
discriminator.to(device)
|
194 |
-
discriminator.train()
|
195 |
-
|
196 |
-
# optimizer
|
197 |
-
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
198 |
-
optimizer = torch.optim.AdamW(model.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
|
199 |
-
discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), config.lr, betas=[config.adam_b1, config.adam_b2])
|
200 |
-
|
201 |
-
# resume training
|
202 |
-
last_step_idx = -1
|
203 |
-
last_epoch = -1
|
204 |
-
for step_idx_str in serialization_dir.glob("steps-*"):
|
205 |
-
step_idx_str = Path(step_idx_str)
|
206 |
-
step_idx = step_idx_str.stem.split("-")[1]
|
207 |
-
step_idx = int(step_idx)
|
208 |
-
if step_idx > last_step_idx:
|
209 |
-
last_step_idx = step_idx
|
210 |
-
|
211 |
-
if last_step_idx != -1:
|
212 |
-
logger.info(f"resume from steps-{last_step_idx}.")
|
213 |
-
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
214 |
-
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
215 |
-
|
216 |
-
discriminator_pt = serialization_dir / f"steps-{last_step_idx}/discriminator.pt"
|
217 |
-
discriminator_optimizer_pth = serialization_dir / f"steps-{last_step_idx}/discriminator_optimizer.pth"
|
218 |
-
|
219 |
-
logger.info(f"load state dict for model.")
|
220 |
-
with open(model_pt.as_posix(), "rb") as f:
|
221 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
222 |
-
model.load_state_dict(state_dict, strict=True)
|
223 |
-
|
224 |
-
if optimizer_pth.exists():
|
225 |
-
logger.info(f"load state dict for optimizer.")
|
226 |
-
with open(optimizer_pth.as_posix(), "rb") as f:
|
227 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
228 |
-
optimizer.load_state_dict(state_dict)
|
229 |
-
|
230 |
-
if discriminator_pt.exists():
|
231 |
-
logger.info(f"load state dict for discriminator.")
|
232 |
-
with open(model_pt.as_posix(), "rb") as f:
|
233 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
234 |
-
discriminator.load_state_dict(state_dict, strict=True)
|
235 |
-
|
236 |
-
if discriminator_optimizer_pth.exists():
|
237 |
-
logger.info(f"load state dict for discriminator_optimizer.")
|
238 |
-
with open(optimizer_pth.as_posix(), "rb") as f:
|
239 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
240 |
-
discriminator_optimizer.load_state_dict(state_dict)
|
241 |
-
|
242 |
-
if config.lr_scheduler == "CosineAnnealingLR":
|
243 |
-
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
244 |
-
optimizer,
|
245 |
-
last_epoch=last_epoch,
|
246 |
-
# T_max=10 * config.eval_steps,
|
247 |
-
# eta_min=0.01 * config.lr,
|
248 |
-
**config.lr_scheduler_kwargs,
|
249 |
-
)
|
250 |
-
discriminator_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
251 |
-
discriminator_optimizer,
|
252 |
-
last_epoch=last_epoch,
|
253 |
-
# T_max=10 * config.eval_steps,
|
254 |
-
# eta_min=0.01 * config.lr,
|
255 |
-
**config.lr_scheduler_kwargs,
|
256 |
-
)
|
257 |
-
elif config.lr_scheduler == "MultiStepLR":
|
258 |
-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
259 |
-
optimizer,
|
260 |
-
last_epoch=last_epoch,
|
261 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
262 |
-
)
|
263 |
-
discriminator_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
264 |
-
discriminator_optimizer,
|
265 |
-
last_epoch=last_epoch,
|
266 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
267 |
-
)
|
268 |
-
else:
|
269 |
-
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
270 |
-
|
271 |
-
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
272 |
-
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
273 |
-
neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
|
274 |
-
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
275 |
-
fft_size_list=[256, 512, 1024],
|
276 |
-
win_size_list=[120, 240, 480],
|
277 |
-
hop_size_list=[25, 50, 100],
|
278 |
-
factor_sc=1.5,
|
279 |
-
factor_mag=1.0,
|
280 |
-
reduction="mean"
|
281 |
-
).to(device)
|
282 |
-
pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
|
283 |
-
|
284 |
-
# training loop
|
285 |
-
|
286 |
-
# state
|
287 |
-
average_pesq_score = 1000000000
|
288 |
-
average_loss = 1000000000
|
289 |
-
average_ae_loss = 1000000000
|
290 |
-
average_neg_si_snr_loss = 1000000000
|
291 |
-
average_neg_stoi_loss = 1000000000
|
292 |
-
average_mr_stft_loss = 1000000000
|
293 |
-
average_pesq_loss = 1000000000
|
294 |
-
average_discriminator_g_loss = 1000000000
|
295 |
-
average_discriminator_d_loss = 1000000000
|
296 |
-
|
297 |
-
model_list = list()
|
298 |
-
best_epoch_idx = None
|
299 |
-
best_step_idx = None
|
300 |
-
best_metric = None
|
301 |
-
patience_count = 0
|
302 |
-
|
303 |
-
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
304 |
-
|
305 |
-
logger.info("training")
|
306 |
-
for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
|
307 |
-
# train
|
308 |
-
model.train()
|
309 |
-
|
310 |
-
total_pesq_score = 0.
|
311 |
-
total_loss = 0.
|
312 |
-
total_ae_loss = 0.
|
313 |
-
total_neg_si_snr_loss = 0.
|
314 |
-
total_neg_stoi_loss = 0.
|
315 |
-
total_mr_stft_loss = 0.
|
316 |
-
total_pesq_loss = 0.
|
317 |
-
total_discriminator_g_loss = 0.
|
318 |
-
total_discriminator_d_loss = 0.
|
319 |
-
total_batches = 0.
|
320 |
-
|
321 |
-
progress_bar_train = tqdm(
|
322 |
-
initial=step_idx,
|
323 |
-
desc="Training; epoch-{}".format(epoch_idx),
|
324 |
-
)
|
325 |
-
for train_batch in train_data_loader:
|
326 |
-
clean_audios, noisy_audios = train_batch
|
327 |
-
clean_audios: torch.Tensor = clean_audios.to(device)
|
328 |
-
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
329 |
-
one_labels = torch.ones(clean_audios.shape[0]).to(device)
|
330 |
-
|
331 |
-
denoise_audios = model.forward(noisy_audios)
|
332 |
-
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
333 |
-
|
334 |
-
if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)):
|
335 |
-
raise AssertionError("nan or inf in denoise_audios")
|
336 |
-
|
337 |
-
# Discriminator
|
338 |
-
clean_audio_list = torch.split(clean_audios, 1, dim=0)
|
339 |
-
enhanced_audio_list = torch.split(denoise_audios, 1, dim=0)
|
340 |
-
clean_audio_list = [t.squeeze().detach().cpu().numpy() for t in clean_audio_list]
|
341 |
-
enhanced_audio_list = [t.squeeze().detach().cpu().numpy() for t in enhanced_audio_list]
|
342 |
-
|
343 |
-
pesq_score_list: List[float] = run_batch_pesq(clean_audio_list, enhanced_audio_list, sample_rate=config.sample_rate, mode="nb")
|
344 |
-
|
345 |
-
metric_r = discriminator.forward(clean_audios, clean_audios)
|
346 |
-
metric_g = discriminator.forward(denoise_audios.detach(), clean_audios)
|
347 |
-
loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
|
348 |
-
|
349 |
-
if -1 in pesq_score_list:
|
350 |
-
# print("-1 in batch_pesq_score!")
|
351 |
-
loss_disc_g = 0
|
352 |
-
else:
|
353 |
-
pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
|
354 |
-
loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
|
355 |
-
|
356 |
-
discriminator_d_loss = loss_disc_r + loss_disc_g
|
357 |
-
discriminator_optimizer.zero_grad()
|
358 |
-
discriminator_d_loss.backward()
|
359 |
-
discriminator_optimizer.step()
|
360 |
-
discriminator_lr_scheduler.step()
|
361 |
-
|
362 |
-
# Generator
|
363 |
-
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
364 |
-
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
365 |
-
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
366 |
-
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
367 |
-
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
368 |
-
|
369 |
-
metric_g = discriminator.forward(denoise_audios, clean_audios)
|
370 |
-
discriminator_g_loss = F.mse_loss(metric_g.flatten(), one_labels)
|
371 |
-
|
372 |
-
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss + 0.2 * discriminator_g_loss
|
373 |
-
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
374 |
-
logger.info(f"find nan or inf in loss.")
|
375 |
-
continue
|
376 |
-
|
377 |
-
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
378 |
-
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
379 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
380 |
-
|
381 |
-
optimizer.zero_grad()
|
382 |
-
loss.backward()
|
383 |
-
optimizer.step()
|
384 |
-
lr_scheduler.step()
|
385 |
-
|
386 |
-
total_pesq_score += pesq_score
|
387 |
-
total_loss += loss.item()
|
388 |
-
total_ae_loss += ae_loss.item()
|
389 |
-
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
390 |
-
total_neg_stoi_loss += neg_stoi_loss.item()
|
391 |
-
total_mr_stft_loss += mr_stft_loss.item()
|
392 |
-
total_pesq_loss += pesq_loss.item()
|
393 |
-
total_discriminator_g_loss += discriminator_g_loss.item()
|
394 |
-
total_discriminator_d_loss += discriminator_d_loss.item()
|
395 |
-
total_batches += 1
|
396 |
-
|
397 |
-
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
398 |
-
average_loss = round(total_loss / total_batches, 4)
|
399 |
-
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
400 |
-
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
401 |
-
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
402 |
-
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
403 |
-
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
|
404 |
-
average_discriminator_g_loss = round(total_discriminator_g_loss / total_batches, 4)
|
405 |
-
average_discriminator_d_loss = round(total_discriminator_d_loss / total_batches, 4)
|
406 |
-
|
407 |
-
progress_bar_train.update(1)
|
408 |
-
progress_bar_train.set_postfix({
|
409 |
-
"lr": lr_scheduler.get_last_lr()[0],
|
410 |
-
"pesq_score": average_pesq_score,
|
411 |
-
"loss": average_loss,
|
412 |
-
"ae_loss": average_ae_loss,
|
413 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
414 |
-
"neg_stoi_loss": average_neg_stoi_loss,
|
415 |
-
"mr_stft_loss": average_mr_stft_loss,
|
416 |
-
"pesq_loss": average_pesq_loss,
|
417 |
-
"disc_g_loss": average_discriminator_g_loss,
|
418 |
-
"disc_d_loss": average_discriminator_d_loss,
|
419 |
-
|
420 |
-
})
|
421 |
-
|
422 |
-
# evaluation
|
423 |
-
step_idx += 1
|
424 |
-
if step_idx % config.eval_steps == 0:
|
425 |
-
with torch.no_grad():
|
426 |
-
torch.cuda.empty_cache()
|
427 |
-
|
428 |
-
total_pesq_score = 0.
|
429 |
-
total_loss = 0.
|
430 |
-
total_ae_loss = 0.
|
431 |
-
total_neg_si_snr_loss = 0.
|
432 |
-
total_neg_stoi_loss = 0.
|
433 |
-
total_mr_stft_loss = 0.
|
434 |
-
total_pesq_loss = 0.
|
435 |
-
total_batches = 0.
|
436 |
-
|
437 |
-
progress_bar_train.close()
|
438 |
-
progress_bar_eval = tqdm(
|
439 |
-
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
440 |
-
)
|
441 |
-
for eval_batch in valid_data_loader:
|
442 |
-
clean_audios, noisy_audios = eval_batch
|
443 |
-
clean_audios = clean_audios.to(device)
|
444 |
-
noisy_audios = noisy_audios.to(device)
|
445 |
-
|
446 |
-
denoise_audios = model.forward(noisy_audios)
|
447 |
-
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
448 |
-
|
449 |
-
# Generator
|
450 |
-
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
451 |
-
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
452 |
-
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
453 |
-
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
454 |
-
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
455 |
-
|
456 |
-
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
457 |
-
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
458 |
-
logger.info(f"find nan or inf in loss.")
|
459 |
-
continue
|
460 |
-
|
461 |
-
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
462 |
-
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
463 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
464 |
-
|
465 |
-
total_pesq_score += pesq_score
|
466 |
-
total_loss += loss.item()
|
467 |
-
total_ae_loss += ae_loss.item()
|
468 |
-
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
469 |
-
total_neg_stoi_loss += neg_stoi_loss.item()
|
470 |
-
total_mr_stft_loss += mr_stft_loss.item()
|
471 |
-
total_pesq_loss += pesq_loss.item()
|
472 |
-
total_batches += 1
|
473 |
-
|
474 |
-
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
475 |
-
average_loss = round(total_loss / total_batches, 4)
|
476 |
-
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
477 |
-
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
478 |
-
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
479 |
-
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
480 |
-
average_pesq_loss = round(total_pesq_loss / total_batches, 4)
|
481 |
-
|
482 |
-
progress_bar_eval.update(1)
|
483 |
-
progress_bar_eval.set_postfix({
|
484 |
-
"lr": lr_scheduler.get_last_lr()[0],
|
485 |
-
"pesq_score": average_pesq_score,
|
486 |
-
"loss": average_loss,
|
487 |
-
"ae_loss": average_ae_loss,
|
488 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
489 |
-
"neg_stoi_loss": average_neg_stoi_loss,
|
490 |
-
"mr_stft_loss": average_mr_stft_loss,
|
491 |
-
"pesq_loss": average_pesq_loss,
|
492 |
-
})
|
493 |
-
|
494 |
-
total_pesq_score = 0.
|
495 |
-
total_loss = 0.
|
496 |
-
total_ae_loss = 0.
|
497 |
-
total_neg_si_snr_loss = 0.
|
498 |
-
total_neg_stoi_loss = 0.
|
499 |
-
total_mr_stft_loss = 0.
|
500 |
-
total_pesq_loss = 0.
|
501 |
-
total_discriminator_g_loss = 0.
|
502 |
-
total_discriminator_d_loss = 0.
|
503 |
-
total_batches = 0.
|
504 |
-
|
505 |
-
progress_bar_eval.close()
|
506 |
-
progress_bar_train = tqdm(
|
507 |
-
initial=progress_bar_train.n,
|
508 |
-
postfix=progress_bar_train.postfix,
|
509 |
-
desc=progress_bar_train.desc,
|
510 |
-
)
|
511 |
-
|
512 |
-
# save path
|
513 |
-
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
514 |
-
save_dir.mkdir(parents=True, exist_ok=False)
|
515 |
-
|
516 |
-
# save models
|
517 |
-
model.save_pretrained(save_dir.as_posix())
|
518 |
-
discriminator.save_pretrained(save_dir.as_posix())
|
519 |
-
|
520 |
-
# save optim
|
521 |
-
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
522 |
-
torch.save(discriminator_optimizer.state_dict(), (save_dir / "discriminator_optimizer.pth").as_posix())
|
523 |
-
|
524 |
-
model_list.append(save_dir)
|
525 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
526 |
-
model_to_delete: Path = model_list.pop(0)
|
527 |
-
shutil.rmtree(model_to_delete.as_posix())
|
528 |
-
|
529 |
-
# save metric
|
530 |
-
if best_metric is None:
|
531 |
-
best_epoch_idx = epoch_idx
|
532 |
-
best_step_idx = step_idx
|
533 |
-
best_metric = average_pesq_score
|
534 |
-
elif average_pesq_score > best_metric:
|
535 |
-
# great is better.
|
536 |
-
best_epoch_idx = epoch_idx
|
537 |
-
best_step_idx = step_idx
|
538 |
-
best_metric = average_pesq_score
|
539 |
-
else:
|
540 |
-
pass
|
541 |
-
|
542 |
-
metrics = {
|
543 |
-
"epoch_idx": epoch_idx,
|
544 |
-
"best_epoch_idx": best_epoch_idx,
|
545 |
-
"best_step_idx": best_step_idx,
|
546 |
-
"pesq_score": average_pesq_score,
|
547 |
-
"loss": average_loss,
|
548 |
-
"ae_loss": average_ae_loss,
|
549 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
550 |
-
"neg_stoi_loss": average_neg_stoi_loss,
|
551 |
-
"mr_stft_loss": average_mr_stft_loss,
|
552 |
-
"pesq_loss": average_pesq_loss,
|
553 |
-
}
|
554 |
-
metrics_filename = save_dir / "metrics_epoch.json"
|
555 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
556 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
557 |
-
|
558 |
-
# save best
|
559 |
-
best_dir = serialization_dir / "best"
|
560 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
561 |
-
if best_dir.exists():
|
562 |
-
shutil.rmtree(best_dir)
|
563 |
-
shutil.copytree(save_dir, best_dir)
|
564 |
-
|
565 |
-
# early stop
|
566 |
-
early_stop_flag = False
|
567 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
568 |
-
patience_count = 0
|
569 |
-
else:
|
570 |
-
patience_count += 1
|
571 |
-
if patience_count >= args.patience:
|
572 |
-
early_stop_flag = True
|
573 |
-
|
574 |
-
# early stop
|
575 |
-
if early_stop_flag:
|
576 |
-
break
|
577 |
-
|
578 |
-
return
|
579 |
-
|
580 |
-
|
581 |
-
if __name__ == "__main__":
|
582 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/conv_tasnet_gan/yaml/config.yaml
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
model_name: "conv_tasnet_gan"
|
2 |
-
|
3 |
-
sample_rate: 8000
|
4 |
-
segment_size: 4
|
5 |
-
|
6 |
-
win_size: 20
|
7 |
-
freq_bins: 256
|
8 |
-
bottleneck_channels: 128
|
9 |
-
num_speakers: 1
|
10 |
-
num_blocks: 2
|
11 |
-
num_sub_blocks: 4
|
12 |
-
sub_blocks_channels: 256
|
13 |
-
sub_blocks_kernel_size: 3
|
14 |
-
|
15 |
-
norm_type: "gLN"
|
16 |
-
causal: false
|
17 |
-
mask_nonlinear: "relu"
|
18 |
-
|
19 |
-
min_snr_db: -10
|
20 |
-
max_snr_db: 20
|
21 |
-
|
22 |
-
lr: 0.005
|
23 |
-
adam_b1: 0.8
|
24 |
-
adam_b2: 0.99
|
25 |
-
|
26 |
-
lr_scheduler: "CosineAnnealingLR"
|
27 |
-
lr_scheduler_kwargs:
|
28 |
-
T_max: 250000
|
29 |
-
eta_min: 0.00005
|
30 |
-
|
31 |
-
eval_steps: 25000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/conv_tasnet_gan/yaml/discriminator_config.yaml
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
model_name: "conv_tasnet_gan"
|
2 |
-
|
3 |
-
sample_rate: 8000
|
4 |
-
segment_size: 16000
|
5 |
-
n_fft: 512
|
6 |
-
win_size: 200
|
7 |
-
hop_size: 80
|
8 |
-
|
9 |
-
discriminator_dim: 32
|
10 |
-
discriminator_in_channel: 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dfnet/step_2_train_model.py
CHANGED
@@ -315,6 +315,7 @@ def main():
|
|
315 |
# evaluation
|
316 |
step_idx += 1
|
317 |
if step_idx % config.eval_steps == 0:
|
|
|
318 |
with torch.no_grad():
|
319 |
torch.cuda.empty_cache()
|
320 |
|
@@ -451,6 +452,7 @@ def main():
|
|
451 |
# early stop
|
452 |
if early_stop_flag:
|
453 |
break
|
|
|
454 |
|
455 |
return
|
456 |
|
|
|
315 |
# evaluation
|
316 |
step_idx += 1
|
317 |
if step_idx % config.eval_steps == 0:
|
318 |
+
model.eval()
|
319 |
with torch.no_grad():
|
320 |
torch.cuda.empty_cache()
|
321 |
|
|
|
452 |
# early stop
|
453 |
if early_stop_flag:
|
454 |
break
|
455 |
+
model.train()
|
456 |
|
457 |
return
|
458 |
|
examples/dfnet2/step_2_train_model.py
CHANGED
@@ -318,6 +318,7 @@ def main():
|
|
318 |
# evaluation
|
319 |
step_idx += 1
|
320 |
if step_idx % config.eval_steps == 0:
|
|
|
321 |
with torch.no_grad():
|
322 |
torch.cuda.empty_cache()
|
323 |
|
@@ -457,6 +458,7 @@ def main():
|
|
457 |
# early stop
|
458 |
if early_stop_flag:
|
459 |
break
|
|
|
460 |
|
461 |
return
|
462 |
|
|
|
318 |
# evaluation
|
319 |
step_idx += 1
|
320 |
if step_idx % config.eval_steps == 0:
|
321 |
+
model.eval()
|
322 |
with torch.no_grad():
|
323 |
torch.cuda.empty_cache()
|
324 |
|
|
|
458 |
# early stop
|
459 |
if early_stop_flag:
|
460 |
break
|
461 |
+
model.train()
|
462 |
|
463 |
return
|
464 |
|
examples/dtln/step_2_train_model.py
CHANGED
@@ -301,6 +301,7 @@ def main():
|
|
301 |
# evaluation
|
302 |
step_idx += 1
|
303 |
if step_idx % config.eval_steps == 0:
|
|
|
304 |
with torch.no_grad():
|
305 |
torch.cuda.empty_cache()
|
306 |
|
@@ -424,6 +425,7 @@ def main():
|
|
424 |
# early stop
|
425 |
if early_stop_flag:
|
426 |
break
|
|
|
427 |
|
428 |
return
|
429 |
|
|
|
301 |
# evaluation
|
302 |
step_idx += 1
|
303 |
if step_idx % config.eval_steps == 0:
|
304 |
+
model.eval()
|
305 |
with torch.no_grad():
|
306 |
torch.cuda.empty_cache()
|
307 |
|
|
|
425 |
# early stop
|
426 |
if early_stop_flag:
|
427 |
break
|
428 |
+
model.train()
|
429 |
|
430 |
return
|
431 |
|
examples/frcrn/step_2_train_model.py
CHANGED
@@ -305,6 +305,7 @@ def main():
|
|
305 |
# evaluation
|
306 |
step_idx += 1
|
307 |
if step_idx % config.eval_steps == 0:
|
|
|
308 |
with torch.no_grad():
|
309 |
torch.cuda.empty_cache()
|
310 |
|
@@ -428,6 +429,7 @@ def main():
|
|
428 |
# early stop
|
429 |
if early_stop_flag:
|
430 |
break
|
|
|
431 |
|
432 |
return
|
433 |
|
|
|
305 |
# evaluation
|
306 |
step_idx += 1
|
307 |
if step_idx % config.eval_steps == 0:
|
308 |
+
model.eval()
|
309 |
with torch.no_grad():
|
310 |
torch.cuda.empty_cache()
|
311 |
|
|
|
429 |
# early stop
|
430 |
if early_stop_flag:
|
431 |
break
|
432 |
+
model.train()
|
433 |
|
434 |
return
|
435 |
|
examples/lstm/step_2_train_model.py
CHANGED
@@ -314,6 +314,7 @@ def main():
|
|
314 |
# evaluation
|
315 |
step_idx += 1
|
316 |
if step_idx % config.eval_steps == 0:
|
|
|
317 |
with torch.no_grad():
|
318 |
torch.cuda.empty_cache()
|
319 |
|
@@ -435,6 +436,7 @@ def main():
|
|
435 |
# early stop
|
436 |
if early_stop_flag:
|
437 |
break
|
|
|
438 |
return
|
439 |
|
440 |
|
|
|
314 |
# evaluation
|
315 |
step_idx += 1
|
316 |
if step_idx % config.eval_steps == 0:
|
317 |
+
model.eval()
|
318 |
with torch.no_grad():
|
319 |
torch.cuda.empty_cache()
|
320 |
|
|
|
436 |
# early stop
|
437 |
if early_stop_flag:
|
438 |
break
|
439 |
+
model.train()
|
440 |
return
|
441 |
|
442 |
|
examples/rnnoise/step_2_train_model.py
CHANGED
@@ -314,6 +314,7 @@ def main():
|
|
314 |
# evaluation
|
315 |
step_idx += 1
|
316 |
if step_idx % config.eval_steps == 0:
|
|
|
317 |
with torch.no_grad():
|
318 |
torch.cuda.empty_cache()
|
319 |
|
@@ -435,6 +436,8 @@ def main():
|
|
435 |
# early stop
|
436 |
if early_stop_flag:
|
437 |
break
|
|
|
|
|
438 |
return
|
439 |
|
440 |
|
|
|
314 |
# evaluation
|
315 |
step_idx += 1
|
316 |
if step_idx % config.eval_steps == 0:
|
317 |
+
model.eval()
|
318 |
with torch.no_grad():
|
319 |
torch.cuda.empty_cache()
|
320 |
|
|
|
436 |
# early stop
|
437 |
if early_stop_flag:
|
438 |
break
|
439 |
+
model.train()
|
440 |
+
|
441 |
return
|
442 |
|
443 |
|
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py
CHANGED
@@ -1047,6 +1047,9 @@ class DfNet2(nn.Module):
|
|
1047 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
1048 |
# feat_spec shape: [b, 2, t, df_bins]
|
1049 |
|
|
|
|
|
|
|
1050 |
return spec, feat_erb, feat_spec
|
1051 |
|
1052 |
def forward(self,
|
|
|
1047 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
1048 |
# feat_spec shape: [b, 2, t, df_bins]
|
1049 |
|
1050 |
+
spec = spec.detach()
|
1051 |
+
feat_erb = feat_erb.detach()
|
1052 |
+
feat_spec = feat_spec.detach()
|
1053 |
return spec, feat_erb, feat_spec
|
1054 |
|
1055 |
def forward(self,
|
toolbox/torchaudio/modules/utils/ema.py
CHANGED
@@ -1,8 +1,95 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
3 |
import torch.nn as nn
|
4 |
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
class ExponentialMovingAverage(nn.Module):
|
7 |
def __init__(self):
|
8 |
super().__init__()
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
+
import math
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
import torch.nn as nn
|
7 |
|
8 |
|
9 |
+
def _calculate_norm_alpha(sample_rate: int, hop_size: int, tau: float):
|
10 |
+
"""Exponential decay factor alpha for a given tau (decay window size [s])."""
|
11 |
+
dt = hop_size / sample_rate
|
12 |
+
result = math.exp(-dt / tau)
|
13 |
+
return result
|
14 |
+
|
15 |
+
|
16 |
+
def get_norm_alpha(sample_rate: int, hop_size: int, norm_tau: float) -> float:
|
17 |
+
a_ = _calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau)
|
18 |
+
|
19 |
+
precision = 3
|
20 |
+
a = 1.0
|
21 |
+
while a >= 1.0:
|
22 |
+
a = round(a_, precision)
|
23 |
+
precision += 1
|
24 |
+
|
25 |
+
return a
|
26 |
+
|
27 |
+
|
28 |
+
MEAN_NORM_INIT = [-60., -90.]
|
29 |
+
|
30 |
+
|
31 |
+
def make_erb_norm_state(erb_bins: int, channels: int) -> np.ndarray:
|
32 |
+
state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
|
33 |
+
state = np.expand_dims(state, axis=0)
|
34 |
+
state = np.repeat(state, channels, axis=0)
|
35 |
+
|
36 |
+
# state shape: (audio_channels, erb_bins)
|
37 |
+
return state
|
38 |
+
|
39 |
+
|
40 |
+
def erb_normalize(erb_feat: np.ndarray, alpha: float, state: np.ndarray = None):
|
41 |
+
erb_feat = np.copy(erb_feat)
|
42 |
+
batch_size, time_steps, erb_bins = erb_feat.shape
|
43 |
+
|
44 |
+
if state is None:
|
45 |
+
state = make_erb_norm_state(erb_bins, erb_feat.shape[0])
|
46 |
+
# state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
|
47 |
+
# state = np.expand_dims(state, axis=0)
|
48 |
+
# state = np.repeat(state, erb_feat.shape[0], axis=0)
|
49 |
+
|
50 |
+
for i in range(batch_size):
|
51 |
+
for j in range(time_steps):
|
52 |
+
for k in range(erb_bins):
|
53 |
+
x = erb_feat[i][j][k]
|
54 |
+
s = state[i][k]
|
55 |
+
|
56 |
+
state[i][k] = x * (1. - alpha) + s * alpha
|
57 |
+
erb_feat[i][j][k] -= state[i][k]
|
58 |
+
erb_feat[i][j][k] /= 40.
|
59 |
+
|
60 |
+
return erb_feat
|
61 |
+
|
62 |
+
|
63 |
+
UNIT_NORM_INIT = [0.001, 0.0001]
|
64 |
+
|
65 |
+
|
66 |
+
def make_spec_norm_state(df_bins: int, channels: int) -> np.ndarray:
|
67 |
+
state = np.linspace(UNIT_NORM_INIT[0], UNIT_NORM_INIT[1], df_bins)
|
68 |
+
state = np.expand_dims(state, axis=0)
|
69 |
+
state = np.repeat(state, channels, axis=0)
|
70 |
+
|
71 |
+
# state shape: (audio_channels, df_bins)
|
72 |
+
return state
|
73 |
+
|
74 |
+
|
75 |
+
def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None):
|
76 |
+
spec_feat = np.copy(spec_feat)
|
77 |
+
batch_size, time_steps, df_bins = spec_feat.shape
|
78 |
+
|
79 |
+
if state is None:
|
80 |
+
state = make_spec_norm_state(df_bins, spec_feat.shape[0])
|
81 |
+
|
82 |
+
for i in range(batch_size):
|
83 |
+
for j in range(time_steps):
|
84 |
+
for k in range(df_bins):
|
85 |
+
x = spec_feat[i][j][k]
|
86 |
+
s = state[i][k]
|
87 |
+
|
88 |
+
state[i][k] = np.abs(x) * (1. - alpha) + s * alpha
|
89 |
+
spec_feat[i][j][k] /= np.sqrt(state[i][k])
|
90 |
+
return spec_feat
|
91 |
+
|
92 |
+
|
93 |
class ExponentialMovingAverage(nn.Module):
|
94 |
def __init__(self):
|
95 |
super().__init__()
|