HoneyTian commited on
Commit
1af34cd
·
0 Parent(s):

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +5 -0
  2. .gitattributes +35 -0
  3. .gitignore +25 -0
  4. Dockerfile +24 -0
  5. README.md +129 -0
  6. examples/clean_unet/run.sh +181 -0
  7. examples/clean_unet/step_1_prepare_data.py +201 -0
  8. examples/clean_unet/step_2_train_model.py +419 -0
  9. examples/clean_unet/step_3_evaluation.py +6 -0
  10. examples/clean_unet/yaml/config.yaml +14 -0
  11. examples/conv_tasnet/run.sh +154 -0
  12. examples/conv_tasnet/step_1_prepare_data.py +164 -0
  13. examples/conv_tasnet/step_2_train_model.py +509 -0
  14. examples/conv_tasnet/yaml/config.yaml +28 -0
  15. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py +90 -0
  16. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py +129 -0
  17. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py +71 -0
  18. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py +93 -0
  19. examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py +77 -0
  20. examples/data_preprocess/dns_challenge_to_8k/process_musan.py +8 -0
  21. examples/data_preprocess/ms_snsd_to_8k/process_ms_snsd.py +70 -0
  22. examples/dfnet/run.sh +156 -0
  23. examples/dfnet/step_1_prepare_data.py +164 -0
  24. examples/dfnet/step_2_train_model.py +461 -0
  25. examples/dfnet/yaml/config.yaml +74 -0
  26. examples/dfnet2/run.sh +164 -0
  27. examples/dfnet2/step_1_prepare_data.py +164 -0
  28. examples/dfnet2/step_2_train_model.py +469 -0
  29. examples/dfnet2/yaml/config.yaml +75 -0
  30. examples/dtln/run.sh +171 -0
  31. examples/dtln/step_1_prepare_data.py +164 -0
  32. examples/dtln/step_2_train_model.py +437 -0
  33. examples/dtln/yaml/config-1024.yaml +29 -0
  34. examples/dtln/yaml/config-256.yaml +29 -0
  35. examples/dtln/yaml/config-512.yaml +29 -0
  36. examples/dtln_mp3_to_wav/run.sh +168 -0
  37. examples/dtln_mp3_to_wav/step_1_prepare_data.py +127 -0
  38. examples/dtln_mp3_to_wav/step_2_train_model.py +445 -0
  39. examples/dtln_mp3_to_wav/yaml/config-1024.yaml +29 -0
  40. examples/dtln_mp3_to_wav/yaml/config-256.yaml +29 -0
  41. examples/dtln_mp3_to_wav/yaml/config-512.yaml +29 -0
  42. examples/frcrn/run.sh +159 -0
  43. examples/frcrn/step_1_prepare_data.py +164 -0
  44. examples/frcrn/step_2_train_model.py +457 -0
  45. examples/frcrn/yaml/config-10.yaml +31 -0
  46. examples/frcrn/yaml/config-14.yaml +31 -0
  47. examples/frcrn/yaml/config-20.yaml +31 -0
  48. examples/frcrn_mp3_to_wav/run.sh +156 -0
  49. examples/frcrn_mp3_to_wav/step_1_prepare_data.py +127 -0
  50. examples/frcrn_mp3_to_wav/step_2_train_model.py +442 -0
.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ /examples/
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .gradio/
3
+ .git/
4
+ .idea/
5
+
6
+ **/evaluation_audio/
7
+ **/file_dir/
8
+ **/flagged/
9
+ **/log/
10
+ **/logs/
11
+ **/__pycache__/
12
+
13
+ /data/
14
+ /docs/
15
+ /dotenv/
16
+ /hub_datasets/
17
+ /script/
18
+ /thirdparty/
19
+ /trained_models/
20
+ /temp/
21
+
22
+ **/*.wav
23
+ **/*.xlsx
24
+
25
+ requirements-python-3-9-9.txt
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . /code
6
+
7
+ RUN apt-get update
8
+ RUN apt-get install -y ffmpeg build-essential
9
+
10
+ RUN pip install --upgrade pip
11
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
+
13
+ RUN useradd -m -u 1000 user
14
+
15
+ USER user
16
+
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ WORKDIR $HOME/app
21
+
22
+ COPY --chown=user . $HOME/app
23
+
24
+ CMD ["python3", "main.py"]
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NX Denoise
3
+ emoji: 🐢
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+ ## NX Denoise
13
+
14
+
15
+ ### datasets
16
+
17
+ ```text
18
+
19
+ AISHELL (15G)
20
+ https://openslr.trmal.net/resources/33/
21
+
22
+ AISHELL-3 (19G)
23
+ http://www.openslr.org/93/
24
+
25
+ DNS3
26
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
27
+ 噪音数据来源于 DEMAND, FreeSound, AudioSet.
28
+
29
+ MS-SNSD
30
+ https://github.com/microsoft/MS-SNSD
31
+ 噪音数据来源于 DEMAND, FreeSound.
32
+
33
+ MUSAN
34
+ https://www.openslr.org/17/
35
+ 其中包含 music, noise, speech.
36
+ music 是一些纯音乐, noise 包含 free-sound, sound-bible, sound-bible部分也许可以做为补充部分.
37
+ 总的来说, 有用的不部不多, 可能噪音数据仍然需要自己收集为主, 更加可靠.
38
+
39
+ CHiME-4
40
+ https://www.chimechallenge.org/challenges/chime4/download.html
41
+
42
+ freesound
43
+ https://freesound.org/
44
+
45
+ AudioSet
46
+ https://research.google.com/audioset/index.html
47
+ ```
48
+
49
+
50
+ ### ### 创建训练容器
51
+
52
+ ```text
53
+ 在容器中训练模型,需要能够从容器中访问到 GPU,参考:
54
+ https://hub.docker.com/r/ollama/ollama
55
+
56
+ docker run -itd \
57
+ --name nx_denoise \
58
+ --network host \
59
+ --gpus all \
60
+ --privileged \
61
+ --ipc=host \
62
+ -v /data/tianxing/HuggingDatasets/nx_noise/data:/data/tianxing/HuggingDatasets/nx_noise/data \
63
+ -v /data/tianxing/PycharmProjects/nx_denoise:/data/tianxing/PycharmProjects/nx_denoise \
64
+ python:3.12
65
+
66
+
67
+ 查看GPU
68
+ nvidia-smi
69
+ watch -n 1 -d nvidia-smi
70
+
71
+
72
+ ```
73
+
74
+ ```text
75
+ 在容器中访问 GPU
76
+
77
+ 参考:
78
+ https://blog.csdn.net/footless_bird/article/details/136291344
79
+ 步骤:
80
+ # 安装
81
+ yum install -y nvidia-container-toolkit
82
+
83
+ # 编辑文件 /etc/docker/daemon.json
84
+ cat /etc/docker/daemon.json
85
+ {
86
+ "data-root": "/data/lib/docker",
87
+ "default-runtime": "nvidia",
88
+ "runtimes": {
89
+ "nvidia": {
90
+ "path": "/usr/bin/nvidia-container-runtime",
91
+ "runtimeArgs": []
92
+ }
93
+ },
94
+ "registry-mirrors": [
95
+ "https://docker.m.daocloud.io",
96
+ "https://dockerproxy.com",
97
+ "https://docker.mirrors.ustc.edu.cn",
98
+ "https://docker.nju.edu.cn"
99
+ ]
100
+ }
101
+
102
+ # 重启 docker
103
+ systemctl restart docker
104
+ systemctl daemon-reload
105
+
106
+ # 测试容器内能否访问 GPU.
107
+ docker run --gpus all python:3.12-slim nvidia-smi
108
+
109
+ # 通过这种方式启动容器, 在容器中, 可以查看到 GPU. 但是容器中没有 GPU驱动 nvidia-smi 不工作.
110
+ docker run -it --privileged python:3.12-slim /bin/bash
111
+ apt update
112
+ apt install -y pciutils
113
+ lspci | grep -i nvidia
114
+ #00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
115
+
116
+ # 网上看的是这种启动容器的方式, 但是进去后仍然是 nvidia-smi 不工作.
117
+ docker run \
118
+ --device /dev/nvidia0:/dev/nvidia0 \
119
+ --device /dev/nvidiactl:/dev/nvidiactl \
120
+ --device /dev/nvidia-uvm:/dev/nvidia-uvm \
121
+ -v /usr/local/nvidia:/usr/local/nvidia \
122
+ -it --privileged python:3.12-slim /bin/bash
123
+
124
+
125
+ # 这种方式进入容器, nvidia-smi 可以工作. 应该关键是 --gpus all 参数.
126
+ docker run -itd --gpus all --name open_unsloth python:3.12-slim /bin/bash
127
+ docker run -itd --gpus all --name Qwen2-7B-Instruct python:3.12-slim /bin/bash
128
+
129
+ ```
examples/clean_unet/run.sh ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
+
19
+
20
+ END
21
+
22
+
23
+ # params
24
+ system_version="windows";
25
+ verbose=true;
26
+ stage=0 # start from 0 if you need to start from data preparation
27
+ stop_stage=9
28
+
29
+ work_dir="$(pwd)"
30
+ file_folder_name=file_folder_name
31
+ final_model_name=final_model_name
32
+ config_file="yaml/config.yaml"
33
+ limit=10
34
+
35
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
36
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
37
+
38
+ max_count=10000000
39
+
40
+ nohup_name=nohup.out
41
+
42
+ # model params
43
+ batch_size=64
44
+ max_epochs=200
45
+ save_top_k=10
46
+ patience=5
47
+
48
+
49
+ # parse options
50
+ while true; do
51
+ [ -z "${1:-}" ] && break; # break if there are no arguments
52
+ case "$1" in
53
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
54
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
55
+ old_value="(eval echo \\$$name)";
56
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
57
+ was_bool=true;
58
+ else
59
+ was_bool=false;
60
+ fi
61
+
62
+ # Set the variable to the right value-- the escaped quotes make it work if
63
+ # the option had spaces, like --cmd "queue.pl -sync y"
64
+ eval "${name}=\"$2\"";
65
+
66
+ # Check that Boolean-valued arguments are really Boolean.
67
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
68
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
69
+ exit 1;
70
+ fi
71
+ shift 2;
72
+ ;;
73
+
74
+ *) break;
75
+ esac
76
+ done
77
+
78
+ file_dir="${work_dir}/${file_folder_name}"
79
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
80
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
81
+
82
+ dataset="${file_dir}/dataset.xlsx"
83
+ train_dataset="${file_dir}/train.xlsx"
84
+ valid_dataset="${file_dir}/valid.xlsx"
85
+
86
+ $verbose && echo "system_version: ${system_version}"
87
+ $verbose && echo "file_folder_name: ${file_folder_name}"
88
+
89
+ if [ $system_version == "windows" ]; then
90
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
91
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
92
+ #source /data/local/bin/nx_denoise/bin/activate
93
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
94
+ fi
95
+
96
+
97
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
98
+ $verbose && echo "stage 1: prepare data"
99
+ cd "${work_dir}" || exit 1
100
+ python3 step_1_prepare_data.py \
101
+ --file_dir "${file_dir}" \
102
+ --noise_dir "${noise_dir}" \
103
+ --speech_dir "${speech_dir}" \
104
+ --train_dataset "${train_dataset}" \
105
+ --valid_dataset "${valid_dataset}" \
106
+ --max_count "${max_count}" \
107
+
108
+ fi
109
+
110
+
111
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
112
+ $verbose && echo "stage 2: train model"
113
+ cd "${work_dir}" || exit 1
114
+ python3 step_2_train_model.py \
115
+ --train_dataset "${train_dataset}" \
116
+ --valid_dataset "${valid_dataset}" \
117
+ --serialization_dir "${file_dir}" \
118
+ --config_file "${config_file}" \
119
+
120
+ fi
121
+
122
+
123
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
124
+ $verbose && echo "stage 3: test model"
125
+ cd "${work_dir}" || exit 1
126
+ python3 step_3_evaluation.py \
127
+ --valid_dataset "${valid_dataset}" \
128
+ --model_dir "${file_dir}/best" \
129
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
130
+ --limit "${limit}" \
131
+
132
+ fi
133
+
134
+
135
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
136
+ $verbose && echo "stage 4: export model"
137
+ cd "${work_dir}" || exit 1
138
+ python3 step_5_export_models.py \
139
+ --vocabulary_dir "${vocabulary_dir}" \
140
+ --model_dir "${file_dir}/best" \
141
+ --serialization_dir "${file_dir}" \
142
+
143
+ fi
144
+
145
+
146
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
147
+ $verbose && echo "stage 5: collect files"
148
+ cd "${work_dir}" || exit 1
149
+
150
+ mkdir -p ${final_model_dir}
151
+
152
+ cp "${file_dir}/best"/* "${final_model_dir}"
153
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
154
+
155
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
156
+
157
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
158
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
159
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
160
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
161
+
162
+ cd "${final_model_dir}/.." || exit 1;
163
+
164
+ if [ -e "${final_model_name}.zip" ]; then
165
+ rm -rf "${final_model_name}_backup.zip"
166
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
167
+ fi
168
+
169
+ zip -r "${final_model_name}.zip" "${final_model_name}"
170
+ rm -rf "${final_model_name}"
171
+
172
+ fi
173
+
174
+
175
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
176
+ $verbose && echo "stage 6: clear file_dir"
177
+ cd "${work_dir}" || exit 1
178
+
179
+ rm -rf "${file_dir}";
180
+
181
+ fi
examples/clean_unet/step_1_prepare_data.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ parser.add_argument("--max_count", default=10000, type=int)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def filename_generator(data_dir: str):
52
+ data_dir = Path(data_dir)
53
+ for filename in data_dir.glob("**/*.wav"):
54
+ yield filename.as_posix()
55
+
56
+
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
58
+ data_dir = Path(data_dir)
59
+ for filename in data_dir.glob("**/*.wav"):
60
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
61
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
62
+
63
+ if raw_duration < duration:
64
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
65
+ continue
66
+ if signal.ndim != 1:
67
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
68
+
69
+ signal_length = len(signal)
70
+ win_size = int(duration * sample_rate)
71
+ for begin in range(0, signal_length - win_size, win_size):
72
+ row = {
73
+ "filename": filename.as_posix(),
74
+ "raw_duration": round(raw_duration, 4),
75
+ "offset": round(begin / sample_rate, 4),
76
+ "duration": round(duration, 4),
77
+ }
78
+ yield row
79
+
80
+
81
+ def get_dataset(args):
82
+ file_dir = Path(args.file_dir)
83
+ file_dir.mkdir(exist_ok=True)
84
+
85
+ noise_dir = Path(args.noise_dir)
86
+ speech_dir = Path(args.speech_dir)
87
+
88
+ noise_generator = target_second_signal_generator(
89
+ noise_dir.as_posix(),
90
+ duration=args.duration,
91
+ sample_rate=args.target_sample_rate
92
+ )
93
+ speech_generator = target_second_signal_generator(
94
+ speech_dir.as_posix(),
95
+ duration=args.duration,
96
+ sample_rate=args.target_sample_rate
97
+ )
98
+
99
+ dataset = list()
100
+
101
+ count = 0
102
+ process_bar = tqdm(desc="build dataset excel")
103
+ for noise, speech in zip(noise_generator, speech_generator):
104
+ if count >= args.max_count:
105
+ break
106
+
107
+ noise_filename = noise["filename"]
108
+ noise_raw_duration = noise["raw_duration"]
109
+ noise_offset = noise["offset"]
110
+ noise_duration = noise["duration"]
111
+
112
+ speech_filename = speech["filename"]
113
+ speech_raw_duration = speech["raw_duration"]
114
+ speech_offset = speech["offset"]
115
+ speech_duration = speech["duration"]
116
+
117
+ random1 = random.random()
118
+ random2 = random.random()
119
+
120
+ row = {
121
+ "noise_filename": noise_filename,
122
+ "noise_raw_duration": noise_raw_duration,
123
+ "noise_offset": noise_offset,
124
+ "noise_duration": noise_duration,
125
+
126
+ "speech_filename": speech_filename,
127
+ "speech_raw_duration": speech_raw_duration,
128
+ "speech_offset": speech_offset,
129
+ "speech_duration": speech_duration,
130
+
131
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
132
+
133
+ "random1": random1,
134
+ "random2": random2,
135
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
136
+ }
137
+ dataset.append(row)
138
+ count += 1
139
+ duration_seconds = count * args.duration
140
+ duration_hours = duration_seconds / 3600
141
+
142
+ process_bar.update(n=1)
143
+ process_bar.set_postfix({
144
+ # "duration_seconds": round(duration_seconds, 4),
145
+ "duration_hours": round(duration_hours, 4),
146
+
147
+ })
148
+
149
+ dataset = pd.DataFrame(dataset)
150
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
151
+ dataset.to_excel(
152
+ file_dir / "dataset.xlsx",
153
+ index=False,
154
+ )
155
+ return
156
+
157
+
158
+
159
+ def split_dataset(args):
160
+ """分割训练集, 测试集"""
161
+ file_dir = Path(args.file_dir)
162
+ file_dir.mkdir(exist_ok=True)
163
+
164
+ df = pd.read_excel(file_dir / "dataset.xlsx")
165
+
166
+ train = list()
167
+ test = list()
168
+
169
+ for i, row in df.iterrows():
170
+ flag = row["flag"]
171
+ if flag == "TRAIN":
172
+ train.append(row)
173
+ else:
174
+ test.append(row)
175
+
176
+ train = pd.DataFrame(train)
177
+ train.to_excel(
178
+ args.train_dataset,
179
+ index=False,
180
+ # encoding="utf_8_sig"
181
+ )
182
+ test = pd.DataFrame(test)
183
+ test.to_excel(
184
+ args.valid_dataset,
185
+ index=False,
186
+ # encoding="utf_8_sig"
187
+ )
188
+
189
+ return
190
+
191
+
192
+ def main():
193
+ args = get_args()
194
+
195
+ get_dataset(args)
196
+ split_dataset(args)
197
+ return
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
examples/clean_unet/step_2_train_model.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/NVIDIA/CleanUNet/blob/main/train.py
5
+
6
+ https://github.com/NVIDIA/CleanUNet/blob/main/configs/DNS-large-full.json
7
+ """
8
+ import argparse
9
+ import json
10
+ import logging
11
+ from logging.handlers import TimedRotatingFileHandler
12
+ import os
13
+ import platform
14
+ from pathlib import Path
15
+ import random
16
+ import sys
17
+ import shutil
18
+ from typing import List
19
+
20
+ pwd = os.path.abspath(os.path.dirname(__file__))
21
+ sys.path.append(os.path.join(pwd, "../../"))
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.nn import functional as F
27
+ from torch.utils.data.dataloader import DataLoader
28
+ from tqdm import tqdm
29
+
30
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
31
+ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig
32
+ from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
33
+ from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
34
+ from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
35
+ from toolbox.torchaudio.models.clean_unet.metrics import run_pesq_score
36
+
37
+ torch.autograd.set_detect_anomaly(True)
38
+
39
+
40
+ def get_args():
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
43
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
44
+
45
+ parser.add_argument("--max_epochs", default=100, type=int)
46
+
47
+ parser.add_argument("--batch_size", default=64, type=int)
48
+ parser.add_argument("--learning_rate", default=2e-4, type=float)
49
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
50
+ parser.add_argument("--patience", default=5, type=int)
51
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
52
+ parser.add_argument("--seed", default=0, type=int)
53
+
54
+ parser.add_argument("--config_file", default="config.yaml", type=str)
55
+
56
+ args = parser.parse_args()
57
+ return args
58
+
59
+
60
+ def logging_config(file_dir: str):
61
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
62
+
63
+ logging.basicConfig(format=fmt,
64
+ datefmt="%m/%d/%Y %H:%M:%S",
65
+ level=logging.INFO)
66
+ file_handler = TimedRotatingFileHandler(
67
+ filename=os.path.join(file_dir, "main.log"),
68
+ encoding="utf-8",
69
+ when="D",
70
+ interval=1,
71
+ backupCount=7
72
+ )
73
+ file_handler.setLevel(logging.INFO)
74
+ file_handler.setFormatter(logging.Formatter(fmt))
75
+ logger = logging.getLogger(__name__)
76
+ logger.addHandler(file_handler)
77
+
78
+ return logger
79
+
80
+
81
+ class CollateFunction(object):
82
+ def __init__(self):
83
+ pass
84
+
85
+ def __call__(self, batch: List[dict]):
86
+ clean_audios = list()
87
+ noisy_audios = list()
88
+
89
+ for sample in batch:
90
+ # noise_wave: torch.Tensor = sample["noise_wave"]
91
+ clean_audio: torch.Tensor = sample["speech_wave"]
92
+ noisy_audio: torch.Tensor = sample["mix_wave"]
93
+ # snr_db: float = sample["snr_db"]
94
+
95
+ clean_audios.append(clean_audio)
96
+ noisy_audios.append(noisy_audio)
97
+
98
+ clean_audios = torch.stack(clean_audios)
99
+ noisy_audios = torch.stack(noisy_audios)
100
+
101
+ # assert
102
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
103
+ raise AssertionError("nan or inf in clean_audios")
104
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
105
+ raise AssertionError("nan or inf in noisy_audios")
106
+ return clean_audios, noisy_audios
107
+
108
+
109
+ collate_fn = CollateFunction()
110
+
111
+
112
+ def main():
113
+ args = get_args()
114
+
115
+ config = CleanUNetConfig.from_pretrained(
116
+ pretrained_model_name_or_path=args.config_file,
117
+ )
118
+
119
+ serialization_dir = Path(args.serialization_dir)
120
+ serialization_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ logger = logging_config(serialization_dir)
123
+
124
+ random.seed(args.seed)
125
+ np.random.seed(args.seed)
126
+ torch.manual_seed(args.seed)
127
+ logger.info(f"set seed: {args.seed}")
128
+
129
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
+ n_gpu = torch.cuda.device_count()
131
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
132
+
133
+ # datasets
134
+ train_dataset = DenoiseExcelDataset(
135
+ excel_file=args.train_dataset,
136
+ expected_sample_rate=8000,
137
+ max_wave_value=32768.0,
138
+ )
139
+ valid_dataset = DenoiseExcelDataset(
140
+ excel_file=args.valid_dataset,
141
+ expected_sample_rate=8000,
142
+ max_wave_value=32768.0,
143
+ )
144
+ train_data_loader = DataLoader(
145
+ dataset=train_dataset,
146
+ batch_size=args.batch_size,
147
+ shuffle=True,
148
+ sampler=None,
149
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
150
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
151
+ collate_fn=collate_fn,
152
+ pin_memory=False,
153
+ # prefetch_factor=64,
154
+ )
155
+ valid_data_loader = DataLoader(
156
+ dataset=valid_dataset,
157
+ batch_size=args.batch_size,
158
+ shuffle=True,
159
+ sampler=None,
160
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
161
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
162
+ collate_fn=collate_fn,
163
+ pin_memory=False,
164
+ # prefetch_factor=64,
165
+ )
166
+
167
+ # models
168
+ logger.info(f"prepare models. config_file: {args.config_file}")
169
+ model = CleanUNetPretrainedModel(config).to(device)
170
+
171
+ # optimizer
172
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
173
+ optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
174
+
175
+ # resume training
176
+ last_epoch = -1
177
+ for epoch_i in serialization_dir.glob("epoch-*"):
178
+ epoch_i = Path(epoch_i)
179
+ epoch_idx = epoch_i.stem.split("-")[1]
180
+ epoch_idx = int(epoch_idx)
181
+ if epoch_idx > last_epoch:
182
+ last_epoch = epoch_idx
183
+
184
+ if last_epoch != -1:
185
+ logger.info(f"resume from epoch-{last_epoch}.")
186
+ model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
187
+ optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
188
+
189
+ logger.info(f"load state dict for model.")
190
+ with open(model_pt.as_posix(), "rb") as f:
191
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
+ model.load_state_dict(state_dict, strict=True)
193
+
194
+ logger.info(f"load state dict for optimizer.")
195
+ with open(optimizer_pth.as_posix(), "rb") as f:
196
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
197
+ optimizer.load_state_dict(state_dict)
198
+
199
+ lr_scheduler = LinearWarmupCosineDecay(
200
+ optimizer,
201
+ lr_max=args.learning_rate,
202
+ n_iter=250000,
203
+ iteration=250000,
204
+ divider=25,
205
+ warmup_proportion=0.05,
206
+ phase=("linear", "cosine"),
207
+ )
208
+
209
+ # ae_loss_fn = nn.MSELoss(reduction="mean")
210
+ ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
211
+
212
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
213
+ fft_sizes=[256, 512, 1024],
214
+ hop_sizes=[25, 50, 120],
215
+ win_lengths=[120, 240, 600],
216
+ sc_lambda=0.5,
217
+ mag_lambda=0.5,
218
+ band="full"
219
+ ).to(device)
220
+
221
+ # training loop
222
+
223
+ # state
224
+ average_pesq_score = 10000000000
225
+ average_loss = 10000000000
226
+ average_ae_loss = 10000000000
227
+ average_sc_loss = 10000000000
228
+ average_mag_loss = 10000000000
229
+
230
+ model_list = list()
231
+ best_idx_epoch = None
232
+ best_metric = None
233
+ patience_count = 0
234
+
235
+ logger.info("training")
236
+ for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
237
+ # train
238
+ model.train()
239
+
240
+ total_pesq_score = 0.
241
+ total_loss = 0.
242
+ total_ae_loss = 0.
243
+ total_sc_loss = 0.
244
+ total_mag_loss = 0.
245
+ total_batches = 0.
246
+
247
+ progress_bar = tqdm(
248
+ total=len(train_data_loader),
249
+ desc="Training; epoch: {}".format(idx_epoch),
250
+ )
251
+ for batch in train_data_loader:
252
+ clean_audios, noisy_audios = batch
253
+ clean_audios = clean_audios.to(device)
254
+ noisy_audios = noisy_audios.to(device)
255
+
256
+ enhanced_audios = model.forward(noisy_audios)
257
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
258
+
259
+ ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
260
+ sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
261
+
262
+ loss = ae_loss + sc_loss + mag_loss
263
+
264
+ enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
265
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
266
+ pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
267
+
268
+ optimizer.zero_grad()
269
+ loss.backward()
270
+ optimizer.step()
271
+ lr_scheduler.step()
272
+
273
+ total_pesq_score += pesq_score
274
+ total_loss += loss.item()
275
+ total_ae_loss += ae_loss.item()
276
+ total_sc_loss += sc_loss.item()
277
+ total_mag_loss += mag_loss.item()
278
+ total_batches += 1
279
+
280
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
281
+ average_loss = round(total_loss / total_batches, 4)
282
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
283
+ average_sc_loss = round(total_sc_loss / total_batches, 4)
284
+ average_mag_loss = round(total_mag_loss / total_batches, 4)
285
+
286
+ progress_bar.update(1)
287
+ progress_bar.set_postfix({
288
+ "pesq_score": average_pesq_score,
289
+ "loss": average_loss,
290
+ "ae_loss": average_ae_loss,
291
+ "sc_loss": average_sc_loss,
292
+ "mag_loss": average_mag_loss,
293
+ })
294
+
295
+ # evaluation
296
+ model.eval()
297
+
298
+ torch.cuda.empty_cache()
299
+
300
+ total_pesq_score = 0.
301
+ total_loss = 0.
302
+ total_ae_loss = 0.
303
+ total_sc_loss = 0.
304
+ total_mag_loss = 0.
305
+ total_batches = 0.
306
+
307
+ progress_bar = tqdm(
308
+ total=len(valid_data_loader),
309
+ desc="Evaluation; epoch: {}".format(idx_epoch),
310
+ )
311
+ with torch.no_grad():
312
+ for batch in valid_data_loader:
313
+ clean_audios, noisy_audios = batch
314
+ clean_audios = clean_audios.to(device)
315
+ noisy_audios = noisy_audios.to(device)
316
+
317
+ enhanced_audios = model.forward(noisy_audios)
318
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
319
+
320
+ ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
321
+ sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
322
+
323
+ loss = ae_loss + sc_loss + mag_loss
324
+
325
+ enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
326
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
327
+ pesq_score = run_pesq_score(clean_audios_list_r, enhanced_audios_list_r, sample_rate=8000, mode="nb")
328
+
329
+ total_pesq_score += pesq_score
330
+ total_loss += loss.item()
331
+ total_ae_loss += ae_loss.item()
332
+ total_sc_loss += sc_loss.item()
333
+ total_mag_loss += mag_loss.item()
334
+ total_batches += 1
335
+
336
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
337
+ average_loss = round(total_loss / total_batches, 4)
338
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
339
+ average_sc_loss = round(total_sc_loss / total_batches, 4)
340
+ average_mag_loss = round(total_mag_loss / total_batches, 4)
341
+
342
+ progress_bar.update(1)
343
+ progress_bar.set_postfix({
344
+ "pesq_score": average_pesq_score,
345
+ "loss": average_loss,
346
+ "ae_loss": average_ae_loss,
347
+ "sc_loss": average_sc_loss,
348
+ "mag_loss": average_mag_loss,
349
+ })
350
+
351
+ # scheduler
352
+ lr_scheduler.step()
353
+
354
+ # save path
355
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
356
+ epoch_dir.mkdir(parents=True, exist_ok=False)
357
+
358
+ # save models
359
+ model.save_pretrained(epoch_dir.as_posix())
360
+
361
+ model_list.append(epoch_dir)
362
+ if len(model_list) >= args.num_serialized_models_to_keep:
363
+ model_to_delete: Path = model_list.pop(0)
364
+ shutil.rmtree(model_to_delete.as_posix())
365
+
366
+ # save optim
367
+ torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
368
+
369
+ # save metric
370
+ if best_metric is None:
371
+ best_idx_epoch = idx_epoch
372
+ best_metric = average_pesq_score
373
+ elif average_pesq_score > best_metric:
374
+ # great is better.
375
+ best_idx_epoch = idx_epoch
376
+ best_metric = average_pesq_score
377
+ else:
378
+ pass
379
+
380
+ metrics = {
381
+ "idx_epoch": idx_epoch,
382
+ "best_idx_epoch": best_idx_epoch,
383
+
384
+ "pesq_score": average_pesq_score,
385
+ "loss": average_loss,
386
+ "ae_loss": average_ae_loss,
387
+ "sc_loss": average_sc_loss,
388
+ "mag_loss": average_mag_loss,
389
+
390
+ }
391
+ metrics_filename = epoch_dir / "metrics_epoch.json"
392
+ with open(metrics_filename, "w", encoding="utf-8") as f:
393
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
394
+
395
+ # save best
396
+ best_dir = serialization_dir / "best"
397
+ if best_idx_epoch == idx_epoch:
398
+ if best_dir.exists():
399
+ shutil.rmtree(best_dir)
400
+ shutil.copytree(epoch_dir, best_dir)
401
+
402
+ # early stop
403
+ early_stop_flag = False
404
+ if best_idx_epoch == idx_epoch:
405
+ patience_count = 0
406
+ else:
407
+ patience_count += 1
408
+ if patience_count >= args.patience:
409
+ early_stop_flag = True
410
+
411
+ # early stop
412
+ if early_stop_flag:
413
+ break
414
+
415
+ return
416
+
417
+
418
+ if __name__ == "__main__":
419
+ main()
examples/clean_unet/step_3_evaluation.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
examples/clean_unet/yaml/config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "clean_unet"
2
+
3
+ channels_input: 1
4
+ channels_output: 1
5
+ channels_h: 64
6
+ max_h: 768
7
+ encoder_n_layers: 8
8
+ kernel_size: 4
9
+ stride: 2
10
+ tsfm_n_layers: 5
11
+ tsfm_n_head: 8
12
+ tsfm_d_model: 512
13
+ tsfm_d_inner: 2048
14
+
examples/conv_tasnet/run.sh ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
+ --max_epochs 400
10
+
11
+
12
+ END
13
+
14
+
15
+ # params
16
+ system_version="windows";
17
+ verbose=true;
18
+ stage=0 # start from 0 if you need to start from data preparation
19
+ stop_stage=9
20
+
21
+ work_dir="$(pwd)"
22
+ file_folder_name=file_folder_name
23
+ final_model_name=final_model_name
24
+ config_file="yaml/config.yaml"
25
+ limit=10
26
+
27
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
28
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
29
+
30
+ max_count=10000000
31
+
32
+ nohup_name=nohup.out
33
+
34
+ # model params
35
+ batch_size=64
36
+ max_epochs=200
37
+ save_top_k=10
38
+ patience=5
39
+
40
+
41
+ # parse options
42
+ while true; do
43
+ [ -z "${1:-}" ] && break; # break if there are no arguments
44
+ case "$1" in
45
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
46
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
47
+ old_value="(eval echo \\$$name)";
48
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
49
+ was_bool=true;
50
+ else
51
+ was_bool=false;
52
+ fi
53
+
54
+ # Set the variable to the right value-- the escaped quotes make it work if
55
+ # the option had spaces, like --cmd "queue.pl -sync y"
56
+ eval "${name}=\"$2\"";
57
+
58
+ # Check that Boolean-valued arguments are really Boolean.
59
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
60
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
61
+ exit 1;
62
+ fi
63
+ shift 2;
64
+ ;;
65
+
66
+ *) break;
67
+ esac
68
+ done
69
+
70
+ file_dir="${work_dir}/${file_folder_name}"
71
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
73
+
74
+ train_dataset="${file_dir}/train.jsonl"
75
+ valid_dataset="${file_dir}/valid.jsonl"
76
+
77
+ $verbose && echo "system_version: ${system_version}"
78
+ $verbose && echo "file_folder_name: ${file_folder_name}"
79
+
80
+ if [ $system_version == "windows" ]; then
81
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
82
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
83
+ #source /data/local/bin/nx_denoise/bin/activate
84
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
85
+ fi
86
+
87
+
88
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
89
+ $verbose && echo "stage 1: prepare data"
90
+ cd "${work_dir}" || exit 1
91
+ python3 step_1_prepare_data.py \
92
+ --file_dir "${file_dir}" \
93
+ --noise_dir "${noise_dir}" \
94
+ --speech_dir "${speech_dir}" \
95
+ --train_dataset "${train_dataset}" \
96
+ --valid_dataset "${valid_dataset}" \
97
+ --max_count "${max_count}" \
98
+
99
+ fi
100
+
101
+
102
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
103
+ $verbose && echo "stage 2: train model"
104
+ cd "${work_dir}" || exit 1
105
+ python3 step_2_train_model.py \
106
+ --train_dataset "${train_dataset}" \
107
+ --valid_dataset "${valid_dataset}" \
108
+ --serialization_dir "${file_dir}" \
109
+ --config_file "${config_file}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
115
+ $verbose && echo "stage 3: test model"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_3_evaluation.py \
118
+ --valid_dataset "${valid_dataset}" \
119
+ --model_dir "${file_dir}/best" \
120
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
121
+ --limit "${limit}" \
122
+
123
+ fi
124
+
125
+
126
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
127
+ $verbose && echo "stage 4: collect files"
128
+ cd "${work_dir}" || exit 1
129
+
130
+ mkdir -p ${final_model_dir}
131
+
132
+ cp "${file_dir}/best"/* "${final_model_dir}"
133
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
134
+
135
+ cd "${final_model_dir}/.." || exit 1;
136
+
137
+ if [ -e "${final_model_name}.zip" ]; then
138
+ rm -rf "${final_model_name}_backup.zip"
139
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
140
+ fi
141
+
142
+ zip -r "${final_model_name}.zip" "${final_model_name}"
143
+ rm -rf "${final_model_name}"
144
+
145
+ fi
146
+
147
+
148
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
149
+ $verbose && echo "stage 5: clear file_dir"
150
+ cd "${work_dir}" || exit 1
151
+
152
+ rm -rf "${file_dir}";
153
+
154
+ fi
examples/conv_tasnet/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=4.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=10000, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset excel")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/conv_tasnet/step_2_train_model.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/kaituoxu/Conv-TasNet/tree/master/src
5
+
6
+ 一般场景:
7
+
8
+ 目标 SI-SNR ≥ 10 dB,适用于电话通信、基础语音助手等。
9
+
10
+ 高要求场景(如医疗助听、语音识别):
11
+ 需 SI-SNR ≥ 14 dB,并配合 PESQ ≥ 3.0 和 STOI ≥ 0.851812。
12
+
13
+ DeepFilterNet2 模型在 DNS4 数据集,超过500小时的音频上训练了 100 个 epoch。
14
+ https://arxiv.org/abs/2205.05474
15
+
16
+ """
17
+ import argparse
18
+ import json
19
+ import logging
20
+ from logging.handlers import TimedRotatingFileHandler
21
+ import os
22
+ import platform
23
+ from pathlib import Path
24
+ import random
25
+ import sys
26
+ import shutil
27
+ from typing import List
28
+
29
+ pwd = os.path.abspath(os.path.dirname(__file__))
30
+ sys.path.append(os.path.join(pwd, "../../"))
31
+
32
+ import numpy as np
33
+ import torch
34
+ import torch.nn as nn
35
+ from torch.nn import functional as F
36
+ from torch.utils.data.dataloader import DataLoader
37
+ from tqdm import tqdm
38
+
39
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
40
+ from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
41
+ from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
42
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
43
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
44
+ from toolbox.torchaudio.losses.perceptual import NegSTOILoss, PesqLoss
45
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
46
+
47
+
48
+ def get_args():
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
51
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
52
+
53
+ parser.add_argument("--max_epochs", default=200, type=int)
54
+
55
+ parser.add_argument("--batch_size", default=8, type=int)
56
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
57
+ parser.add_argument("--patience", default=5, type=int)
58
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
59
+ parser.add_argument("--seed", default=1234, type=int)
60
+
61
+ parser.add_argument("--config_file", default="config.yaml", type=str)
62
+
63
+ args = parser.parse_args()
64
+ return args
65
+
66
+
67
+ def logging_config(file_dir: str):
68
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
69
+
70
+ logging.basicConfig(format=fmt,
71
+ datefmt="%m/%d/%Y %H:%M:%S",
72
+ level=logging.INFO)
73
+ file_handler = TimedRotatingFileHandler(
74
+ filename=os.path.join(file_dir, "main.log"),
75
+ encoding="utf-8",
76
+ when="D",
77
+ interval=1,
78
+ backupCount=7
79
+ )
80
+ file_handler.setLevel(logging.INFO)
81
+ file_handler.setFormatter(logging.Formatter(fmt))
82
+ logger = logging.getLogger(__name__)
83
+ logger.addHandler(file_handler)
84
+
85
+ return logger
86
+
87
+
88
+ class CollateFunction(object):
89
+ def __init__(self):
90
+ pass
91
+
92
+ def __call__(self, batch: List[dict]):
93
+ clean_audios = list()
94
+ noisy_audios = list()
95
+
96
+ for sample in batch:
97
+ # noise_wave: torch.Tensor = sample["noise_wave"]
98
+ clean_audio: torch.Tensor = sample["speech_wave"]
99
+ noisy_audio: torch.Tensor = sample["mix_wave"]
100
+ # snr_db: float = sample["snr_db"]
101
+
102
+ clean_audios.append(clean_audio)
103
+ noisy_audios.append(noisy_audio)
104
+
105
+ clean_audios = torch.stack(clean_audios)
106
+ noisy_audios = torch.stack(noisy_audios)
107
+
108
+ # assert
109
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
110
+ raise AssertionError("nan or inf in clean_audios")
111
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
112
+ raise AssertionError("nan or inf in noisy_audios")
113
+ return clean_audios, noisy_audios
114
+
115
+
116
+ collate_fn = CollateFunction()
117
+
118
+
119
+ def main():
120
+ args = get_args()
121
+
122
+ config = ConvTasNetConfig.from_pretrained(
123
+ pretrained_model_name_or_path=args.config_file,
124
+ )
125
+
126
+ serialization_dir = Path(args.serialization_dir)
127
+ serialization_dir.mkdir(parents=True, exist_ok=True)
128
+
129
+ logger = logging_config(serialization_dir)
130
+
131
+ random.seed(args.seed)
132
+ np.random.seed(args.seed)
133
+ torch.manual_seed(args.seed)
134
+ logger.info(f"set seed: {args.seed}")
135
+
136
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
+ n_gpu = torch.cuda.device_count()
138
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
139
+
140
+ # datasets
141
+ train_dataset = DenoiseJsonlDataset(
142
+ jsonl_file=args.train_dataset,
143
+ expected_sample_rate=config.sample_rate,
144
+ max_wave_value=32768.0,
145
+ min_snr_db=config.min_snr_db,
146
+ max_snr_db=config.max_snr_db,
147
+ # skip=225000,
148
+ )
149
+ valid_dataset = DenoiseJsonlDataset(
150
+ jsonl_file=args.valid_dataset,
151
+ expected_sample_rate=config.sample_rate,
152
+ max_wave_value=32768.0,
153
+ min_snr_db=config.min_snr_db,
154
+ max_snr_db=config.max_snr_db,
155
+ )
156
+ train_data_loader = DataLoader(
157
+ dataset=train_dataset,
158
+ batch_size=args.batch_size,
159
+ # shuffle=True,
160
+ sampler=None,
161
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
162
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
163
+ collate_fn=collate_fn,
164
+ pin_memory=False,
165
+ prefetch_factor=2,
166
+ )
167
+ valid_data_loader = DataLoader(
168
+ dataset=valid_dataset,
169
+ batch_size=args.batch_size,
170
+ # shuffle=True,
171
+ sampler=None,
172
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
173
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
174
+ collate_fn=collate_fn,
175
+ pin_memory=False,
176
+ prefetch_factor=2,
177
+ )
178
+
179
+ # models
180
+ logger.info(f"prepare models. config_file: {args.config_file}")
181
+ model = ConvTasNetPretrainedModel(config).to(device)
182
+ model.to(device)
183
+ model.train()
184
+
185
+ # optimizer
186
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
187
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
188
+
189
+ # resume training
190
+ last_step_idx = -1
191
+ last_epoch = -1
192
+ for step_idx_str in serialization_dir.glob("steps-*"):
193
+ step_idx_str = Path(step_idx_str)
194
+ step_idx = step_idx_str.stem.split("-")[1]
195
+ step_idx = int(step_idx)
196
+ if step_idx > last_step_idx:
197
+ last_step_idx = step_idx
198
+ last_epoch = 1
199
+
200
+ if last_step_idx != -1:
201
+ logger.info(f"resume from steps-{last_step_idx}.")
202
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
203
+ optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
204
+
205
+ logger.info(f"load state dict for model.")
206
+ with open(model_pt.as_posix(), "rb") as f:
207
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
208
+ model.load_state_dict(state_dict, strict=True)
209
+
210
+ logger.info(f"load state dict for optimizer.")
211
+ with open(optimizer_pth.as_posix(), "rb") as f:
212
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
213
+ optimizer.load_state_dict(state_dict)
214
+
215
+ if config.lr_scheduler == "CosineAnnealingLR":
216
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
217
+ optimizer,
218
+ last_epoch=last_epoch,
219
+ # T_max=10 * config.eval_steps,
220
+ # eta_min=0.01 * config.lr,
221
+ **config.lr_scheduler_kwargs,
222
+ )
223
+ elif config.lr_scheduler == "MultiStepLR":
224
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
225
+ optimizer,
226
+ last_epoch=last_epoch,
227
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
228
+ )
229
+ else:
230
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
231
+
232
+ ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
233
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
234
+ neg_stoi_loss_fn = NegSTOILoss(sample_rate=config.sample_rate, reduction="mean").to(device)
235
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
236
+ fft_size_list=[256, 512, 1024],
237
+ win_size_list=[120, 240, 480],
238
+ hop_size_list=[25, 50, 100],
239
+ factor_sc=1.5,
240
+ factor_mag=1.0,
241
+ reduction="mean"
242
+ ).to(device)
243
+ pesq_loss_fn = PesqLoss(0.5, sample_rate=config.sample_rate).to(device)
244
+
245
+ # training loop
246
+
247
+ # state
248
+ average_pesq_score = 1000000000
249
+ average_loss = 1000000000
250
+ average_ae_loss = 1000000000
251
+ average_neg_si_snr_loss = 1000000000
252
+ average_neg_stoi_loss = 1000000000
253
+
254
+ model_list = list()
255
+ best_epoch_idx = None
256
+ best_step_idx = None
257
+ best_metric = None
258
+ patience_count = 0
259
+
260
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
261
+
262
+ logger.info("training")
263
+ for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
264
+ # train
265
+ model.train()
266
+
267
+ total_pesq_score = 0.
268
+ total_loss = 0.
269
+ total_ae_loss = 0.
270
+ total_neg_si_snr_loss = 0.
271
+ total_neg_stoi_loss = 0.
272
+ total_mr_stft_loss = 0.
273
+ total_pesq_loss = 0.
274
+ total_batches = 0.
275
+
276
+ progress_bar_train = tqdm(
277
+ initial=step_idx,
278
+ desc="Training; epoch-{}".format(epoch_idx),
279
+ )
280
+ for train_batch in train_data_loader:
281
+ clean_audios, noisy_audios = train_batch
282
+ clean_audios: torch.Tensor = clean_audios.to(device)
283
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
284
+
285
+ denoise_audios = model.forward(noisy_audios)
286
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
287
+
288
+ if torch.any(torch.isnan(denoise_audios)) or torch.any(torch.isinf(denoise_audios)):
289
+ raise AssertionError("nan or inf in denoise_audios")
290
+
291
+ ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
292
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
293
+ neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
294
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
295
+ pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
296
+
297
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
298
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
299
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
300
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
301
+ # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
302
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
303
+ loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss
304
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
305
+ logger.info(f"find nan or inf in loss.")
306
+ continue
307
+
308
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
309
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
310
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
311
+
312
+ optimizer.zero_grad()
313
+ loss.backward()
314
+ optimizer.step()
315
+ lr_scheduler.step()
316
+
317
+ total_pesq_score += pesq_score
318
+ total_loss += loss.item()
319
+ total_ae_loss += ae_loss.item()
320
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
321
+ total_neg_stoi_loss += neg_stoi_loss.item()
322
+ total_mr_stft_loss += mr_stft_loss.item()
323
+ total_pesq_loss += pesq_loss.item()
324
+ total_batches += 1
325
+
326
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
327
+ average_loss = round(total_loss / total_batches, 4)
328
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
329
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
330
+ average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
331
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
332
+ average_pesq_loss = round(total_pesq_loss / total_batches, 4)
333
+
334
+ progress_bar_train.update(1)
335
+ progress_bar_train.set_postfix({
336
+ "lr": lr_scheduler.get_last_lr()[0],
337
+ "pesq_score": average_pesq_score,
338
+ "loss": average_loss,
339
+ "ae_loss": average_ae_loss,
340
+ "neg_si_snr_loss": average_neg_si_snr_loss,
341
+ "neg_stoi_loss": average_neg_stoi_loss,
342
+ "mr_stft_loss": average_mr_stft_loss,
343
+ "pesq_loss": average_pesq_loss,
344
+ })
345
+
346
+ # evaluation
347
+ step_idx += 1
348
+ if step_idx % config.eval_steps == 0:
349
+ model.eval()
350
+ with torch.no_grad():
351
+ torch.cuda.empty_cache()
352
+
353
+ total_pesq_score = 0.
354
+ total_loss = 0.
355
+ total_ae_loss = 0.
356
+ total_neg_si_snr_loss = 0.
357
+ total_neg_stoi_loss = 0.
358
+ total_mr_stft_loss = 0.
359
+ total_pesq_loss = 0.
360
+ total_batches = 0.
361
+
362
+ progress_bar_train.close()
363
+ progress_bar_eval = tqdm(
364
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
365
+ )
366
+ for eval_batch in valid_data_loader:
367
+ clean_audios, noisy_audios = eval_batch
368
+ clean_audios = clean_audios.to(device)
369
+ noisy_audios = noisy_audios.to(device)
370
+
371
+ denoise_audios = model.forward(noisy_audios)
372
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
373
+
374
+ ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
375
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
376
+ neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
377
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
378
+ pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
379
+
380
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
381
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
382
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
383
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
384
+ # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
385
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
386
+ loss = 0.1 * ae_loss + 0.1 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.2 * neg_stoi_loss + 0.2 * pesq_loss
387
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
388
+ logger.info(f"find nan or inf in loss.")
389
+ continue
390
+
391
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
392
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
393
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
394
+
395
+ total_pesq_score += pesq_score
396
+ total_loss += loss.item()
397
+ total_ae_loss += ae_loss.item()
398
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
399
+ total_neg_stoi_loss += neg_stoi_loss.item()
400
+ total_mr_stft_loss += mr_stft_loss.item()
401
+ total_pesq_loss += pesq_loss.item()
402
+ total_batches += 1
403
+
404
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
405
+ average_loss = round(total_loss / total_batches, 4)
406
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
407
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
408
+ average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
409
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
410
+ average_pesq_loss = round(total_pesq_loss / total_batches, 4)
411
+
412
+ progress_bar_eval.update(1)
413
+ progress_bar_eval.set_postfix({
414
+ "lr": lr_scheduler.get_last_lr()[0],
415
+ "pesq_score": average_pesq_score,
416
+ "loss": average_loss,
417
+ "ae_loss": average_ae_loss,
418
+ "neg_si_snr_loss": average_neg_si_snr_loss,
419
+ "neg_stoi_loss": average_neg_stoi_loss,
420
+ "mr_stft_loss": average_mr_stft_loss,
421
+ "pesq_loss": average_pesq_loss,
422
+ })
423
+
424
+ total_pesq_score = 0.
425
+ total_loss = 0.
426
+ total_ae_loss = 0.
427
+ total_neg_si_snr_loss = 0.
428
+ total_neg_stoi_loss = 0.
429
+ total_mr_stft_loss = 0.
430
+ total_pesq_loss = 0.
431
+ total_batches = 0.
432
+
433
+ progress_bar_eval.close()
434
+ progress_bar_train = tqdm(
435
+ initial=progress_bar_train.n,
436
+ postfix=progress_bar_train.postfix,
437
+ desc=progress_bar_train.desc,
438
+ )
439
+
440
+ # save path
441
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
442
+ save_dir.mkdir(parents=True, exist_ok=False)
443
+
444
+ # save models
445
+ model.save_pretrained(save_dir.as_posix())
446
+
447
+ model_list.append(save_dir)
448
+ if len(model_list) >= args.num_serialized_models_to_keep:
449
+ model_to_delete: Path = model_list.pop(0)
450
+ shutil.rmtree(model_to_delete.as_posix())
451
+
452
+ # save optim
453
+ torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
454
+
455
+ # save metric
456
+ if best_metric is None:
457
+ best_epoch_idx = epoch_idx
458
+ best_step_idx = step_idx
459
+ best_metric = average_pesq_score
460
+ elif average_pesq_score > best_metric:
461
+ # great is better.
462
+ best_epoch_idx = epoch_idx
463
+ best_step_idx = step_idx
464
+ best_metric = average_pesq_score
465
+ else:
466
+ pass
467
+
468
+ metrics = {
469
+ "epoch_idx": epoch_idx,
470
+ "best_epoch_idx": best_epoch_idx,
471
+ "best_step_idx": best_step_idx,
472
+ "pesq_score": average_pesq_score,
473
+ "loss": average_loss,
474
+ "ae_loss": average_ae_loss,
475
+ "neg_si_snr_loss": average_neg_si_snr_loss,
476
+ "neg_stoi_loss": average_neg_stoi_loss,
477
+ "mr_stft_loss": average_mr_stft_loss,
478
+ "pesq_loss": average_pesq_loss,
479
+ }
480
+ metrics_filename = save_dir / "metrics_epoch.json"
481
+ with open(metrics_filename, "w", encoding="utf-8") as f:
482
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
483
+
484
+ # save best
485
+ best_dir = serialization_dir / "best"
486
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
487
+ if best_dir.exists():
488
+ shutil.rmtree(best_dir)
489
+ shutil.copytree(save_dir, best_dir)
490
+
491
+ # early stop
492
+ early_stop_flag = False
493
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
494
+ patience_count = 0
495
+ else:
496
+ patience_count += 1
497
+ if patience_count >= args.patience:
498
+ early_stop_flag = True
499
+
500
+ # early stop
501
+ if early_stop_flag:
502
+ break
503
+ model.train()
504
+
505
+ return
506
+
507
+
508
+ if __name__ == "__main__":
509
+ main()
examples/conv_tasnet/yaml/config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "conv_tasnet"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 4
5
+
6
+ win_size: 20
7
+ freq_bins: 256
8
+ bottleneck_channels: 256
9
+ num_speakers: 1
10
+ num_blocks: 4
11
+ num_sub_blocks: 8
12
+ sub_blocks_channels: 512
13
+ sub_blocks_kernel_size: 3
14
+
15
+ norm_type: "gLN"
16
+ causal: false
17
+ mask_nonlinear: "relu"
18
+
19
+ min_snr_db: -10
20
+ max_snr_db: 20
21
+
22
+ lr: 0.005
23
+ lr_scheduler: "CosineAnnealingLR"
24
+ lr_scheduler_kwargs:
25
+ T_max: 250000
26
+ eta_min: 0.00005
27
+
28
+ eval_steps: 25000
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_emotional_speech.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ 247M
16
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
17
+
18
+
19
+ """
20
+ import argparse
21
+ import os
22
+ from pathlib import Path
23
+ import sys
24
+
25
+ import numpy as np
26
+ from tqdm import tqdm
27
+
28
+ pwd = os.path.abspath(os.path.dirname(__file__))
29
+ sys.path.append(os.path.join(pwd, "../../"))
30
+
31
+ import librosa
32
+ from scipy.io import wavfile
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser()
37
+
38
+ parser.add_argument(
39
+ "--data_dir",
40
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
41
+ type=str
42
+ )
43
+ parser.add_argument(
44
+ "--output_dir",
45
+ default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
46
+ type=str
47
+ )
48
+ parser.add_argument("--sample_rate", default=8000, type=int)
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def main():
54
+ args = get_args()
55
+
56
+ data_dir = Path(args.data_dir)
57
+ output_dir = Path(args.output_dir)
58
+ output_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ # finished_set
61
+ finished_set = set()
62
+ for filename in tqdm(output_dir.glob("**/*.wav")):
63
+ name = filename.stem
64
+ finished_set.add(name)
65
+ print(f"finished_set count: {len(finished_set)}")
66
+
67
+ for filename in tqdm(data_dir.glob("**/*.wav")):
68
+ label = filename.parts[-2]
69
+ name = filename.stem
70
+ # print(f"filename: {filename.as_posix()}")
71
+ if name in finished_set:
72
+ continue
73
+
74
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
75
+
76
+ signal = signal * (1 << 15)
77
+ signal = np.array(signal, dtype=np.int16)
78
+
79
+ to_file = output_dir / f"{label}/{name}.wav"
80
+ to_file.parent.mkdir(parents=True, exist_ok=True)
81
+ wavfile.write(
82
+ to_file.as_posix(),
83
+ rate=args.sample_rate,
84
+ data=signal,
85
+ )
86
+ return
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ 12G
16
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.french_data.tar.bz2
17
+
18
+ 43G
19
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.german_speech.tar.bz2
20
+
21
+ 7.9G
22
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.italian_speech.tar.bz2
23
+
24
+ 12G
25
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.mandarin_speech.tar.bz2
26
+
27
+ 3.1G
28
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.russian_speech.tar.bz2
29
+
30
+ 9.7G
31
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.spanish_speech.tar.bz2
32
+
33
+ 617M
34
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.singing_voice.tar.bz2
35
+
36
+ """
37
+ import argparse
38
+ import os
39
+ from pathlib import Path
40
+ import sys
41
+
42
+ import numpy as np
43
+ from tqdm import tqdm
44
+
45
+ pwd = os.path.abspath(os.path.dirname(__file__))
46
+ sys.path.append(os.path.join(pwd, "../../"))
47
+
48
+ import librosa
49
+ from scipy.io import wavfile
50
+
51
+
52
+ def get_args():
53
+ parser = argparse.ArgumentParser()
54
+
55
+ parser.add_argument(
56
+ "--data_dir",
57
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
58
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
59
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
60
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.french_data\datasets\clean\french_data",
61
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech",
62
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
63
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
64
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.spanish_speech\datasets\clean\spanish_speech",
65
+ type=str
66
+ )
67
+ parser.add_argument(
68
+ "--output_dir",
69
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
70
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
71
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
72
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-french-speech-8k",
73
+ default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k",
74
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
75
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
76
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-spanish-speech-8k",
77
+ type=str
78
+ )
79
+ parser.add_argument("--sample_rate", default=8000, type=int)
80
+ args = parser.parse_args()
81
+ return args
82
+
83
+
84
+ def main():
85
+ args = get_args()
86
+
87
+ data_dir = Path(args.data_dir)
88
+ output_dir = Path(args.output_dir)
89
+ output_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ # finished_set
92
+ finished_set = set()
93
+ for filename in tqdm(output_dir.glob("**/*.wav")):
94
+ filename = Path(filename)
95
+ relative_name = filename.relative_to(output_dir)
96
+ relative_name_ = relative_name.as_posix()
97
+ finished_set.add(relative_name_)
98
+ print(f"finished_set count: {len(finished_set)}")
99
+
100
+ for filename in tqdm(data_dir.glob("**/*.wav")):
101
+ relative_name = filename.relative_to(data_dir)
102
+ relative_name_ = relative_name.as_posix()
103
+ if relative_name_ in finished_set:
104
+ continue
105
+ finished_set.add(relative_name_)
106
+
107
+ try:
108
+ signal, _ = librosa.load(filename.as_posix(), mono=False, sr=args.sample_rate)
109
+ except Exception:
110
+ print(f"skip file: {filename.as_posix()}")
111
+ continue
112
+ if signal.ndim != 1:
113
+ raise AssertionError
114
+
115
+ signal = signal * (1 << 15)
116
+ signal = np.array(signal, dtype=np.int16)
117
+
118
+ to_file = output_dir / relative_name.as_posix()
119
+ to_file.parent.mkdir(parents=True, exist_ok=True)
120
+ wavfile.write(
121
+ to_file.as_posix(),
122
+ rate=args.sample_rate,
123
+ data=signal,
124
+ )
125
+ return
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_demand.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ """
10
+ import argparse
11
+ import os
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+
17
+ import numpy as np
18
+
19
+ pwd = os.path.abspath(os.path.dirname(__file__))
20
+ sys.path.append(os.path.join(pwd, "../../"))
21
+
22
+ import librosa
23
+ from scipy.io import wavfile
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser()
28
+
29
+ parser.add_argument(
30
+ "--data_dir",
31
+ default=r"E:\programmer\asr_datasets\dns-challenge\DEMAND\demand",
32
+ type=str
33
+ )
34
+ parser.add_argument(
35
+ "--output_dir",
36
+ default=r"E:\programmer\asr_datasets\denoise\demand-8k",
37
+ type=str
38
+ )
39
+ parser.add_argument("--sample_rate", default=8000, type=int)
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def main():
45
+ args = get_args()
46
+
47
+ data_dir = Path(args.data_dir)
48
+ output_dir = Path(args.output_dir)
49
+ output_dir.mkdir(parents=True, exist_ok=False)
50
+
51
+ for filename in data_dir.glob("**/ch01.wav"):
52
+ label = filename.parts[-2]
53
+ name = filename.stem
54
+
55
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
56
+
57
+ signal = signal * (1 << 15)
58
+ signal = np.array(signal, dtype=np.int16)
59
+
60
+ to_file = output_dir / f"{label}/{name}.wav"
61
+ to_file.parent.mkdir(parents=True, exist_ok=True)
62
+ wavfile.write(
63
+ to_file.as_posix(),
64
+ rate=args.sample_rate,
65
+ data=signal,
66
+ )
67
+ return
68
+
69
+
70
+ if __name__ == '__main__':
71
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_impulse_responses.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ 247M
16
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.emotional_speech.tar.bz2
17
+
18
+ 240M
19
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.impulse_responses.tar.bz2
20
+
21
+
22
+ """
23
+ import argparse
24
+ import os
25
+ from pathlib import Path
26
+ import sys
27
+
28
+ import numpy as np
29
+ from tqdm import tqdm
30
+
31
+ pwd = os.path.abspath(os.path.dirname(__file__))
32
+ sys.path.append(os.path.join(pwd, "../../"))
33
+
34
+ import librosa
35
+ from scipy.io import wavfile
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser()
40
+
41
+ parser.add_argument(
42
+ "--data_dir",
43
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.emotional_speech\datasets\clean\emotional_speech",
44
+ type=str
45
+ )
46
+ parser.add_argument(
47
+ "--output_dir",
48
+ default=r"E:\programmer\asr_datasets\denoise\dns-clean-emotional-speech-8k",
49
+ type=str
50
+ )
51
+ parser.add_argument("--sample_rate", default=8000, type=int)
52
+ args = parser.parse_args()
53
+ return args
54
+
55
+
56
+ def main():
57
+ args = get_args()
58
+
59
+ data_dir = Path(args.data_dir)
60
+ output_dir = Path(args.output_dir)
61
+ output_dir.mkdir(parents=True, exist_ok=True)
62
+
63
+ # finished_set
64
+ finished_set = set()
65
+ for filename in tqdm(output_dir.glob("**/*.wav")):
66
+ name = filename.stem
67
+ finished_set.add(name)
68
+ print(f"finished_set count: {len(finished_set)}")
69
+
70
+ for filename in tqdm(data_dir.glob("**/*.wav")):
71
+ label = filename.parts[-2]
72
+ name = filename.stem
73
+ # print(f"filename: {filename.as_posix()}")
74
+ if name in finished_set:
75
+ continue
76
+
77
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
78
+
79
+ signal = signal * (1 << 15)
80
+ signal = np.array(signal, dtype=np.int16)
81
+
82
+ to_file = output_dir / f"{label}/{name}.wav"
83
+ to_file.parent.mkdir(parents=True, exist_ok=True)
84
+ wavfile.write(
85
+ to_file.as_posix(),
86
+ rate=args.sample_rate,
87
+ data=signal,
88
+ )
89
+ return
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_noise.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/microsoft/DNS-Challenge/blob/master/download-dns-challenge-3.sh
5
+
6
+ 1.2G
7
+ wget https://dns3public.blob.core.windows.net/dns3archive/DEMAND.tar.bz2
8
+
9
+ 14G
10
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.noise.tar.bz2
11
+
12
+ 38G
13
+ wget https://dns3public.blob.core.windows.net/dns3archive/datasets/datasets.clean.read_speech.tar.bz2
14
+
15
+ """
16
+ import argparse
17
+ import os
18
+ from pathlib import Path
19
+ import sys
20
+
21
+ import numpy as np
22
+ from tqdm import tqdm
23
+
24
+ pwd = os.path.abspath(os.path.dirname(__file__))
25
+ sys.path.append(os.path.join(pwd, "../../"))
26
+
27
+ import librosa
28
+ from scipy.io import wavfile
29
+
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser()
33
+
34
+ parser.add_argument(
35
+ "--data_dir",
36
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.noise\datasets",
37
+ type=str
38
+ )
39
+ parser.add_argument(
40
+ "--output_dir",
41
+ default=r"E:\programmer\asr_datasets\denoise\dns-noise-8k",
42
+ type=str
43
+ )
44
+ parser.add_argument("--sample_rate", default=8000, type=int)
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def main():
50
+ args = get_args()
51
+
52
+ data_dir = Path(args.data_dir)
53
+ output_dir = Path(args.output_dir)
54
+ output_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ for filename in tqdm(data_dir.glob("**/*.wav")):
57
+ label = filename.parts[-2]
58
+ name = filename.stem
59
+ # print(f"filename: {filename.as_posix()}")
60
+
61
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
62
+
63
+ signal = signal * (1 << 15)
64
+ signal = np.array(signal, dtype=np.int16)
65
+
66
+ to_file = output_dir / f"{label}/{name}.wav"
67
+ to_file.parent.mkdir(parents=True, exist_ok=True)
68
+ wavfile.write(
69
+ to_file.as_posix(),
70
+ rate=args.sample_rate,
71
+ data=signal,
72
+ )
73
+ return
74
+
75
+
76
+ if __name__ == '__main__':
77
+ main()
examples/data_preprocess/dns_challenge_to_8k/process_musan.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://www.openslr.org/17/
5
+ """
6
+
7
+ if __name__ == '__main__':
8
+ pass
examples/data_preprocess/ms_snsd_to_8k/process_ms_snsd.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ MS-SNSD
5
+ https://github.com/microsoft/MS-SNSD
6
+ """
7
+ import argparse
8
+ import os
9
+ from pathlib import Path
10
+ import sys
11
+
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import librosa
19
+ from scipy.io import wavfile
20
+
21
+
22
+ def get_args():
23
+ parser = argparse.ArgumentParser()
24
+
25
+ parser.add_argument(
26
+ "--data_dir",
27
+ default=r"E:\programmer\asr_datasets\MS-SNSD",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--output_dir",
32
+ default=r"E:\programmer\asr_datasets\denoise\ms-snsd-noise-8k",
33
+ type=str
34
+ )
35
+ parser.add_argument("--sample_rate", default=8000, type=int)
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+
43
+ data_dir = Path(args.data_dir)
44
+ output_dir = Path(args.output_dir)
45
+ output_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ for filename in tqdm(data_dir.glob("**/*.wav")):
48
+ label = filename.parts[-2]
49
+ name = filename.stem
50
+
51
+ if label not in ["noise_train", "noise_test", "clean_train", "clean_test"]:
52
+ continue
53
+
54
+ signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
55
+
56
+ signal = signal * (1 << 15)
57
+ signal = np.array(signal, dtype=np.int16)
58
+
59
+ to_file = output_dir / f"{label}/{name}.wav"
60
+ to_file.parent.mkdir(parents=True, exist_ok=True)
61
+ wavfile.write(
62
+ to_file.as_posix(),
63
+ rate=args.sample_rate,
64
+ data=signal,
65
+ )
66
+ return
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
examples/dfnet/run.sh ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
6
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
+ --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
+
9
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-nx-dns3 \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
+
13
+
14
+ END
15
+
16
+
17
+ # params
18
+ system_version="windows";
19
+ verbose=true;
20
+ stage=0 # start from 0 if you need to start from data preparation
21
+ stop_stage=9
22
+
23
+ work_dir="$(pwd)"
24
+ file_folder_name=file_folder_name
25
+ final_model_name=final_model_name
26
+ config_file="yaml/config.yaml"
27
+ limit=10
28
+
29
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
30
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
31
+
32
+ max_count=10000000
33
+
34
+ nohup_name=nohup.out
35
+
36
+ # model params
37
+ batch_size=64
38
+ max_epochs=200
39
+ save_top_k=10
40
+ patience=5
41
+
42
+
43
+ # parse options
44
+ while true; do
45
+ [ -z "${1:-}" ] && break; # break if there are no arguments
46
+ case "$1" in
47
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
48
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
49
+ old_value="(eval echo \\$$name)";
50
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
51
+ was_bool=true;
52
+ else
53
+ was_bool=false;
54
+ fi
55
+
56
+ # Set the variable to the right value-- the escaped quotes make it work if
57
+ # the option had spaces, like --cmd "queue.pl -sync y"
58
+ eval "${name}=\"$2\"";
59
+
60
+ # Check that Boolean-valued arguments are really Boolean.
61
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
62
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
63
+ exit 1;
64
+ fi
65
+ shift 2;
66
+ ;;
67
+
68
+ *) break;
69
+ esac
70
+ done
71
+
72
+ file_dir="${work_dir}/${file_folder_name}"
73
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
74
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
75
+
76
+ train_dataset="${file_dir}/train.jsonl"
77
+ valid_dataset="${file_dir}/valid.jsonl"
78
+
79
+ $verbose && echo "system_version: ${system_version}"
80
+ $verbose && echo "file_folder_name: ${file_folder_name}"
81
+
82
+ if [ $system_version == "windows" ]; then
83
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
84
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
85
+ #source /data/local/bin/nx_denoise/bin/activate
86
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
87
+ fi
88
+
89
+
90
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
91
+ $verbose && echo "stage 1: prepare data"
92
+ cd "${work_dir}" || exit 1
93
+ python3 step_1_prepare_data.py \
94
+ --file_dir "${file_dir}" \
95
+ --noise_dir "${noise_dir}" \
96
+ --speech_dir "${speech_dir}" \
97
+ --train_dataset "${train_dataset}" \
98
+ --valid_dataset "${valid_dataset}" \
99
+ --max_count "${max_count}" \
100
+
101
+ fi
102
+
103
+
104
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
105
+ $verbose && echo "stage 2: train model"
106
+ cd "${work_dir}" || exit 1
107
+ python3 step_2_train_model.py \
108
+ --train_dataset "${train_dataset}" \
109
+ --valid_dataset "${valid_dataset}" \
110
+ --serialization_dir "${file_dir}" \
111
+ --config_file "${config_file}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
+ $verbose && echo "stage 3: test model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_3_evaluation.py \
120
+ --valid_dataset "${valid_dataset}" \
121
+ --model_dir "${file_dir}/best" \
122
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
123
+ --limit "${limit}" \
124
+
125
+ fi
126
+
127
+
128
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
+ $verbose && echo "stage 4: collect files"
130
+ cd "${work_dir}" || exit 1
131
+
132
+ mkdir -p ${final_model_dir}
133
+
134
+ cp "${file_dir}/best"/* "${final_model_dir}"
135
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
136
+
137
+ cd "${final_model_dir}/.." || exit 1;
138
+
139
+ if [ -e "${final_model_name}.zip" ]; then
140
+ rm -rf "${final_model_name}_backup.zip"
141
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
142
+ fi
143
+
144
+ zip -r "${final_model_name}.zip" "${final_model_name}"
145
+ rm -rf "${final_model_name}"
146
+
147
+ fi
148
+
149
+
150
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
151
+ $verbose && echo "stage 5: clear file_dir"
152
+ cd "${work_dir}" || exit 1
153
+
154
+ rm -rf "${file_dir}";
155
+
156
+ fi
examples/dfnet/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=4.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=10000, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/dfnet/step_2_train_model.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/Rikorose/DeepFilterNet
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ from fontTools.varLib.plot import stops
19
+
20
+ pwd = os.path.abspath(os.path.dirname(__file__))
21
+ sys.path.append(os.path.join(pwd, "../../"))
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.nn import functional as F
27
+ from torch.utils.data.dataloader import DataLoader
28
+ from tqdm import tqdm
29
+
30
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
31
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
33
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
34
+ from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
35
+ from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
41
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
42
+
43
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
44
+ parser.add_argument("--patience", default=10, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ clean_audios = list()
80
+ noisy_audios = list()
81
+ snr_db_list = list()
82
+
83
+ for sample in batch:
84
+ # noise_wave: torch.Tensor = sample["noise_wave"]
85
+ clean_audio: torch.Tensor = sample["speech_wave"]
86
+ noisy_audio: torch.Tensor = sample["mix_wave"]
87
+ # snr_db: float = sample["snr_db"]
88
+
89
+ clean_audios.append(clean_audio)
90
+ noisy_audios.append(noisy_audio)
91
+
92
+ clean_audios = torch.stack(clean_audios)
93
+ noisy_audios = torch.stack(noisy_audios)
94
+
95
+ # assert
96
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
97
+ raise AssertionError("nan or inf in clean_audios")
98
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
99
+ raise AssertionError("nan or inf in noisy_audios")
100
+ return clean_audios, noisy_audios
101
+
102
+
103
+ collate_fn = CollateFunction()
104
+
105
+
106
+ def main():
107
+ args = get_args()
108
+
109
+ config = DfNetConfig.from_pretrained(
110
+ pretrained_model_name_or_path=args.config_file,
111
+ )
112
+
113
+ serialization_dir = Path(args.serialization_dir)
114
+ serialization_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ logger = logging_config(serialization_dir)
117
+
118
+ random.seed(config.seed)
119
+ np.random.seed(config.seed)
120
+ torch.manual_seed(config.seed)
121
+ logger.info(f"set seed: {config.seed}")
122
+
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ n_gpu = torch.cuda.device_count()
125
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
+
127
+ # datasets
128
+ train_dataset = DenoiseJsonlDataset(
129
+ jsonl_file=args.train_dataset,
130
+ expected_sample_rate=config.sample_rate,
131
+ max_wave_value=32768.0,
132
+ min_snr_db=config.min_snr_db,
133
+ max_snr_db=config.max_snr_db,
134
+ # skip=225000,
135
+ )
136
+ valid_dataset = DenoiseJsonlDataset(
137
+ jsonl_file=args.valid_dataset,
138
+ expected_sample_rate=config.sample_rate,
139
+ max_wave_value=32768.0,
140
+ min_snr_db=config.min_snr_db,
141
+ max_snr_db=config.max_snr_db,
142
+ )
143
+ train_data_loader = DataLoader(
144
+ dataset=train_dataset,
145
+ batch_size=config.batch_size,
146
+ # shuffle=True,
147
+ sampler=None,
148
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
149
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
150
+ collate_fn=collate_fn,
151
+ pin_memory=False,
152
+ prefetch_factor=None if platform.system() == "Windows" else 2,
153
+ )
154
+ valid_data_loader = DataLoader(
155
+ dataset=valid_dataset,
156
+ batch_size=config.batch_size,
157
+ # shuffle=True,
158
+ sampler=None,
159
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
160
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
161
+ collate_fn=collate_fn,
162
+ pin_memory=False,
163
+ prefetch_factor=None if platform.system() == "Windows" else 2,
164
+ )
165
+
166
+ # models
167
+ logger.info(f"prepare models. config_file: {args.config_file}")
168
+ model = DfNetPretrainedModel(config).to(device)
169
+ model.to(device)
170
+ model.train()
171
+
172
+ # optimizer
173
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
174
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
175
+
176
+ # resume training
177
+ last_step_idx = -1
178
+ last_epoch = -1
179
+ for step_idx_str in serialization_dir.glob("steps-*"):
180
+ step_idx_str = Path(step_idx_str)
181
+ step_idx = step_idx_str.stem.split("-")[1]
182
+ step_idx = int(step_idx)
183
+ if step_idx > last_step_idx:
184
+ last_step_idx = step_idx
185
+ # last_epoch = 1
186
+
187
+ if last_step_idx != -1:
188
+ logger.info(f"resume from steps-{last_step_idx}.")
189
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
190
+
191
+ logger.info(f"load state dict for model.")
192
+ with open(model_pt.as_posix(), "rb") as f:
193
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
+ model.load_state_dict(state_dict, strict=True)
195
+
196
+ if config.lr_scheduler == "CosineAnnealingLR":
197
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
198
+ optimizer,
199
+ last_epoch=last_epoch,
200
+ # T_max=10 * config.eval_steps,
201
+ # eta_min=0.01 * config.lr,
202
+ **config.lr_scheduler_kwargs,
203
+ )
204
+ elif config.lr_scheduler == "MultiStepLR":
205
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
206
+ optimizer,
207
+ last_epoch=last_epoch,
208
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
209
+ )
210
+ else:
211
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
212
+
213
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
214
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
215
+ fft_size_list=[256, 512, 1024],
216
+ win_size_list=[256, 512, 1024],
217
+ hop_size_list=[128, 256, 512],
218
+ factor_sc=1.5,
219
+ factor_mag=1.0,
220
+ reduction="mean"
221
+ ).to(device)
222
+
223
+ # training loop
224
+
225
+ # state
226
+ average_pesq_score = 1000000000
227
+ average_loss = 1000000000
228
+ average_mr_stft_loss = 1000000000
229
+ average_neg_si_snr_loss = 1000000000
230
+ average_mask_loss = 1000000000
231
+ average_lsnr_loss = 1000000000
232
+
233
+ model_list = list()
234
+ best_epoch_idx = None
235
+ best_step_idx = None
236
+ best_metric = None
237
+ patience_count = 0
238
+
239
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
240
+
241
+ logger.info("training")
242
+ early_stop_flag = False
243
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
244
+ if early_stop_flag:
245
+ break
246
+
247
+ # train
248
+ model.train()
249
+
250
+ total_pesq_score = 0.
251
+ total_loss = 0.
252
+ total_mr_stft_loss = 0.
253
+ total_neg_si_snr_loss = 0.
254
+ total_mask_loss = 0.
255
+ total_lsnr_loss = 0.
256
+ total_batches = 0.
257
+
258
+ progress_bar_train = tqdm(
259
+ initial=step_idx,
260
+ desc="Training; epoch-{}".format(epoch_idx),
261
+ )
262
+ for train_batch in train_data_loader:
263
+ clean_audios, noisy_audios = train_batch
264
+ clean_audios: torch.Tensor = clean_audios.to(device)
265
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
266
+
267
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
268
+
269
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
270
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
271
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
272
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
273
+
274
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
275
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
276
+ logger.info(f"find nan or inf in loss.")
277
+ continue
278
+
279
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
280
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
281
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
282
+
283
+ optimizer.zero_grad()
284
+ loss.backward()
285
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
286
+ optimizer.step()
287
+ lr_scheduler.step()
288
+
289
+ total_pesq_score += pesq_score
290
+ total_loss += loss.item()
291
+ total_mr_stft_loss += mr_stft_loss.item()
292
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
293
+ total_mask_loss += mask_loss.item()
294
+ total_lsnr_loss += lsnr_loss.item()
295
+ total_batches += 1
296
+
297
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
298
+ average_loss = round(total_loss / total_batches, 4)
299
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
300
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
301
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
302
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
303
+
304
+ progress_bar_train.update(1)
305
+ progress_bar_train.set_postfix({
306
+ "lr": lr_scheduler.get_last_lr()[0],
307
+ "pesq_score": average_pesq_score,
308
+ "loss": average_loss,
309
+ "mr_stft_loss": average_mr_stft_loss,
310
+ "neg_si_snr_loss": average_neg_si_snr_loss,
311
+ "mask_loss": average_mask_loss,
312
+ "lsnr_loss": average_lsnr_loss,
313
+ })
314
+
315
+ # evaluation
316
+ step_idx += 1
317
+ if step_idx % config.eval_steps == 0:
318
+ model.eval()
319
+ with torch.no_grad():
320
+ torch.cuda.empty_cache()
321
+
322
+ total_pesq_score = 0.
323
+ total_loss = 0.
324
+ total_mr_stft_loss = 0.
325
+ total_neg_si_snr_loss = 0.
326
+ total_mask_loss = 0.
327
+ total_lsnr_loss = 0.
328
+ total_batches = 0.
329
+
330
+ progress_bar_train.close()
331
+ progress_bar_eval = tqdm(
332
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
333
+ )
334
+ for eval_batch in valid_data_loader:
335
+ clean_audios, noisy_audios = eval_batch
336
+ clean_audios: torch.Tensor = clean_audios.to(device)
337
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
338
+
339
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
340
+
341
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
342
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
343
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
344
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
345
+
346
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
347
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
348
+ logger.info(f"find nan or inf in loss.")
349
+ continue
350
+
351
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
352
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
353
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
354
+
355
+ total_pesq_score += pesq_score
356
+ total_loss += loss.item()
357
+ total_mr_stft_loss += mr_stft_loss.item()
358
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
359
+ total_mask_loss += mask_loss.item()
360
+ total_lsnr_loss += lsnr_loss.item()
361
+ total_batches += 1
362
+
363
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
364
+ average_loss = round(total_loss / total_batches, 4)
365
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
366
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
367
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
368
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
369
+
370
+ progress_bar_eval.update(1)
371
+ progress_bar_eval.set_postfix({
372
+ "lr": lr_scheduler.get_last_lr()[0],
373
+ "pesq_score": average_pesq_score,
374
+ "loss": average_loss,
375
+ "mr_stft_loss": average_mr_stft_loss,
376
+ "neg_si_snr_loss": average_neg_si_snr_loss,
377
+ "mask_loss": average_mask_loss,
378
+ "lsnr_loss": average_lsnr_loss,
379
+ })
380
+
381
+ total_pesq_score = 0.
382
+ total_loss = 0.
383
+ total_mr_stft_loss = 0.
384
+ total_neg_si_snr_loss = 0.
385
+ total_mask_loss = 0.
386
+ total_lsnr_loss = 0.
387
+ total_batches = 0.
388
+
389
+ progress_bar_eval.close()
390
+ progress_bar_train = tqdm(
391
+ initial=progress_bar_train.n,
392
+ postfix=progress_bar_train.postfix,
393
+ desc=progress_bar_train.desc,
394
+ )
395
+
396
+ # save path
397
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
398
+ save_dir.mkdir(parents=True, exist_ok=False)
399
+
400
+ # save models
401
+ model.save_pretrained(save_dir.as_posix())
402
+
403
+ model_list.append(save_dir)
404
+ if len(model_list) >= args.num_serialized_models_to_keep:
405
+ model_to_delete: Path = model_list.pop(0)
406
+ shutil.rmtree(model_to_delete.as_posix())
407
+
408
+ # save metric
409
+ if best_metric is None:
410
+ best_epoch_idx = epoch_idx
411
+ best_step_idx = step_idx
412
+ best_metric = average_pesq_score
413
+ elif average_pesq_score >= best_metric:
414
+ # great is better.
415
+ best_epoch_idx = epoch_idx
416
+ best_step_idx = step_idx
417
+ best_metric = average_pesq_score
418
+ else:
419
+ pass
420
+
421
+ metrics = {
422
+ "epoch_idx": epoch_idx,
423
+ "best_epoch_idx": best_epoch_idx,
424
+ "best_step_idx": best_step_idx,
425
+ "pesq_score": average_pesq_score,
426
+ "loss": average_loss,
427
+ "mr_stft_loss": average_mr_stft_loss,
428
+ "neg_si_snr_loss": average_neg_si_snr_loss,
429
+ "mask_loss": average_mask_loss,
430
+ "lsnr_loss": average_lsnr_loss,
431
+ }
432
+ metrics_filename = save_dir / "metrics_epoch.json"
433
+ with open(metrics_filename, "w", encoding="utf-8") as f:
434
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
435
+
436
+ # save best
437
+ best_dir = serialization_dir / "best"
438
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
439
+ if best_dir.exists():
440
+ shutil.rmtree(best_dir)
441
+ shutil.copytree(save_dir, best_dir)
442
+
443
+ # early stop
444
+ early_stop_flag = False
445
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
446
+ patience_count = 0
447
+ else:
448
+ patience_count += 1
449
+ if patience_count >= args.patience:
450
+ early_stop_flag = True
451
+
452
+ # early stop
453
+ if early_stop_flag:
454
+ break
455
+ model.train()
456
+
457
+ return
458
+
459
+
460
+ if __name__ == "__main__":
461
+ main()
examples/dfnet/yaml/config.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "dfnet"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ spec_bins: 256
10
+
11
+ # model
12
+ conv_channels: 64
13
+ conv_kernel_size_input:
14
+ - 3
15
+ - 3
16
+ conv_kernel_size_inner:
17
+ - 1
18
+ - 3
19
+ conv_lookahead: 0
20
+
21
+ convt_kernel_size_inner:
22
+ - 1
23
+ - 3
24
+
25
+ embedding_hidden_size: 256
26
+ encoder_combine_op: "concat"
27
+
28
+ encoder_emb_skip_op: "none"
29
+ encoder_emb_linear_groups: 16
30
+ encoder_emb_hidden_size: 256
31
+
32
+ encoder_linear_groups: 32
33
+
34
+ decoder_emb_num_layers: 3
35
+ decoder_emb_skip_op: "none"
36
+ decoder_emb_linear_groups: 16
37
+ decoder_emb_hidden_size: 256
38
+
39
+ df_decoder_hidden_size: 256
40
+ df_num_layers: 2
41
+ df_order: 5
42
+ df_bins: 96
43
+ df_gru_skip: "grouped_linear"
44
+ df_decoder_linear_groups: 16
45
+ df_pathway_kernel_size_t: 5
46
+ df_lookahead: 2
47
+
48
+ # lsnr
49
+ n_frame: 3
50
+ lsnr_max: 30
51
+ lsnr_min: -15
52
+ norm_tau: 1.
53
+
54
+ # data
55
+ min_snr_db: -10
56
+ max_snr_db: 20
57
+
58
+ # train
59
+ lr: 0.001
60
+ lr_scheduler: "CosineAnnealingLR"
61
+ lr_scheduler_kwargs:
62
+ T_max: 250000
63
+ eta_min: 0.0001
64
+
65
+ max_epochs: 100
66
+ clip_grad_norm: 10.0
67
+ seed: 1234
68
+
69
+ num_workers: 8
70
+ batch_size: 64
71
+ eval_steps: 10000
72
+
73
+ # runtime
74
+ use_post_filter: true
examples/dfnet2/run.sh ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
6
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
+ --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
+
9
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-dns3 \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
+
13
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \
14
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
15
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
16
+
17
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dfnet2-nx2-dns3 --final_model_name dfnet2-nx2-dns3 \
18
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/" \
19
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/"
20
+
21
+
22
+ END
23
+
24
+
25
+ # params
26
+ system_version="windows";
27
+ verbose=true;
28
+ stage=0 # start from 0 if you need to start from data preparation
29
+ stop_stage=9
30
+
31
+ work_dir="$(pwd)"
32
+ file_folder_name=file_folder_name
33
+ final_model_name=final_model_name
34
+ config_file="yaml/config.yaml"
35
+ limit=10
36
+
37
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
38
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
39
+
40
+ max_count=-1
41
+
42
+ nohup_name=nohup.out
43
+
44
+ # model params
45
+ batch_size=64
46
+ max_epochs=200
47
+ save_top_k=10
48
+ patience=5
49
+
50
+
51
+ # parse options
52
+ while true; do
53
+ [ -z "${1:-}" ] && break; # break if there are no arguments
54
+ case "$1" in
55
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
56
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
57
+ old_value="(eval echo \\$$name)";
58
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
59
+ was_bool=true;
60
+ else
61
+ was_bool=false;
62
+ fi
63
+
64
+ # Set the variable to the right value-- the escaped quotes make it work if
65
+ # the option had spaces, like --cmd "queue.pl -sync y"
66
+ eval "${name}=\"$2\"";
67
+
68
+ # Check that Boolean-valued arguments are really Boolean.
69
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
70
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
71
+ exit 1;
72
+ fi
73
+ shift 2;
74
+ ;;
75
+
76
+ *) break;
77
+ esac
78
+ done
79
+
80
+ file_dir="${work_dir}/${file_folder_name}"
81
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
82
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
83
+
84
+ train_dataset="${file_dir}/train.jsonl"
85
+ valid_dataset="${file_dir}/valid.jsonl"
86
+
87
+ $verbose && echo "system_version: ${system_version}"
88
+ $verbose && echo "file_folder_name: ${file_folder_name}"
89
+
90
+ if [ $system_version == "windows" ]; then
91
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
92
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
93
+ #source /data/local/bin/nx_denoise/bin/activate
94
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
95
+ fi
96
+
97
+
98
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
99
+ $verbose && echo "stage 1: prepare data"
100
+ cd "${work_dir}" || exit 1
101
+ python3 step_1_prepare_data.py \
102
+ --file_dir "${file_dir}" \
103
+ --noise_dir "${noise_dir}" \
104
+ --speech_dir "${speech_dir}" \
105
+ --train_dataset "${train_dataset}" \
106
+ --valid_dataset "${valid_dataset}" \
107
+ --max_count "${max_count}" \
108
+
109
+ fi
110
+
111
+
112
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
113
+ $verbose && echo "stage 2: train model"
114
+ cd "${work_dir}" || exit 1
115
+ python3 step_2_train_model.py \
116
+ --train_dataset "${train_dataset}" \
117
+ --valid_dataset "${valid_dataset}" \
118
+ --serialization_dir "${file_dir}" \
119
+ --config_file "${config_file}" \
120
+
121
+ fi
122
+
123
+
124
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
125
+ $verbose && echo "stage 3: test model"
126
+ cd "${work_dir}" || exit 1
127
+ python3 step_3_evaluation.py \
128
+ --valid_dataset "${valid_dataset}" \
129
+ --model_dir "${file_dir}/best" \
130
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
131
+ --limit "${limit}" \
132
+
133
+ fi
134
+
135
+
136
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
137
+ $verbose && echo "stage 4: collect files"
138
+ cd "${work_dir}" || exit 1
139
+
140
+ mkdir -p ${final_model_dir}
141
+
142
+ cp "${file_dir}/best"/* "${final_model_dir}"
143
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
144
+
145
+ cd "${final_model_dir}/.." || exit 1;
146
+
147
+ if [ -e "${final_model_name}.zip" ]; then
148
+ rm -rf "${final_model_name}_backup.zip"
149
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
150
+ fi
151
+
152
+ zip -r "${final_model_name}.zip" "${final_model_name}"
153
+ rm -rf "${final_model_name}"
154
+
155
+ fi
156
+
157
+
158
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
159
+ $verbose && echo "stage 5: clear file_dir"
160
+ cd "${work_dir}" || exit 1
161
+
162
+ rm -rf "${file_dir}";
163
+
164
+ fi
examples/dfnet2/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=2.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=-1, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/dfnet2/step_2_train_model.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/Rikorose/DeepFilterNet
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ from fontTools.varLib.plot import stops
19
+
20
+ pwd = os.path.abspath(os.path.dirname(__file__))
21
+ sys.path.append(os.path.join(pwd, "../../"))
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.nn import functional as F
27
+ from torch.utils.data.dataloader import DataLoader
28
+ from tqdm import tqdm
29
+
30
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
31
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
33
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
34
+ from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
35
+ from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2, DfNet2PretrainedModel
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
41
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
42
+
43
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
44
+ parser.add_argument("--patience", default=30, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ clean_audios = list()
80
+ noisy_audios = list()
81
+ snr_db_list = list()
82
+
83
+ for sample in batch:
84
+ # noise_wave: torch.Tensor = sample["noise_wave"]
85
+ clean_audio: torch.Tensor = sample["speech_wave"]
86
+ noisy_audio: torch.Tensor = sample["mix_wave"]
87
+ # snr_db: float = sample["snr_db"]
88
+
89
+ clean_audios.append(clean_audio)
90
+ noisy_audios.append(noisy_audio)
91
+
92
+ clean_audios = torch.stack(clean_audios)
93
+ noisy_audios = torch.stack(noisy_audios)
94
+
95
+ # assert
96
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
97
+ raise AssertionError("nan or inf in clean_audios")
98
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
99
+ raise AssertionError("nan or inf in noisy_audios")
100
+ return clean_audios, noisy_audios
101
+
102
+
103
+ collate_fn = CollateFunction()
104
+
105
+
106
+ def main():
107
+ args = get_args()
108
+
109
+ config = DfNet2Config.from_pretrained(
110
+ pretrained_model_name_or_path=args.config_file,
111
+ )
112
+
113
+ serialization_dir = Path(args.serialization_dir)
114
+ serialization_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ logger = logging_config(serialization_dir)
117
+
118
+ random.seed(config.seed)
119
+ np.random.seed(config.seed)
120
+ torch.manual_seed(config.seed)
121
+ logger.info(f"set seed: {config.seed}")
122
+
123
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
+ n_gpu = torch.cuda.device_count()
125
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
126
+
127
+ # datasets
128
+ train_dataset = DenoiseJsonlDataset(
129
+ jsonl_file=args.train_dataset,
130
+ expected_sample_rate=config.sample_rate,
131
+ max_wave_value=32768.0,
132
+ min_snr_db=config.min_snr_db,
133
+ max_snr_db=config.max_snr_db,
134
+ # skip=225000,
135
+ )
136
+ valid_dataset = DenoiseJsonlDataset(
137
+ jsonl_file=args.valid_dataset,
138
+ expected_sample_rate=config.sample_rate,
139
+ max_wave_value=32768.0,
140
+ min_snr_db=config.min_snr_db,
141
+ max_snr_db=config.max_snr_db,
142
+ )
143
+ train_data_loader = DataLoader(
144
+ dataset=train_dataset,
145
+ batch_size=config.batch_size,
146
+ # shuffle=True,
147
+ sampler=None,
148
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
149
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
150
+ collate_fn=collate_fn,
151
+ pin_memory=False,
152
+ prefetch_factor=None if platform.system() == "Windows" else 2,
153
+ )
154
+ valid_data_loader = DataLoader(
155
+ dataset=valid_dataset,
156
+ batch_size=config.batch_size,
157
+ # shuffle=True,
158
+ sampler=None,
159
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
160
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
161
+ collate_fn=collate_fn,
162
+ pin_memory=False,
163
+ prefetch_factor=None if platform.system() == "Windows" else 2,
164
+ )
165
+
166
+ # models
167
+ logger.info(f"prepare models. config_file: {args.config_file}")
168
+ model = DfNet2PretrainedModel(config).to(device)
169
+ model.to(device)
170
+ model.train()
171
+
172
+ # optimizer
173
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
174
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
175
+
176
+ # resume training
177
+ last_step_idx = -1
178
+ last_epoch = -1
179
+ for step_idx_str in serialization_dir.glob("steps-*"):
180
+ step_idx_str = Path(step_idx_str)
181
+ step_idx = step_idx_str.stem.split("-")[1]
182
+ step_idx = int(step_idx)
183
+ if step_idx > last_step_idx:
184
+ last_step_idx = step_idx
185
+ # last_epoch = 1
186
+
187
+ if last_step_idx != -1:
188
+ logger.info(f"resume from steps-{last_step_idx}.")
189
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
190
+
191
+ logger.info(f"load state dict for model.")
192
+ with open(model_pt.as_posix(), "rb") as f:
193
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
+ model.load_state_dict(state_dict, strict=True)
195
+
196
+ if config.lr_scheduler == "CosineAnnealingLR":
197
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
198
+ optimizer,
199
+ last_epoch=last_epoch,
200
+ # T_max=10 * config.eval_steps,
201
+ # eta_min=0.01 * config.lr,
202
+ **config.lr_scheduler_kwargs,
203
+ )
204
+ elif config.lr_scheduler == "MultiStepLR":
205
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
206
+ optimizer,
207
+ last_epoch=last_epoch,
208
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
209
+ )
210
+ else:
211
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
212
+
213
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
214
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
215
+ fft_size_list=[256, 512, 1024],
216
+ win_size_list=[256, 512, 1024],
217
+ hop_size_list=[128, 256, 512],
218
+ factor_sc=1.5,
219
+ factor_mag=1.0,
220
+ reduction="mean"
221
+ ).to(device)
222
+
223
+ # training loop
224
+
225
+ # state
226
+ average_pesq_score = 1000000000
227
+ average_loss = 1000000000
228
+ average_mr_stft_loss = 1000000000
229
+ average_neg_si_snr_loss = 1000000000
230
+ average_mask_loss = 1000000000
231
+ average_lsnr_loss = 1000000000
232
+
233
+ model_list = list()
234
+ best_epoch_idx = None
235
+ best_step_idx = None
236
+ best_metric = None
237
+ patience_count = 0
238
+
239
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
240
+
241
+ logger.info("training")
242
+ early_stop_flag = False
243
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
244
+ if early_stop_flag:
245
+ break
246
+
247
+ # train
248
+ model.train()
249
+
250
+ total_pesq_score = 0.
251
+ total_loss = 0.
252
+ total_mr_stft_loss = 0.
253
+ total_neg_si_snr_loss = 0.
254
+ total_mask_loss = 0.
255
+ total_lsnr_loss = 0.
256
+ total_batches = 0.
257
+
258
+ progress_bar_train = tqdm(
259
+ initial=step_idx,
260
+ desc="Training; epoch-{}".format(epoch_idx),
261
+ )
262
+ for train_batch in train_data_loader:
263
+ clean_audios, noisy_audios = train_batch
264
+ clean_audios: torch.Tensor = clean_audios.to(device)
265
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
266
+
267
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
268
+ # est_wav shape: [b, 1, n_samples]
269
+ est_wav = torch.squeeze(est_wav, dim=1)
270
+ # est_wav shape: [b, n_samples]
271
+
272
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
273
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
274
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
275
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
276
+
277
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
278
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
279
+ logger.info(f"find nan or inf in loss. continue.")
280
+ continue
281
+
282
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
283
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
284
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
285
+
286
+ optimizer.zero_grad()
287
+ loss.backward()
288
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
289
+ optimizer.step()
290
+ lr_scheduler.step()
291
+
292
+ total_pesq_score += pesq_score
293
+ total_loss += loss.item()
294
+ total_mr_stft_loss += mr_stft_loss.item()
295
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
296
+ total_mask_loss += mask_loss.item()
297
+ total_lsnr_loss += lsnr_loss.item()
298
+ total_batches += 1
299
+
300
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
301
+ average_loss = round(total_loss / total_batches, 4)
302
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
303
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
304
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
305
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
306
+
307
+ progress_bar_train.update(1)
308
+ progress_bar_train.set_postfix({
309
+ "lr": lr_scheduler.get_last_lr()[0],
310
+ "pesq_score": average_pesq_score,
311
+ "loss": average_loss,
312
+ "mr_stft_loss": average_mr_stft_loss,
313
+ "neg_si_snr_loss": average_neg_si_snr_loss,
314
+ "mask_loss": average_mask_loss,
315
+ "lsnr_loss": average_lsnr_loss,
316
+ })
317
+
318
+ # evaluation
319
+ step_idx += 1
320
+ if step_idx % config.eval_steps == 0:
321
+ with torch.no_grad():
322
+ torch.cuda.empty_cache()
323
+
324
+ model.eval()
325
+
326
+ total_pesq_score = 0.
327
+ total_loss = 0.
328
+ total_mr_stft_loss = 0.
329
+ total_neg_si_snr_loss = 0.
330
+ total_mask_loss = 0.
331
+ total_lsnr_loss = 0.
332
+ total_batches = 0.
333
+
334
+ progress_bar_train.close()
335
+ progress_bar_eval = tqdm(
336
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
337
+ )
338
+ for eval_batch in valid_data_loader:
339
+ clean_audios, noisy_audios = eval_batch
340
+ clean_audios: torch.Tensor = clean_audios.to(device)
341
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
342
+
343
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
344
+ # est_wav shape: [b, 1, n_samples]
345
+ est_wav = torch.squeeze(est_wav, dim=1)
346
+ # est_wav shape: [b, n_samples]
347
+
348
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
349
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
350
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
351
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
352
+
353
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.01 * lsnr_loss
354
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
355
+ logger.info(f"find nan or inf in loss. continue.")
356
+ continue
357
+
358
+ denoise_audios_list_r = list(est_wav.detach().cpu().numpy())
359
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
360
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
361
+
362
+ total_pesq_score += pesq_score
363
+ total_loss += loss.item()
364
+ total_mr_stft_loss += mr_stft_loss.item()
365
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
366
+ total_mask_loss += mask_loss.item()
367
+ total_lsnr_loss += lsnr_loss.item()
368
+ total_batches += 1
369
+
370
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
371
+ average_loss = round(total_loss / total_batches, 4)
372
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
373
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
374
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
375
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
376
+
377
+ progress_bar_eval.update(1)
378
+ progress_bar_eval.set_postfix({
379
+ "lr": lr_scheduler.get_last_lr()[0],
380
+ "pesq_score": average_pesq_score,
381
+ "loss": average_loss,
382
+ "mr_stft_loss": average_mr_stft_loss,
383
+ "neg_si_snr_loss": average_neg_si_snr_loss,
384
+ "mask_loss": average_mask_loss,
385
+ "lsnr_loss": average_lsnr_loss,
386
+ })
387
+
388
+ model.train()
389
+
390
+ total_pesq_score = 0.
391
+ total_loss = 0.
392
+ total_mr_stft_loss = 0.
393
+ total_neg_si_snr_loss = 0.
394
+ total_mask_loss = 0.
395
+ total_lsnr_loss = 0.
396
+ total_batches = 0.
397
+
398
+ progress_bar_eval.close()
399
+ progress_bar_train = tqdm(
400
+ initial=progress_bar_train.n,
401
+ postfix=progress_bar_train.postfix,
402
+ desc=progress_bar_train.desc,
403
+ )
404
+
405
+ # save path
406
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
407
+ save_dir.mkdir(parents=True, exist_ok=False)
408
+
409
+ # save models
410
+ model.save_pretrained(save_dir.as_posix())
411
+
412
+ model_list.append(save_dir)
413
+ if len(model_list) >= args.num_serialized_models_to_keep:
414
+ model_to_delete: Path = model_list.pop(0)
415
+ shutil.rmtree(model_to_delete.as_posix())
416
+
417
+ # save metric
418
+ if best_metric is None:
419
+ best_epoch_idx = epoch_idx
420
+ best_step_idx = step_idx
421
+ best_metric = average_pesq_score
422
+ elif average_pesq_score >= best_metric:
423
+ # great is better.
424
+ best_epoch_idx = epoch_idx
425
+ best_step_idx = step_idx
426
+ best_metric = average_pesq_score
427
+ else:
428
+ pass
429
+
430
+ metrics = {
431
+ "epoch_idx": epoch_idx,
432
+ "best_epoch_idx": best_epoch_idx,
433
+ "best_step_idx": best_step_idx,
434
+ "pesq_score": average_pesq_score,
435
+ "loss": average_loss,
436
+ "mr_stft_loss": average_mr_stft_loss,
437
+ "neg_si_snr_loss": average_neg_si_snr_loss,
438
+ "mask_loss": average_mask_loss,
439
+ "lsnr_loss": average_lsnr_loss,
440
+ }
441
+ metrics_filename = save_dir / "metrics_epoch.json"
442
+ with open(metrics_filename, "w", encoding="utf-8") as f:
443
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
444
+
445
+ # save best
446
+ best_dir = serialization_dir / "best"
447
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
448
+ if best_dir.exists():
449
+ shutil.rmtree(best_dir)
450
+ shutil.copytree(save_dir, best_dir)
451
+
452
+ # early stop
453
+ early_stop_flag = False
454
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
455
+ patience_count = 0
456
+ else:
457
+ patience_count += 1
458
+ if patience_count >= args.patience:
459
+ early_stop_flag = True
460
+
461
+ # early stop
462
+ if early_stop_flag:
463
+ break
464
+
465
+ return
466
+
467
+
468
+ if __name__ == "__main__":
469
+ main()
examples/dfnet2/yaml/config.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "dfnet2"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ spec_bins: 256
10
+ erb_bins: 32
11
+ min_freq_bins_for_erb: 2
12
+ use_ema_norm: true
13
+
14
+ # model
15
+ conv_channels: 64
16
+ conv_kernel_size_input:
17
+ - 3
18
+ - 3
19
+ conv_kernel_size_inner:
20
+ - 1
21
+ - 3
22
+ convt_kernel_size_inner:
23
+ - 1
24
+ - 3
25
+
26
+ embedding_hidden_size: 256
27
+ encoder_combine_op: "concat"
28
+
29
+ encoder_emb_skip_op: "none"
30
+ encoder_emb_linear_groups: 16
31
+ encoder_emb_hidden_size: 256
32
+
33
+ encoder_linear_groups: 32
34
+
35
+ decoder_emb_num_layers: 3
36
+ decoder_emb_skip_op: "none"
37
+ decoder_emb_linear_groups: 16
38
+ decoder_emb_hidden_size: 256
39
+
40
+ df_decoder_hidden_size: 256
41
+ df_num_layers: 2
42
+ df_order: 5
43
+ df_bins: 96
44
+ df_gru_skip: "grouped_linear"
45
+ df_decoder_linear_groups: 16
46
+ df_pathway_kernel_size_t: 5
47
+ df_lookahead: 2
48
+
49
+ # lsnr
50
+ n_frame: 3
51
+ lsnr_max: 30
52
+ lsnr_min: -15
53
+ norm_tau: 1.
54
+
55
+ # data
56
+ min_snr_db: -5
57
+ max_snr_db: 40
58
+
59
+ # train
60
+ lr: 0.001
61
+ lr_scheduler: "CosineAnnealingLR"
62
+ lr_scheduler_kwargs:
63
+ T_max: 250000
64
+ eta_min: 0.0001
65
+
66
+ max_epochs: 100
67
+ clip_grad_norm: 10.0
68
+ seed: 1234
69
+
70
+ num_workers: 8
71
+ batch_size: 96
72
+ eval_steps: 10000
73
+
74
+ # runtime
75
+ use_post_filter: true
examples/dtln/run.sh ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \
6
+ --config_file "yaml/config-256.yaml" \
7
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
+
10
+
11
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
12
+ --config_file "yaml/config-512.yaml" \
13
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
14
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
15
+
16
+
17
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
18
+ --config_file "yaml/config-1024.yaml" \
19
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
20
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
21
+
22
+
23
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3 --final_model_name dtln-256-nx2-dns3 \
24
+ --config_file "yaml/config-256.yaml" \
25
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
26
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
27
+
28
+
29
+ END
30
+
31
+
32
+ # params
33
+ system_version="windows";
34
+ verbose=true;
35
+ stage=0 # start from 0 if you need to start from data preparation
36
+ stop_stage=9
37
+
38
+ work_dir="$(pwd)"
39
+ file_folder_name=file_folder_name
40
+ final_model_name=final_model_name
41
+ config_file="yaml/config.yaml"
42
+ limit=10
43
+
44
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
45
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
46
+
47
+ max_count=-1
48
+
49
+ nohup_name=nohup.out
50
+
51
+ # model params
52
+ batch_size=64
53
+ max_epochs=200
54
+ save_top_k=10
55
+ patience=5
56
+
57
+
58
+ # parse options
59
+ while true; do
60
+ [ -z "${1:-}" ] && break; # break if there are no arguments
61
+ case "$1" in
62
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
63
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
64
+ old_value="(eval echo \\$$name)";
65
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
66
+ was_bool=true;
67
+ else
68
+ was_bool=false;
69
+ fi
70
+
71
+ # Set the variable to the right value-- the escaped quotes make it work if
72
+ # the option had spaces, like --cmd "queue.pl -sync y"
73
+ eval "${name}=\"$2\"";
74
+
75
+ # Check that Boolean-valued arguments are really Boolean.
76
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
77
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
78
+ exit 1;
79
+ fi
80
+ shift 2;
81
+ ;;
82
+
83
+ *) break;
84
+ esac
85
+ done
86
+
87
+ file_dir="${work_dir}/${file_folder_name}"
88
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
89
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
90
+
91
+ train_dataset="${file_dir}/train.jsonl"
92
+ valid_dataset="${file_dir}/valid.jsonl"
93
+
94
+ $verbose && echo "system_version: ${system_version}"
95
+ $verbose && echo "file_folder_name: ${file_folder_name}"
96
+
97
+ if [ $system_version == "windows" ]; then
98
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
99
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
100
+ #source /data/local/bin/nx_denoise/bin/activate
101
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
102
+ fi
103
+
104
+
105
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
106
+ $verbose && echo "stage 1: prepare data"
107
+ cd "${work_dir}" || exit 1
108
+ python3 step_1_prepare_data.py \
109
+ --file_dir "${file_dir}" \
110
+ --noise_dir "${noise_dir}" \
111
+ --speech_dir "${speech_dir}" \
112
+ --train_dataset "${train_dataset}" \
113
+ --valid_dataset "${valid_dataset}" \
114
+ --max_count "${max_count}" \
115
+
116
+ fi
117
+
118
+
119
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
120
+ $verbose && echo "stage 2: train model"
121
+ cd "${work_dir}" || exit 1
122
+ python3 step_2_train_model.py \
123
+ --train_dataset "${train_dataset}" \
124
+ --valid_dataset "${valid_dataset}" \
125
+ --serialization_dir "${file_dir}" \
126
+ --config_file "${config_file}" \
127
+
128
+ fi
129
+
130
+
131
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
132
+ $verbose && echo "stage 3: test model"
133
+ cd "${work_dir}" || exit 1
134
+ python3 step_3_evaluation.py \
135
+ --valid_dataset "${valid_dataset}" \
136
+ --model_dir "${file_dir}/best" \
137
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
138
+ --limit "${limit}" \
139
+
140
+ fi
141
+
142
+
143
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
144
+ $verbose && echo "stage 4: collect files"
145
+ cd "${work_dir}" || exit 1
146
+
147
+ mkdir -p ${final_model_dir}
148
+
149
+ cp "${file_dir}/best"/* "${final_model_dir}"
150
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
151
+
152
+ cd "${final_model_dir}/.." || exit 1;
153
+
154
+ if [ -e "${final_model_name}.zip" ]; then
155
+ rm -rf "${final_model_name}_backup.zip"
156
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
157
+ fi
158
+
159
+ zip -r "${final_model_name}.zip" "${final_model_name}"
160
+ rm -rf "${final_model_name}"
161
+
162
+ fi
163
+
164
+
165
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
166
+ $verbose && echo "stage 5: clear file_dir"
167
+ cd "${work_dir}" || exit 1
168
+
169
+ rm -rf "${file_dir}";
170
+
171
+ fi
examples/dtln/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=2.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=-1, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/dtln/step_2_train_model.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/breizhn/DTLN
5
+
6
+ """
7
+ import argparse
8
+ import json
9
+ import logging
10
+ from logging.handlers import TimedRotatingFileHandler
11
+ import os
12
+ import platform
13
+ from pathlib import Path
14
+ import random
15
+ import sys
16
+ import shutil
17
+ from typing import List
18
+
19
+ pwd = os.path.abspath(os.path.dirname(__file__))
20
+ sys.path.append(os.path.join(pwd, "../../"))
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.nn import functional as F
26
+ from torch.utils.data.dataloader import DataLoader
27
+ from tqdm import tqdm
28
+
29
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
30
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
31
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
32
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
33
+ from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
34
+ from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
35
+
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
40
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
41
+
42
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
43
+ parser.add_argument("--patience", default=30, type=int)
44
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
+
46
+ parser.add_argument("--config_file", default="config.yaml", type=str)
47
+
48
+ args = parser.parse_args()
49
+ return args
50
+
51
+
52
+ def logging_config(file_dir: str):
53
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
54
+
55
+ logging.basicConfig(format=fmt,
56
+ datefmt="%m/%d/%Y %H:%M:%S",
57
+ level=logging.INFO)
58
+ file_handler = TimedRotatingFileHandler(
59
+ filename=os.path.join(file_dir, "main.log"),
60
+ encoding="utf-8",
61
+ when="D",
62
+ interval=1,
63
+ backupCount=7
64
+ )
65
+ file_handler.setLevel(logging.INFO)
66
+ file_handler.setFormatter(logging.Formatter(fmt))
67
+ logger = logging.getLogger(__name__)
68
+ logger.addHandler(file_handler)
69
+
70
+ return logger
71
+
72
+
73
+ class CollateFunction(object):
74
+ def __init__(self):
75
+ pass
76
+
77
+ def __call__(self, batch: List[dict]):
78
+ clean_audios = list()
79
+ noisy_audios = list()
80
+ snr_db_list = list()
81
+
82
+ for sample in batch:
83
+ # noise_wave: torch.Tensor = sample["noise_wave"]
84
+ clean_audio: torch.Tensor = sample["speech_wave"]
85
+ noisy_audio: torch.Tensor = sample["mix_wave"]
86
+ # snr_db: float = sample["snr_db"]
87
+
88
+ clean_audios.append(clean_audio)
89
+ noisy_audios.append(noisy_audio)
90
+
91
+ clean_audios = torch.stack(clean_audios)
92
+ noisy_audios = torch.stack(noisy_audios)
93
+
94
+ # assert
95
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
96
+ raise AssertionError("nan or inf in clean_audios")
97
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
98
+ raise AssertionError("nan or inf in noisy_audios")
99
+ return clean_audios, noisy_audios
100
+
101
+
102
+ collate_fn = CollateFunction()
103
+
104
+
105
+ def main():
106
+ args = get_args()
107
+
108
+ config = DTLNConfig.from_pretrained(
109
+ pretrained_model_name_or_path=args.config_file,
110
+ )
111
+
112
+ serialization_dir = Path(args.serialization_dir)
113
+ serialization_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ logger = logging_config(serialization_dir)
116
+
117
+ random.seed(config.seed)
118
+ np.random.seed(config.seed)
119
+ torch.manual_seed(config.seed)
120
+ logger.info(f"set seed: {config.seed}")
121
+
122
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ n_gpu = torch.cuda.device_count()
124
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
125
+
126
+ # datasets
127
+ train_dataset = DenoiseJsonlDataset(
128
+ jsonl_file=args.train_dataset,
129
+ expected_sample_rate=config.sample_rate,
130
+ max_wave_value=32768.0,
131
+ min_snr_db=config.min_snr_db,
132
+ max_snr_db=config.max_snr_db,
133
+ # skip=225000,
134
+ )
135
+ valid_dataset = DenoiseJsonlDataset(
136
+ jsonl_file=args.valid_dataset,
137
+ expected_sample_rate=config.sample_rate,
138
+ max_wave_value=32768.0,
139
+ min_snr_db=config.min_snr_db,
140
+ max_snr_db=config.max_snr_db,
141
+ )
142
+ train_data_loader = DataLoader(
143
+ dataset=train_dataset,
144
+ batch_size=config.batch_size,
145
+ # shuffle=True,
146
+ sampler=None,
147
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
148
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
149
+ collate_fn=collate_fn,
150
+ pin_memory=False,
151
+ prefetch_factor=None if platform.system() == "Windows" else 2,
152
+ )
153
+ valid_data_loader = DataLoader(
154
+ dataset=valid_dataset,
155
+ batch_size=config.batch_size,
156
+ # shuffle=True,
157
+ sampler=None,
158
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
159
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
160
+ collate_fn=collate_fn,
161
+ pin_memory=False,
162
+ prefetch_factor=None if platform.system() == "Windows" else 2,
163
+ )
164
+
165
+ # models
166
+ logger.info(f"prepare models. config_file: {args.config_file}")
167
+ model = DTLNPretrainedModel(config).to(device)
168
+ model.to(device)
169
+ model.train()
170
+
171
+ # optimizer
172
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
173
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
174
+
175
+ # resume training
176
+ last_step_idx = -1
177
+ last_epoch = -1
178
+ for step_idx_str in serialization_dir.glob("steps-*"):
179
+ step_idx_str = Path(step_idx_str)
180
+ step_idx = step_idx_str.stem.split("-")[1]
181
+ step_idx = int(step_idx)
182
+ if step_idx > last_step_idx:
183
+ last_step_idx = step_idx
184
+ # last_epoch = 1
185
+
186
+ if last_step_idx != -1:
187
+ logger.info(f"resume from steps-{last_step_idx}.")
188
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
189
+
190
+ logger.info(f"load state dict for model.")
191
+ with open(model_pt.as_posix(), "rb") as f:
192
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
193
+ model.load_state_dict(state_dict, strict=True)
194
+
195
+ if config.lr_scheduler == "CosineAnnealingLR":
196
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
197
+ optimizer,
198
+ last_epoch=last_epoch,
199
+ # T_max=10 * config.eval_steps,
200
+ # eta_min=0.01 * config.lr,
201
+ **config.lr_scheduler_kwargs,
202
+ )
203
+ elif config.lr_scheduler == "MultiStepLR":
204
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
205
+ optimizer,
206
+ last_epoch=last_epoch,
207
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
208
+ )
209
+ else:
210
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
211
+
212
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
213
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
214
+ fft_size_list=[256, 512, 1024],
215
+ win_size_list=[256, 512, 1024],
216
+ hop_size_list=[128, 256, 512],
217
+ factor_sc=1.5,
218
+ factor_mag=1.0,
219
+ reduction="mean"
220
+ ).to(device)
221
+
222
+ # training loop
223
+
224
+ # state
225
+ average_pesq_score = 1000000000
226
+ average_loss = 1000000000
227
+ average_mr_stft_loss = 1000000000
228
+ average_neg_si_snr_loss = 1000000000
229
+
230
+ model_list = list()
231
+ best_epoch_idx = None
232
+ best_step_idx = None
233
+ best_metric = None
234
+ patience_count = 0
235
+
236
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
237
+
238
+ logger.info("training")
239
+ early_stop_flag = False
240
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
241
+ if early_stop_flag:
242
+ break
243
+
244
+ # train
245
+ model.train()
246
+
247
+ total_pesq_score = 0.
248
+ total_loss = 0.
249
+ total_mr_stft_loss = 0.
250
+ total_neg_si_snr_loss = 0.
251
+ total_batches = 0.
252
+
253
+ progress_bar_train = tqdm(
254
+ initial=step_idx,
255
+ desc="Training; epoch-{}".format(epoch_idx),
256
+ )
257
+ for train_batch in train_data_loader:
258
+ clean_audios, noisy_audios = train_batch
259
+ clean_audios: torch.Tensor = clean_audios.to(device)
260
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
261
+
262
+ denoise_audios = model.forward(noisy_audios)
263
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
264
+
265
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
266
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
267
+
268
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
269
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
270
+ logger.info(f"find nan or inf in loss.")
271
+ continue
272
+
273
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
274
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
275
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
276
+
277
+ optimizer.zero_grad()
278
+ loss.backward()
279
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
280
+ optimizer.step()
281
+ lr_scheduler.step()
282
+
283
+ total_pesq_score += pesq_score
284
+ total_loss += loss.item()
285
+ total_mr_stft_loss += mr_stft_loss.item()
286
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
287
+ total_batches += 1
288
+
289
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
290
+ average_loss = round(total_loss / total_batches, 4)
291
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
292
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
293
+
294
+ progress_bar_train.update(1)
295
+ progress_bar_train.set_postfix({
296
+ "lr": lr_scheduler.get_last_lr()[0],
297
+ "pesq_score": average_pesq_score,
298
+ "loss": average_loss,
299
+ "mr_stft_loss": average_mr_stft_loss,
300
+ "neg_si_snr_loss": average_neg_si_snr_loss,
301
+ })
302
+
303
+ # evaluation
304
+ step_idx += 1
305
+ if step_idx % config.eval_steps == 0:
306
+ model.eval()
307
+ with torch.no_grad():
308
+ torch.cuda.empty_cache()
309
+
310
+ total_pesq_score = 0.
311
+ total_loss = 0.
312
+ total_mr_stft_loss = 0.
313
+ total_neg_si_snr_loss = 0.
314
+ total_batches = 0.
315
+
316
+ progress_bar_train.close()
317
+ progress_bar_eval = tqdm(
318
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
319
+ )
320
+ for eval_batch in valid_data_loader:
321
+ clean_audios, noisy_audios = eval_batch
322
+ clean_audios: torch.Tensor = clean_audios.to(device)
323
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
324
+
325
+ denoise_audios = model.forward(noisy_audios)
326
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
327
+
328
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
329
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
330
+
331
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
332
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
333
+ logger.info(f"find nan or inf in loss.")
334
+ continue
335
+
336
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
337
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
338
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
339
+
340
+ total_pesq_score += pesq_score
341
+ total_loss += loss.item()
342
+ total_mr_stft_loss += mr_stft_loss.item()
343
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
344
+ total_batches += 1
345
+
346
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
347
+ average_loss = round(total_loss / total_batches, 4)
348
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
349
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
350
+
351
+ progress_bar_eval.update(1)
352
+ progress_bar_eval.set_postfix({
353
+ "lr": lr_scheduler.get_last_lr()[0],
354
+ "pesq_score": average_pesq_score,
355
+ "loss": average_loss,
356
+ "mr_stft_loss": average_mr_stft_loss,
357
+ "neg_si_snr_loss": average_neg_si_snr_loss,
358
+
359
+ })
360
+
361
+ total_pesq_score = 0.
362
+ total_loss = 0.
363
+ total_mr_stft_loss = 0.
364
+ total_neg_si_snr_loss = 0.
365
+ total_batches = 0.
366
+
367
+ progress_bar_eval.close()
368
+ progress_bar_train = tqdm(
369
+ initial=progress_bar_train.n,
370
+ postfix=progress_bar_train.postfix,
371
+ desc=progress_bar_train.desc,
372
+ )
373
+
374
+ # save path
375
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
376
+ save_dir.mkdir(parents=True, exist_ok=False)
377
+
378
+ # save models
379
+ model.save_pretrained(save_dir.as_posix())
380
+
381
+ model_list.append(save_dir)
382
+ if len(model_list) >= args.num_serialized_models_to_keep:
383
+ model_to_delete: Path = model_list.pop(0)
384
+ shutil.rmtree(model_to_delete.as_posix())
385
+
386
+ # save metric
387
+ if best_metric is None:
388
+ best_epoch_idx = epoch_idx
389
+ best_step_idx = step_idx
390
+ best_metric = average_pesq_score
391
+ elif average_pesq_score >= best_metric:
392
+ # great is better.
393
+ best_epoch_idx = epoch_idx
394
+ best_step_idx = step_idx
395
+ best_metric = average_pesq_score
396
+ else:
397
+ pass
398
+
399
+ metrics = {
400
+ "epoch_idx": epoch_idx,
401
+ "best_epoch_idx": best_epoch_idx,
402
+ "best_step_idx": best_step_idx,
403
+ "pesq_score": average_pesq_score,
404
+ "loss": average_loss,
405
+ "mr_stft_loss": average_mr_stft_loss,
406
+ "neg_si_snr_loss": average_neg_si_snr_loss,
407
+ }
408
+ metrics_filename = save_dir / "metrics_epoch.json"
409
+ with open(metrics_filename, "w", encoding="utf-8") as f:
410
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
411
+
412
+ # save best
413
+ best_dir = serialization_dir / "best"
414
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
415
+ if best_dir.exists():
416
+ shutil.rmtree(best_dir)
417
+ shutil.copytree(save_dir, best_dir)
418
+
419
+ # early stop
420
+ early_stop_flag = False
421
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
422
+ patience_count = 0
423
+ else:
424
+ patience_count += 1
425
+ if patience_count >= args.patience:
426
+ early_stop_flag = True
427
+
428
+ # early stop
429
+ if early_stop_flag:
430
+ break
431
+ model.train()
432
+
433
+ return
434
+
435
+
436
+ if __name__ == "__main__":
437
+ main()
examples/dtln/yaml/config-1024.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ fft_size: 512
6
+ hop_size: 128
7
+ win_type: hann
8
+
9
+ # data
10
+ min_snr_db: -5
11
+ max_snr_db: 25
12
+
13
+ # model
14
+ encoder_size: 1024
15
+
16
+ # train
17
+ lr: 0.001
18
+ lr_scheduler: "CosineAnnealingLR"
19
+ lr_scheduler_kwargs:
20
+ T_max: 250000
21
+ eta_min: 0.0001
22
+
23
+ max_epochs: 100
24
+ clip_grad_norm: 10.0
25
+ seed: 1234
26
+
27
+ num_workers: 4
28
+ batch_size: 64
29
+ eval_steps: 15000
examples/dtln/yaml/config-256.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ fft_size: 256
6
+ hop_size: 128
7
+ win_type: hann
8
+
9
+ # data
10
+ min_snr_db: -5
11
+ max_snr_db: 25
12
+
13
+ # model
14
+ encoder_size: 256
15
+
16
+ # train
17
+ lr: 0.001
18
+ lr_scheduler: "CosineAnnealingLR"
19
+ lr_scheduler_kwargs:
20
+ T_max: 250000
21
+ eta_min: 0.0001
22
+
23
+ max_epochs: 100
24
+ clip_grad_norm: 10.0
25
+ seed: 1234
26
+
27
+ num_workers: 4
28
+ batch_size: 64
29
+ eval_steps: 15000
examples/dtln/yaml/config-512.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ fft_size: 512
6
+ hop_size: 128
7
+ win_type: hann
8
+
9
+ # data
10
+ min_snr_db: -5
11
+ max_snr_db: 25
12
+
13
+ # model
14
+ encoder_size: 512
15
+
16
+ # train
17
+ lr: 0.001
18
+ lr_scheduler: "CosineAnnealingLR"
19
+ lr_scheduler_kwargs:
20
+ T_max: 250000
21
+ eta_min: 0.0001
22
+
23
+ max_epochs: 100
24
+ clip_grad_norm: 10.0
25
+ seed: 1234
26
+
27
+ num_workers: 4
28
+ batch_size: 64
29
+ eval_steps: 15000
examples/dtln_mp3_to_wav/run.sh ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \
6
+ --config_file "yaml/config-256.yaml" \
7
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
+
10
+
11
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
12
+ --config_file "yaml/config-512.yaml" \
13
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
14
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
15
+
16
+
17
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
18
+ --config_file "yaml/config-1024.yaml" \
19
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
20
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
21
+
22
+
23
+ bash run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3-mp3 --final_model_name dtln-256-nx2-dns3-mp3 \
24
+ --config_file "yaml/config-256.yaml" \
25
+ --audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
26
+
27
+
28
+ END
29
+
30
+
31
+ # params
32
+ system_version="windows";
33
+ verbose=true;
34
+ stage=0 # start from 0 if you need to start from data preparation
35
+ stop_stage=9
36
+
37
+ work_dir="$(pwd)"
38
+ file_folder_name=file_folder_name
39
+ final_model_name=final_model_name
40
+ config_file="yaml/config.yaml"
41
+ limit=10
42
+
43
+ audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
44
+
45
+ max_count=-1
46
+
47
+ nohup_name=nohup.out
48
+
49
+ # model params
50
+ batch_size=64
51
+ max_epochs=200
52
+ save_top_k=10
53
+ patience=5
54
+
55
+
56
+ # parse options
57
+ while true; do
58
+ [ -z "${1:-}" ] && break; # break if there are no arguments
59
+ case "$1" in
60
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
61
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
62
+ old_value="(eval echo \\$$name)";
63
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
64
+ was_bool=true;
65
+ else
66
+ was_bool=false;
67
+ fi
68
+
69
+ # Set the variable to the right value-- the escaped quotes make it work if
70
+ # the option had spaces, like --cmd "queue.pl -sync y"
71
+ eval "${name}=\"$2\"";
72
+
73
+ # Check that Boolean-valued arguments are really Boolean.
74
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
75
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
76
+ exit 1;
77
+ fi
78
+ shift 2;
79
+ ;;
80
+
81
+ *) break;
82
+ esac
83
+ done
84
+
85
+ file_dir="${work_dir}/${file_folder_name}"
86
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
87
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
88
+
89
+ train_dataset="${file_dir}/train.jsonl"
90
+ valid_dataset="${file_dir}/valid.jsonl"
91
+
92
+ $verbose && echo "system_version: ${system_version}"
93
+ $verbose && echo "file_folder_name: ${file_folder_name}"
94
+
95
+ if [ $system_version == "windows" ]; then
96
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
97
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
98
+ #source /data/local/bin/nx_denoise/bin/activate
99
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
100
+ fi
101
+
102
+
103
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
104
+ $verbose && echo "stage 1: prepare data"
105
+ cd "${work_dir}" || exit 1
106
+ python3 step_1_prepare_data.py \
107
+ --file_dir "${file_dir}" \
108
+ --audio_dir "${audio_dir}" \
109
+ --train_dataset "${train_dataset}" \
110
+ --valid_dataset "${valid_dataset}" \
111
+ --max_count "${max_count}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
117
+ $verbose && echo "stage 2: train model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_2_train_model.py \
120
+ --train_dataset "${train_dataset}" \
121
+ --valid_dataset "${valid_dataset}" \
122
+ --serialization_dir "${file_dir}" \
123
+ --config_file "${config_file}" \
124
+
125
+ fi
126
+
127
+
128
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
129
+ $verbose && echo "stage 3: test model"
130
+ cd "${work_dir}" || exit 1
131
+ python3 step_3_evaluation.py \
132
+ --valid_dataset "${valid_dataset}" \
133
+ --model_dir "${file_dir}/best" \
134
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
135
+ --limit "${limit}" \
136
+
137
+ fi
138
+
139
+
140
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
141
+ $verbose && echo "stage 4: collect files"
142
+ cd "${work_dir}" || exit 1
143
+
144
+ mkdir -p ${final_model_dir}
145
+
146
+ cp "${file_dir}/best"/* "${final_model_dir}"
147
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
148
+
149
+ cd "${final_model_dir}/.." || exit 1;
150
+
151
+ if [ -e "${final_model_name}.zip" ]; then
152
+ rm -rf "${final_model_name}_backup.zip"
153
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
154
+ fi
155
+
156
+ zip -r "${final_model_name}.zip" "${final_model_name}"
157
+ rm -rf "${final_model_name}"
158
+
159
+ fi
160
+
161
+
162
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
163
+ $verbose && echo "stage 5: clear file_dir"
164
+ cd "${work_dir}" || exit 1
165
+
166
+ rm -rf "${file_dir}";
167
+
168
+ fi
examples/dtln_mp3_to_wav/step_1_prepare_data.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--audio_dir",
24
+ default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
25
+ type=str
26
+ )
27
+
28
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
29
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
30
+
31
+ parser.add_argument("--duration", default=4.0, type=float)
32
+
33
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
34
+
35
+ parser.add_argument("--max_count", default=-1, type=int)
36
+
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
42
+ data_dir = Path(data_dir)
43
+ for epoch_idx in range(max_epoch):
44
+ for filename in data_dir.glob("**/*.wav"):
45
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
46
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
47
+
48
+ if raw_duration < duration:
49
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
50
+ continue
51
+ if signal.ndim != 1:
52
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
53
+
54
+ signal_length = len(signal)
55
+ win_size = int(duration * sample_rate)
56
+ for begin in range(0, signal_length - win_size, win_size):
57
+ if np.sum(signal[begin: begin+win_size]) == 0:
58
+ continue
59
+ row = {
60
+ "epoch_idx": epoch_idx,
61
+ "filename": filename.as_posix(),
62
+ "raw_duration": round(raw_duration, 4),
63
+ "offset": round(begin / sample_rate, 4),
64
+ "duration": round(duration, 4),
65
+ }
66
+ yield row
67
+
68
+
69
+ def main():
70
+ args = get_args()
71
+
72
+ file_dir = Path(args.file_dir)
73
+ file_dir.mkdir(exist_ok=True)
74
+
75
+ audio_dir = Path(args.audio_dir)
76
+
77
+ audio_generator = target_second_signal_generator(
78
+ audio_dir.as_posix(),
79
+ duration=args.duration,
80
+ sample_rate=args.target_sample_rate,
81
+ max_epoch=1,
82
+ )
83
+ count = 0
84
+ process_bar = tqdm(desc="build dataset jsonl")
85
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
86
+ for audio in audio_generator:
87
+ if count >= args.max_count > 0:
88
+ break
89
+
90
+ filename = audio["filename"]
91
+ raw_duration = audio["raw_duration"]
92
+ offset = audio["offset"]
93
+ duration = audio["duration"]
94
+
95
+ random1 = random.random()
96
+ random2 = random.random()
97
+
98
+ row = {
99
+ "count": count,
100
+
101
+ "filename": filename,
102
+ "raw_duration": raw_duration,
103
+ "offset": offset,
104
+ "duration": duration,
105
+
106
+ "random1": random1,
107
+ }
108
+ row = json.dumps(row, ensure_ascii=False)
109
+ if random2 < (1 / 300):
110
+ fvalid.write(f"{row}\n")
111
+ else:
112
+ ftrain.write(f"{row}\n")
113
+
114
+ count += 1
115
+ duration_seconds = count * args.duration
116
+ duration_hours = duration_seconds / 3600
117
+
118
+ process_bar.update(n=1)
119
+ process_bar.set_postfix({
120
+ "duration_hours": round(duration_hours, 4),
121
+ })
122
+
123
+ return
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
examples/dtln_mp3_to_wav/step_2_train_model.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/breizhn/DTLN
5
+
6
+ """
7
+ import argparse
8
+ import json
9
+ import logging
10
+ from logging.handlers import TimedRotatingFileHandler
11
+ import os
12
+ import platform
13
+ from pathlib import Path
14
+ import random
15
+ import sys
16
+ import shutil
17
+ from typing import List
18
+
19
+ pwd = os.path.abspath(os.path.dirname(__file__))
20
+ sys.path.append(os.path.join(pwd, "../../"))
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.nn import functional as F
26
+ from torch.utils.data.dataloader import DataLoader
27
+ from tqdm import tqdm
28
+
29
+ from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
30
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
31
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
32
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
33
+ from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
34
+ from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
35
+
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
40
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
41
+
42
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
43
+ parser.add_argument("--patience", default=30, type=int)
44
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
+
46
+ parser.add_argument("--config_file", default="config.yaml", type=str)
47
+
48
+ args = parser.parse_args()
49
+ return args
50
+
51
+
52
+ def logging_config(file_dir: str):
53
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
54
+
55
+ logging.basicConfig(format=fmt,
56
+ datefmt="%m/%d/%Y %H:%M:%S",
57
+ level=logging.INFO)
58
+ file_handler = TimedRotatingFileHandler(
59
+ filename=os.path.join(file_dir, "main.log"),
60
+ encoding="utf-8",
61
+ when="D",
62
+ interval=1,
63
+ backupCount=7
64
+ )
65
+ file_handler.setLevel(logging.INFO)
66
+ file_handler.setFormatter(logging.Formatter(fmt))
67
+ logger = logging.getLogger(__name__)
68
+ logger.addHandler(file_handler)
69
+
70
+ return logger
71
+
72
+
73
+ class CollateFunction(object):
74
+ def __init__(self):
75
+ pass
76
+
77
+ def __call__(self, batch: List[dict]):
78
+ mp3_waveform_list = list()
79
+ wav_waveform_list = list()
80
+
81
+ for sample in batch:
82
+ mp3_waveform: torch.Tensor = sample["mp3_waveform"]
83
+ wav_waveform: torch.Tensor = sample["wav_waveform"]
84
+
85
+ mp3_waveform_list.append(mp3_waveform)
86
+ wav_waveform_list.append(wav_waveform)
87
+
88
+ mp3_waveform_list = torch.stack(mp3_waveform_list)
89
+ wav_waveform_list = torch.stack(wav_waveform_list)
90
+
91
+ # assert
92
+ if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
93
+ raise AssertionError("nan or inf in mp3_waveform_list")
94
+ if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
95
+ raise AssertionError("nan or inf in wav_waveform_list")
96
+
97
+ return mp3_waveform_list, wav_waveform_list
98
+
99
+
100
+ collate_fn = CollateFunction()
101
+
102
+
103
+ def main():
104
+ args = get_args()
105
+
106
+ config = DTLNConfig.from_pretrained(
107
+ pretrained_model_name_or_path=args.config_file,
108
+ )
109
+
110
+ serialization_dir = Path(args.serialization_dir)
111
+ serialization_dir.mkdir(parents=True, exist_ok=True)
112
+
113
+ logger = logging_config(serialization_dir)
114
+
115
+ random.seed(config.seed)
116
+ np.random.seed(config.seed)
117
+ torch.manual_seed(config.seed)
118
+ logger.info(f"set seed: {config.seed}")
119
+
120
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
+ n_gpu = torch.cuda.device_count()
122
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
123
+
124
+ # datasets
125
+ train_dataset = Mp3ToWavJsonlDataset(
126
+ jsonl_file=args.train_dataset,
127
+ expected_sample_rate=config.sample_rate,
128
+ max_wave_value=32768.0,
129
+ # skip=225000,
130
+ )
131
+ valid_dataset = Mp3ToWavJsonlDataset(
132
+ jsonl_file=args.valid_dataset,
133
+ expected_sample_rate=config.sample_rate,
134
+ max_wave_value=32768.0,
135
+ )
136
+ train_data_loader = DataLoader(
137
+ dataset=train_dataset,
138
+ batch_size=config.batch_size,
139
+ # shuffle=True,
140
+ sampler=None,
141
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
142
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
143
+ collate_fn=collate_fn,
144
+ pin_memory=False,
145
+ prefetch_factor=None if platform.system() == "Windows" else 2,
146
+ )
147
+ valid_data_loader = DataLoader(
148
+ dataset=valid_dataset,
149
+ batch_size=config.batch_size,
150
+ # shuffle=True,
151
+ sampler=None,
152
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
153
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
154
+ collate_fn=collate_fn,
155
+ pin_memory=False,
156
+ prefetch_factor=None if platform.system() == "Windows" else 2,
157
+ )
158
+
159
+ # models
160
+ logger.info(f"prepare models. config_file: {args.config_file}")
161
+ model = DTLNPretrainedModel(config).to(device)
162
+ model.to(device)
163
+ model.train()
164
+
165
+ # optimizer
166
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
167
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
168
+
169
+ # resume training
170
+ last_step_idx = -1
171
+ last_epoch = -1
172
+ for step_idx_str in serialization_dir.glob("steps-*"):
173
+ step_idx_str = Path(step_idx_str)
174
+ step_idx = step_idx_str.stem.split("-")[1]
175
+ step_idx = int(step_idx)
176
+ if step_idx > last_step_idx:
177
+ last_step_idx = step_idx
178
+ # last_epoch = 1
179
+
180
+ if last_step_idx != -1:
181
+ logger.info(f"resume from steps-{last_step_idx}.")
182
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
183
+
184
+ logger.info(f"load state dict for model.")
185
+ with open(model_pt.as_posix(), "rb") as f:
186
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
187
+ model.load_state_dict(state_dict, strict=True)
188
+
189
+ if config.lr_scheduler == "CosineAnnealingLR":
190
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
191
+ optimizer,
192
+ last_epoch=last_epoch,
193
+ # T_max=10 * config.eval_steps,
194
+ # eta_min=0.01 * config.lr,
195
+ **config.lr_scheduler_kwargs,
196
+ )
197
+ elif config.lr_scheduler == "MultiStepLR":
198
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
199
+ optimizer,
200
+ last_epoch=last_epoch,
201
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
202
+ )
203
+ else:
204
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
205
+
206
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
207
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
208
+ fft_size_list=[256, 512, 1024],
209
+ win_size_list=[256, 512, 1024],
210
+ hop_size_list=[128, 256, 512],
211
+ factor_sc=1.5,
212
+ factor_mag=1.0,
213
+ reduction="mean"
214
+ ).to(device)
215
+ audio_l1_loss_fn = nn.L1Loss(reduction="mean")
216
+
217
+ # training loop
218
+
219
+ # state
220
+ average_pesq_score = 1000000000
221
+ average_loss = 1000000000
222
+ average_mr_stft_loss = 1000000000
223
+ average_audio_l1_loss = 1000000000
224
+ average_neg_si_snr_loss = 1000000000
225
+
226
+ model_list = list()
227
+ best_epoch_idx = None
228
+ best_step_idx = None
229
+ best_metric = None
230
+ patience_count = 0
231
+
232
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
233
+
234
+ logger.info("training")
235
+ early_stop_flag = False
236
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
237
+ if early_stop_flag:
238
+ break
239
+
240
+ # train
241
+ model.train()
242
+
243
+ total_pesq_score = 0.
244
+ total_loss = 0.
245
+ total_mr_stft_loss = 0.
246
+ total_audio_l1_loss = 0.
247
+ total_neg_si_snr_loss = 0.
248
+ total_batches = 0.
249
+
250
+ progress_bar_train = tqdm(
251
+ initial=step_idx,
252
+ desc="Training; epoch-{}".format(epoch_idx),
253
+ )
254
+ for train_batch in train_data_loader:
255
+ mp3_audios, wav_audios = train_batch
256
+ noisy_audios: torch.Tensor = mp3_audios.to(device)
257
+ clean_audios: torch.Tensor = wav_audios.to(device)
258
+
259
+ denoise_audios = model.forward(noisy_audios)
260
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
261
+
262
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
263
+ audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
264
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
265
+
266
+ loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
267
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
268
+ logger.info(f"find nan or inf in loss.")
269
+ continue
270
+
271
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
272
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
273
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
274
+
275
+ optimizer.zero_grad()
276
+ loss.backward()
277
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
278
+ optimizer.step()
279
+ lr_scheduler.step()
280
+
281
+ total_pesq_score += pesq_score
282
+ total_loss += loss.item()
283
+ total_mr_stft_loss += mr_stft_loss.item()
284
+ total_audio_l1_loss += audio_l1_loss.item()
285
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
286
+ total_batches += 1
287
+
288
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
289
+ average_loss = round(total_loss / total_batches, 4)
290
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
291
+ average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
292
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
293
+
294
+ progress_bar_train.update(1)
295
+ progress_bar_train.set_postfix({
296
+ "lr": lr_scheduler.get_last_lr()[0],
297
+ "pesq_score": average_pesq_score,
298
+ "loss": average_loss,
299
+ "mr_stft_loss": average_mr_stft_loss,
300
+ "audio_l1_loss": average_audio_l1_loss,
301
+ "neg_si_snr_loss": average_neg_si_snr_loss,
302
+ })
303
+
304
+ # evaluation
305
+ step_idx += 1
306
+ if step_idx % config.eval_steps == 0:
307
+ model.eval()
308
+ with torch.no_grad():
309
+ torch.cuda.empty_cache()
310
+
311
+ total_pesq_score = 0.
312
+ total_loss = 0.
313
+ total_mr_stft_loss = 0.
314
+ total_audio_l1_loss = 0.
315
+ total_neg_si_snr_loss = 0.
316
+ total_batches = 0.
317
+
318
+ progress_bar_train.close()
319
+ progress_bar_eval = tqdm(
320
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
321
+ )
322
+ for eval_batch in valid_data_loader:
323
+ mp3_audios, wav_audios = eval_batch
324
+ noisy_audios: torch.Tensor = mp3_audios.to(device)
325
+ clean_audios: torch.Tensor = wav_audios.to(device)
326
+
327
+ denoise_audios = model.forward(noisy_audios)
328
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
329
+
330
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
331
+ audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
332
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
333
+
334
+ loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
335
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
336
+ logger.info(f"find nan or inf in loss.")
337
+ continue
338
+
339
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
340
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
341
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
342
+
343
+ total_pesq_score += pesq_score
344
+ total_loss += loss.item()
345
+ total_mr_stft_loss += mr_stft_loss.item()
346
+ total_audio_l1_loss += audio_l1_loss.item()
347
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
348
+ total_batches += 1
349
+
350
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
351
+ average_loss = round(total_loss / total_batches, 4)
352
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
353
+ average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
354
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
355
+
356
+ progress_bar_eval.update(1)
357
+ progress_bar_eval.set_postfix({
358
+ "lr": lr_scheduler.get_last_lr()[0],
359
+ "pesq_score": average_pesq_score,
360
+ "loss": average_loss,
361
+ "mr_stft_loss": average_mr_stft_loss,
362
+ "audio_l1_loss": average_audio_l1_loss,
363
+ "neg_si_snr_loss": average_neg_si_snr_loss,
364
+
365
+ })
366
+
367
+ total_pesq_score = 0.
368
+ total_loss = 0.
369
+ total_mr_stft_loss = 0.
370
+ total_audio_l1_loss = 0.
371
+ total_neg_si_snr_loss = 0.
372
+ total_batches = 0.
373
+
374
+ progress_bar_eval.close()
375
+ progress_bar_train = tqdm(
376
+ initial=progress_bar_train.n,
377
+ postfix=progress_bar_train.postfix,
378
+ desc=progress_bar_train.desc,
379
+ )
380
+
381
+ # save path
382
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
383
+ save_dir.mkdir(parents=True, exist_ok=False)
384
+
385
+ # save models
386
+ model.save_pretrained(save_dir.as_posix())
387
+
388
+ model_list.append(save_dir)
389
+ if len(model_list) >= args.num_serialized_models_to_keep:
390
+ model_to_delete: Path = model_list.pop(0)
391
+ shutil.rmtree(model_to_delete.as_posix())
392
+
393
+ # save metric
394
+ if best_metric is None:
395
+ best_epoch_idx = epoch_idx
396
+ best_step_idx = step_idx
397
+ best_metric = average_pesq_score
398
+ elif average_pesq_score >= best_metric:
399
+ # great is better.
400
+ best_epoch_idx = epoch_idx
401
+ best_step_idx = step_idx
402
+ best_metric = average_pesq_score
403
+ else:
404
+ pass
405
+
406
+ metrics = {
407
+ "epoch_idx": epoch_idx,
408
+ "best_epoch_idx": best_epoch_idx,
409
+ "best_step_idx": best_step_idx,
410
+ "pesq_score": average_pesq_score,
411
+ "loss": average_loss,
412
+ "mr_stft_loss": average_mr_stft_loss,
413
+ "audio_l1_loss": average_audio_l1_loss,
414
+ "neg_si_snr_loss": average_neg_si_snr_loss,
415
+ }
416
+ metrics_filename = save_dir / "metrics_epoch.json"
417
+ with open(metrics_filename, "w", encoding="utf-8") as f:
418
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
419
+
420
+ # save best
421
+ best_dir = serialization_dir / "best"
422
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
423
+ if best_dir.exists():
424
+ shutil.rmtree(best_dir)
425
+ shutil.copytree(save_dir, best_dir)
426
+
427
+ # early stop
428
+ early_stop_flag = False
429
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
430
+ patience_count = 0
431
+ else:
432
+ patience_count += 1
433
+ if patience_count >= args.patience:
434
+ early_stop_flag = True
435
+
436
+ # early stop
437
+ if early_stop_flag:
438
+ break
439
+ model.train()
440
+
441
+ return
442
+
443
+
444
+ if __name__ == "__main__":
445
+ main()
examples/dtln_mp3_to_wav/yaml/config-1024.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ fft_size: 512
6
+ hop_size: 128
7
+ win_type: hann
8
+
9
+ # data
10
+ min_snr_db: -5
11
+ max_snr_db: 25
12
+
13
+ # model
14
+ encoder_size: 1024
15
+
16
+ # train
17
+ lr: 0.001
18
+ lr_scheduler: "CosineAnnealingLR"
19
+ lr_scheduler_kwargs:
20
+ T_max: 250000
21
+ eta_min: 0.0001
22
+
23
+ max_epochs: 100
24
+ clip_grad_norm: 10.0
25
+ seed: 1234
26
+
27
+ num_workers: 4
28
+ batch_size: 64
29
+ eval_steps: 15000
examples/dtln_mp3_to_wav/yaml/config-256.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ fft_size: 256
6
+ hop_size: 128
7
+ win_type: hann
8
+
9
+ # data
10
+ min_snr_db: -5
11
+ max_snr_db: 25
12
+
13
+ # model
14
+ encoder_size: 256
15
+
16
+ # train
17
+ lr: 0.001
18
+ lr_scheduler: "CosineAnnealingLR"
19
+ lr_scheduler_kwargs:
20
+ T_max: 250000
21
+ eta_min: 0.0001
22
+
23
+ max_epochs: 100
24
+ clip_grad_norm: 10.0
25
+ seed: 1234
26
+
27
+ num_workers: 4
28
+ batch_size: 64
29
+ eval_steps: 15000
examples/dtln_mp3_to_wav/yaml/config-512.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "DTLN"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ fft_size: 512
6
+ hop_size: 128
7
+ win_type: hann
8
+
9
+ # data
10
+ min_snr_db: -5
11
+ max_snr_db: 25
12
+
13
+ # model
14
+ encoder_size: 512
15
+
16
+ # train
17
+ lr: 0.001
18
+ lr_scheduler: "CosineAnnealingLR"
19
+ lr_scheduler_kwargs:
20
+ T_max: 250000
21
+ eta_min: 0.0001
22
+
23
+ max_epochs: 100
24
+ clip_grad_norm: 10.0
25
+ seed: 1234
26
+
27
+ num_workers: 4
28
+ batch_size: 64
29
+ eval_steps: 15000
examples/frcrn/run.sh ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
7
+ --config_file "yaml/config-10.yaml" \
8
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
9
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
10
+
11
+
12
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
13
+ --config_file "yaml/config-10.yaml" \
14
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
15
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
16
+
17
+ END
18
+
19
+
20
+ # params
21
+ system_version="windows";
22
+ verbose=true;
23
+ stage=0 # start from 0 if you need to start from data preparation
24
+ stop_stage=9
25
+
26
+ work_dir="$(pwd)"
27
+ file_folder_name=file_folder_name
28
+ final_model_name=final_model_name
29
+ config_file="yaml/config.yaml"
30
+ limit=10
31
+
32
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
33
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
34
+
35
+ max_count=10000000
36
+
37
+ nohup_name=nohup.out
38
+
39
+ # model params
40
+ batch_size=64
41
+ max_epochs=200
42
+ save_top_k=10
43
+ patience=5
44
+
45
+
46
+ # parse options
47
+ while true; do
48
+ [ -z "${1:-}" ] && break; # break if there are no arguments
49
+ case "$1" in
50
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
51
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
52
+ old_value="(eval echo \\$$name)";
53
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
54
+ was_bool=true;
55
+ else
56
+ was_bool=false;
57
+ fi
58
+
59
+ # Set the variable to the right value-- the escaped quotes make it work if
60
+ # the option had spaces, like --cmd "queue.pl -sync y"
61
+ eval "${name}=\"$2\"";
62
+
63
+ # Check that Boolean-valued arguments are really Boolean.
64
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
65
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
66
+ exit 1;
67
+ fi
68
+ shift 2;
69
+ ;;
70
+
71
+ *) break;
72
+ esac
73
+ done
74
+
75
+ file_dir="${work_dir}/${file_folder_name}"
76
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
77
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
78
+
79
+ train_dataset="${file_dir}/train.jsonl"
80
+ valid_dataset="${file_dir}/valid.jsonl"
81
+
82
+ $verbose && echo "system_version: ${system_version}"
83
+ $verbose && echo "file_folder_name: ${file_folder_name}"
84
+
85
+ if [ $system_version == "windows" ]; then
86
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
87
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
88
+ #source /data/local/bin/nx_denoise/bin/activate
89
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
90
+ fi
91
+
92
+
93
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
94
+ $verbose && echo "stage 1: prepare data"
95
+ cd "${work_dir}" || exit 1
96
+ python3 step_1_prepare_data.py \
97
+ --file_dir "${file_dir}" \
98
+ --noise_dir "${noise_dir}" \
99
+ --speech_dir "${speech_dir}" \
100
+ --train_dataset "${train_dataset}" \
101
+ --valid_dataset "${valid_dataset}" \
102
+ --max_count "${max_count}" \
103
+
104
+ fi
105
+
106
+
107
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
108
+ $verbose && echo "stage 2: train model"
109
+ cd "${work_dir}" || exit 1
110
+ python3 step_2_train_model.py \
111
+ --train_dataset "${train_dataset}" \
112
+ --valid_dataset "${valid_dataset}" \
113
+ --serialization_dir "${file_dir}" \
114
+ --config_file "${config_file}" \
115
+
116
+ fi
117
+
118
+
119
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
120
+ $verbose && echo "stage 3: test model"
121
+ cd "${work_dir}" || exit 1
122
+ python3 step_3_evaluation.py \
123
+ --valid_dataset "${valid_dataset}" \
124
+ --model_dir "${file_dir}/best" \
125
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
126
+ --limit "${limit}" \
127
+
128
+ fi
129
+
130
+
131
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
132
+ $verbose && echo "stage 4: collect files"
133
+ cd "${work_dir}" || exit 1
134
+
135
+ mkdir -p ${final_model_dir}
136
+
137
+ cp "${file_dir}/best"/* "${final_model_dir}"
138
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
139
+
140
+ cd "${final_model_dir}/.." || exit 1;
141
+
142
+ if [ -e "${final_model_name}.zip" ]; then
143
+ rm -rf "${final_model_name}_backup.zip"
144
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
145
+ fi
146
+
147
+ zip -r "${final_model_name}.zip" "${final_model_name}"
148
+ rm -rf "${final_model_name}"
149
+
150
+ fi
151
+
152
+
153
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
154
+ $verbose && echo "stage 5: clear file_dir"
155
+ cd "${work_dir}" || exit 1
156
+
157
+ rm -rf "${file_dir}";
158
+
159
+ fi
examples/frcrn/step_1_prepare_data.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--noise_dir",
24
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--speech_dir",
29
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
30
+ type=str
31
+ )
32
+
33
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
34
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
35
+
36
+ parser.add_argument("--duration", default=2.0, type=float)
37
+ parser.add_argument("--min_snr_db", default=-10, type=float)
38
+ parser.add_argument("--max_snr_db", default=20, type=float)
39
+
40
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
41
+
42
+ parser.add_argument("--max_count", default=-1, type=int)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def filename_generator(data_dir: str):
49
+ data_dir = Path(data_dir)
50
+ for filename in data_dir.glob("**/*.wav"):
51
+ yield filename.as_posix()
52
+
53
+
54
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
55
+ data_dir = Path(data_dir)
56
+ for epoch_idx in range(max_epoch):
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ if np.sum(signal[begin: begin+win_size]) == 0:
71
+ continue
72
+ row = {
73
+ "epoch_idx": epoch_idx,
74
+ "filename": filename.as_posix(),
75
+ "raw_duration": round(raw_duration, 4),
76
+ "offset": round(begin / sample_rate, 4),
77
+ "duration": round(duration, 4),
78
+ }
79
+ yield row
80
+
81
+
82
+ def main():
83
+ args = get_args()
84
+
85
+ file_dir = Path(args.file_dir)
86
+ file_dir.mkdir(exist_ok=True)
87
+
88
+ noise_dir = Path(args.noise_dir)
89
+ speech_dir = Path(args.speech_dir)
90
+
91
+ noise_generator = target_second_signal_generator(
92
+ noise_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate,
95
+ max_epoch=100000,
96
+ )
97
+ speech_generator = target_second_signal_generator(
98
+ speech_dir.as_posix(),
99
+ duration=args.duration,
100
+ sample_rate=args.target_sample_rate,
101
+ max_epoch=1,
102
+ )
103
+
104
+ dataset = list()
105
+
106
+ count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
+ for noise, speech in zip(noise_generator, speech_generator):
110
+ if count >= args.max_count > 0:
111
+ break
112
+
113
+ noise_filename = noise["filename"]
114
+ noise_raw_duration = noise["raw_duration"]
115
+ noise_offset = noise["offset"]
116
+ noise_duration = noise["duration"]
117
+
118
+ speech_filename = speech["filename"]
119
+ speech_raw_duration = speech["raw_duration"]
120
+ speech_offset = speech["offset"]
121
+ speech_duration = speech["duration"]
122
+
123
+ random1 = random.random()
124
+ random2 = random.random()
125
+
126
+ row = {
127
+ "count": count,
128
+
129
+ "noise_filename": noise_filename,
130
+ "noise_raw_duration": noise_raw_duration,
131
+ "noise_offset": noise_offset,
132
+ "noise_duration": noise_duration,
133
+
134
+ "speech_filename": speech_filename,
135
+ "speech_raw_duration": speech_raw_duration,
136
+ "speech_offset": speech_offset,
137
+ "speech_duration": speech_duration,
138
+
139
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
140
+
141
+ "random1": random1,
142
+ }
143
+ row = json.dumps(row, ensure_ascii=False)
144
+ if random2 < (1 / 300 / 1):
145
+ fvalid.write(f"{row}\n")
146
+ else:
147
+ ftrain.write(f"{row}\n")
148
+
149
+ count += 1
150
+ duration_seconds = count * args.duration
151
+ duration_hours = duration_seconds / 3600
152
+
153
+ process_bar.update(n=1)
154
+ process_bar.set_postfix({
155
+ # "duration_seconds": round(duration_seconds, 4),
156
+ "duration_hours": round(duration_hours, 4),
157
+
158
+ })
159
+
160
+ return
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
examples/frcrn/step_2_train_model.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/2206.07293
5
+
6
+ FRCRN 论文中:
7
+ 在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33
8
+
9
+ WSJ0 包含约 80小时的纯净英语语音录音.
10
+
11
+ 我的音频大约是 1300 小时, 则预期大约需要 10个 epoch
12
+ """
13
+ import argparse
14
+ import json
15
+ import logging
16
+ from logging.handlers import TimedRotatingFileHandler
17
+ import os
18
+ import platform
19
+ from pathlib import Path
20
+ import random
21
+ import sys
22
+ import shutil
23
+ from typing import List
24
+
25
+ pwd = os.path.abspath(os.path.dirname(__file__))
26
+ sys.path.append(os.path.join(pwd, "../../"))
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.nn import functional as F
32
+ from torch.utils.data.dataloader import DataLoader
33
+ from tqdm import tqdm
34
+
35
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
36
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
37
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
38
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
39
+ from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
40
+ from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
46
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
47
+
48
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
49
+ parser.add_argument("--patience", default=30, type=int)
50
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
51
+
52
+ parser.add_argument("--config_file", default="config.yaml", type=str)
53
+
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ def logging_config(file_dir: str):
59
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
60
+
61
+ logging.basicConfig(format=fmt,
62
+ datefmt="%m/%d/%Y %H:%M:%S",
63
+ level=logging.INFO)
64
+ file_handler = TimedRotatingFileHandler(
65
+ filename=os.path.join(file_dir, "main.log"),
66
+ encoding="utf-8",
67
+ when="D",
68
+ interval=1,
69
+ backupCount=7
70
+ )
71
+ file_handler.setLevel(logging.INFO)
72
+ file_handler.setFormatter(logging.Formatter(fmt))
73
+ logger = logging.getLogger(__name__)
74
+ logger.addHandler(file_handler)
75
+
76
+ return logger
77
+
78
+
79
+ class CollateFunction(object):
80
+ def __init__(self):
81
+ pass
82
+
83
+ def __call__(self, batch: List[dict]):
84
+ clean_audios = list()
85
+ noisy_audios = list()
86
+
87
+ for sample in batch:
88
+ # noise_wave: torch.Tensor = sample["noise_wave"]
89
+ clean_audio: torch.Tensor = sample["speech_wave"]
90
+ noisy_audio: torch.Tensor = sample["mix_wave"]
91
+ # snr_db: float = sample["snr_db"]
92
+
93
+ clean_audios.append(clean_audio)
94
+ noisy_audios.append(noisy_audio)
95
+
96
+ clean_audios = torch.stack(clean_audios)
97
+ noisy_audios = torch.stack(noisy_audios)
98
+
99
+ # assert
100
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
101
+ raise AssertionError("nan or inf in clean_audios")
102
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
103
+ raise AssertionError("nan or inf in noisy_audios")
104
+ return clean_audios, noisy_audios
105
+
106
+
107
+ collate_fn = CollateFunction()
108
+
109
+
110
+ def main():
111
+ args = get_args()
112
+
113
+ config = FRCRNConfig.from_pretrained(
114
+ pretrained_model_name_or_path=args.config_file,
115
+ )
116
+
117
+ serialization_dir = Path(args.serialization_dir)
118
+ serialization_dir.mkdir(parents=True, exist_ok=True)
119
+
120
+ logger = logging_config(serialization_dir)
121
+
122
+ random.seed(config.seed)
123
+ np.random.seed(config.seed)
124
+ torch.manual_seed(config.seed)
125
+ logger.info(f"set seed: {config.seed}")
126
+
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ n_gpu = torch.cuda.device_count()
129
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
130
+
131
+ # datasets
132
+ train_dataset = DenoiseJsonlDataset(
133
+ jsonl_file=args.train_dataset,
134
+ expected_sample_rate=config.sample_rate,
135
+ max_wave_value=32768.0,
136
+ min_snr_db=config.min_snr_db,
137
+ max_snr_db=config.max_snr_db,
138
+ # skip=225000,
139
+ )
140
+ valid_dataset = DenoiseJsonlDataset(
141
+ jsonl_file=args.valid_dataset,
142
+ expected_sample_rate=config.sample_rate,
143
+ max_wave_value=32768.0,
144
+ min_snr_db=config.min_snr_db,
145
+ max_snr_db=config.max_snr_db,
146
+ )
147
+ train_data_loader = DataLoader(
148
+ dataset=train_dataset,
149
+ batch_size=config.batch_size,
150
+ # shuffle=True,
151
+ sampler=None,
152
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
153
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
154
+ collate_fn=collate_fn,
155
+ pin_memory=False,
156
+ prefetch_factor=2,
157
+ )
158
+ valid_data_loader = DataLoader(
159
+ dataset=valid_dataset,
160
+ batch_size=config.batch_size,
161
+ # shuffle=True,
162
+ sampler=None,
163
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
164
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
165
+ collate_fn=collate_fn,
166
+ pin_memory=False,
167
+ prefetch_factor=2,
168
+ )
169
+
170
+ # models
171
+ logger.info(f"prepare models. config_file: {args.config_file}")
172
+ model = FRCRNPretrainedModel(config).to(device)
173
+ model.to(device)
174
+ model.train()
175
+
176
+ # optimizer
177
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
178
+ optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
179
+
180
+ # resume training
181
+ last_step_idx = -1
182
+ last_epoch = -1
183
+ for step_idx_str in serialization_dir.glob("steps-*"):
184
+ step_idx_str = Path(step_idx_str)
185
+ step_idx = step_idx_str.stem.split("-")[1]
186
+ step_idx = int(step_idx)
187
+ if step_idx > last_step_idx:
188
+ last_step_idx = step_idx
189
+ # last_epoch = 0
190
+
191
+ if last_step_idx != -1:
192
+ logger.info(f"resume from steps-{last_step_idx}.")
193
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
194
+ # optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
195
+
196
+ logger.info(f"load state dict for model.")
197
+ with open(model_pt.as_posix(), "rb") as f:
198
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
199
+ model.load_state_dict(state_dict, strict=True)
200
+
201
+ # logger.info(f"load state dict for optimizer.")
202
+ # with open(optimizer_pth.as_posix(), "rb") as f:
203
+ # state_dict = torch.load(f, map_location="cpu", weights_only=True)
204
+ # optimizer.load_state_dict(state_dict)
205
+
206
+ if config.lr_scheduler == "CosineAnnealingLR":
207
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
208
+ optimizer,
209
+ last_epoch=last_epoch,
210
+ # T_max=10 * config.eval_steps,
211
+ # eta_min=0.01 * config.lr,
212
+ **config.lr_scheduler_kwargs,
213
+ )
214
+ elif config.lr_scheduler == "MultiStepLR":
215
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
216
+ optimizer,
217
+ last_epoch=last_epoch,
218
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
219
+ )
220
+ else:
221
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
222
+
223
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
224
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
225
+ fft_size_list=[256, 512, 1024],
226
+ win_size_list=[256, 512, 1024],
227
+ hop_size_list=[128, 256, 512],
228
+ factor_sc=1.5,
229
+ factor_mag=1.0,
230
+ reduction="mean"
231
+ ).to(device)
232
+
233
+ # training loop
234
+
235
+ # state
236
+ average_pesq_score = 1000000000
237
+ average_loss = 1000000000
238
+ average_neg_si_snr_loss = 1000000000
239
+ average_mask_loss = 1000000000
240
+
241
+ model_list = list()
242
+ best_epoch_idx = None
243
+ best_step_idx = None
244
+ best_metric = None
245
+ patience_count = 0
246
+
247
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
248
+
249
+ logger.info("training")
250
+ early_stop_flag = False
251
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
252
+ if early_stop_flag:
253
+ break
254
+
255
+ # train
256
+ model.train()
257
+
258
+ total_pesq_score = 0.
259
+ total_loss = 0.
260
+ total_mr_stft_loss = 0.
261
+ total_neg_si_snr_loss = 0.
262
+ total_mask_loss = 0.
263
+ total_batches = 0.
264
+
265
+ progress_bar_train = tqdm(
266
+ initial=step_idx,
267
+ desc="Training; epoch-{}".format(epoch_idx),
268
+ )
269
+ for train_batch in train_data_loader:
270
+ clean_audios, noisy_audios = train_batch
271
+ clean_audios: torch.Tensor = clean_audios.to(device)
272
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
273
+
274
+ est_spec, est_wav, est_mask = model.forward(noisy_audios)
275
+ denoise_audios = est_wav
276
+
277
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
278
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
279
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
280
+
281
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
282
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
283
+ logger.info(f"find nan or inf in loss.")
284
+ continue
285
+
286
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
287
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
288
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
289
+
290
+ optimizer.zero_grad()
291
+ loss.backward()
292
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
293
+ optimizer.step()
294
+ lr_scheduler.step()
295
+
296
+ total_pesq_score += pesq_score
297
+ total_loss += loss.item()
298
+ total_mr_stft_loss += mr_stft_loss.item()
299
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
300
+ total_mask_loss += mask_loss.item()
301
+ total_batches += 1
302
+
303
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
304
+ average_loss = round(total_loss / total_batches, 4)
305
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
306
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
307
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
308
+
309
+ progress_bar_train.update(1)
310
+ progress_bar_train.set_postfix({
311
+ "lr": lr_scheduler.get_last_lr()[0],
312
+ "pesq_score": average_pesq_score,
313
+ "loss": average_loss,
314
+ "mr_stft_loss": average_mr_stft_loss,
315
+ "neg_si_snr_loss": average_neg_si_snr_loss,
316
+ "mask_loss": average_mask_loss,
317
+ })
318
+
319
+ # evaluation
320
+ step_idx += 1
321
+ if step_idx % config.eval_steps == 0:
322
+ model.eval()
323
+ with torch.no_grad():
324
+ torch.cuda.empty_cache()
325
+
326
+ total_pesq_score = 0.
327
+ total_loss = 0.
328
+ total_mr_stft_loss = 0.
329
+ total_neg_si_snr_loss = 0.
330
+ total_mask_loss = 0.
331
+ total_batches = 0.
332
+
333
+ progress_bar_train.close()
334
+ progress_bar_eval = tqdm(
335
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
336
+ )
337
+ for eval_batch in valid_data_loader:
338
+ clean_audios, noisy_audios = eval_batch
339
+ clean_audios = clean_audios.to(device)
340
+ noisy_audios = noisy_audios.to(device)
341
+
342
+ est_spec, est_wav, est_mask = model.forward(noisy_audios)
343
+ denoise_audios = est_wav
344
+
345
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
346
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
347
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
348
+
349
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
350
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
351
+ logger.info(f"find nan or inf in loss.")
352
+ continue
353
+
354
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
355
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
356
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
357
+
358
+ total_pesq_score += pesq_score
359
+ total_loss += loss.item()
360
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
361
+ total_mask_loss += mask_loss.item()
362
+ total_batches += 1
363
+
364
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
365
+ average_loss = round(total_loss / total_batches, 4)
366
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
367
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
368
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
369
+
370
+ progress_bar_eval.update(1)
371
+ progress_bar_eval.set_postfix({
372
+ "lr": lr_scheduler.get_last_lr()[0],
373
+ "pesq_score": average_pesq_score,
374
+ "loss": average_loss,
375
+ "mr_stft_loss": average_mr_stft_loss,
376
+ "neg_si_snr_loss": average_neg_si_snr_loss,
377
+ "mask_loss": average_mask_loss,
378
+ })
379
+
380
+ total_pesq_score = 0.
381
+ total_loss = 0.
382
+ total_mr_stft_loss = 0.
383
+ total_neg_si_snr_loss = 0.
384
+ total_mask_loss = 0.
385
+ total_batches = 0.
386
+
387
+ progress_bar_eval.close()
388
+ progress_bar_train = tqdm(
389
+ initial=progress_bar_train.n,
390
+ postfix=progress_bar_train.postfix,
391
+ desc=progress_bar_train.desc,
392
+ )
393
+
394
+ # save path
395
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
396
+ save_dir.mkdir(parents=True, exist_ok=False)
397
+
398
+ # save models
399
+ model.save_pretrained(save_dir.as_posix())
400
+
401
+ model_list.append(save_dir)
402
+ if len(model_list) >= args.num_serialized_models_to_keep:
403
+ model_to_delete: Path = model_list.pop(0)
404
+ shutil.rmtree(model_to_delete.as_posix())
405
+
406
+ # save metric
407
+ if best_metric is None:
408
+ best_epoch_idx = epoch_idx
409
+ best_step_idx = step_idx
410
+ best_metric = average_pesq_score
411
+ elif average_pesq_score >= best_metric:
412
+ # great is better.
413
+ best_epoch_idx = epoch_idx
414
+ best_step_idx = step_idx
415
+ best_metric = average_pesq_score
416
+ else:
417
+ pass
418
+
419
+ metrics = {
420
+ "epoch_idx": epoch_idx,
421
+ "best_epoch_idx": best_epoch_idx,
422
+ "best_step_idx": best_step_idx,
423
+ "pesq_score": average_pesq_score,
424
+ "loss": average_loss,
425
+ "neg_si_snr_loss": average_neg_si_snr_loss,
426
+ "mask_loss": average_mask_loss,
427
+ }
428
+ metrics_filename = save_dir / "metrics_epoch.json"
429
+ with open(metrics_filename, "w", encoding="utf-8") as f:
430
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
431
+
432
+ # save best
433
+ best_dir = serialization_dir / "best"
434
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
435
+ if best_dir.exists():
436
+ shutil.rmtree(best_dir)
437
+ shutil.copytree(save_dir, best_dir)
438
+
439
+ # early stop
440
+ early_stop_flag = False
441
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
442
+ patience_count = 0
443
+ else:
444
+ patience_count += 1
445
+ if patience_count >= args.patience:
446
+ early_stop_flag = True
447
+
448
+ # early stop
449
+ if early_stop_flag:
450
+ break
451
+ model.train()
452
+
453
+ return
454
+
455
+
456
+ if __name__ == "__main__":
457
+ main()
examples/frcrn/yaml/config-10.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "frcrn"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 32000
5
+ nfft: 128
6
+ win_size: 128
7
+ hop_size: 64
8
+ win_type: hann
9
+
10
+ use_complex_networks: true
11
+ model_depth: 10
12
+ model_complexity: -1
13
+
14
+ min_snr_db: -10
15
+ max_snr_db: 20
16
+
17
+ num_workers: 8
18
+ batch_size: 32
19
+ eval_steps: 20000
20
+
21
+ lr: 0.001
22
+ lr_scheduler: "CosineAnnealingLR"
23
+ lr_scheduler_kwargs:
24
+ T_max: 250000
25
+ eta_min: 0.0001
26
+
27
+ max_epochs: 100
28
+ weight_decay: 1.0e-05
29
+ clip_grad_norm: 10.0
30
+ seed: 1234
31
+ num_gpus: -1
examples/frcrn/yaml/config-14.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "frcrn"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 32000
5
+ nfft: 640
6
+ win_size: 640
7
+ hop_size: 320
8
+ win_type: hann
9
+
10
+ use_complex_networks: true
11
+ model_depth: 14
12
+ model_complexity: -1
13
+
14
+ min_snr_db: -10
15
+ max_snr_db: 20
16
+
17
+ num_workers: 8
18
+ batch_size: 32
19
+ eval_steps: 10000
20
+
21
+ lr: 0.001
22
+ lr_scheduler: "CosineAnnealingLR"
23
+ lr_scheduler_kwargs:
24
+ T_max: 250000
25
+ eta_min: 0.0001
26
+
27
+ max_epochs: 100
28
+ weight_decay: 1.0e-05
29
+ clip_grad_norm: 10.0
30
+ seed: 1234
31
+ num_gpus: -1
examples/frcrn/yaml/config-20.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "frcrn"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 32000
5
+ nfft: 512
6
+ win_size: 512
7
+ hop_size: 256
8
+ win_type: hann
9
+
10
+ use_complex_networks: true
11
+ model_depth: 20
12
+ model_complexity: 45
13
+
14
+ min_snr_db: -10
15
+ max_snr_db: 20
16
+
17
+ num_workers: 8
18
+ batch_size: 32
19
+ eval_steps: 10000
20
+
21
+ lr: 0.001
22
+ lr_scheduler: "CosineAnnealingLR"
23
+ lr_scheduler_kwargs:
24
+ T_max: 250000
25
+ eta_min: 0.0001
26
+
27
+ max_epochs: 100
28
+ weight_decay: 1.0e-05
29
+ clip_grad_norm: 10.0
30
+ seed: 1234
31
+ num_gpus: -1
examples/frcrn_mp3_to_wav/run.sh ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
7
+ --config_file "yaml/config-10.yaml" \
8
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
9
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
10
+
11
+
12
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
13
+ --config_file "yaml/config-10.yaml" \
14
+ --audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
15
+
16
+ END
17
+
18
+
19
+ # params
20
+ system_version="windows";
21
+ verbose=true;
22
+ stage=0 # start from 0 if you need to start from data preparation
23
+ stop_stage=9
24
+
25
+ work_dir="$(pwd)"
26
+ file_folder_name=file_folder_name
27
+ final_model_name=final_model_name
28
+ config_file="yaml/config.yaml"
29
+ limit=10
30
+
31
+ audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
32
+
33
+ max_count=10000000
34
+
35
+ nohup_name=nohup.out
36
+
37
+ # model params
38
+ batch_size=64
39
+ max_epochs=200
40
+ save_top_k=10
41
+ patience=5
42
+
43
+
44
+ # parse options
45
+ while true; do
46
+ [ -z "${1:-}" ] && break; # break if there are no arguments
47
+ case "$1" in
48
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
49
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
50
+ old_value="(eval echo \\$$name)";
51
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
52
+ was_bool=true;
53
+ else
54
+ was_bool=false;
55
+ fi
56
+
57
+ # Set the variable to the right value-- the escaped quotes make it work if
58
+ # the option had spaces, like --cmd "queue.pl -sync y"
59
+ eval "${name}=\"$2\"";
60
+
61
+ # Check that Boolean-valued arguments are really Boolean.
62
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
63
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
64
+ exit 1;
65
+ fi
66
+ shift 2;
67
+ ;;
68
+
69
+ *) break;
70
+ esac
71
+ done
72
+
73
+ file_dir="${work_dir}/${file_folder_name}"
74
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
75
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
76
+
77
+ train_dataset="${file_dir}/train.jsonl"
78
+ valid_dataset="${file_dir}/valid.jsonl"
79
+
80
+ $verbose && echo "system_version: ${system_version}"
81
+ $verbose && echo "file_folder_name: ${file_folder_name}"
82
+
83
+ if [ $system_version == "windows" ]; then
84
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
85
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
86
+ #source /data/local/bin/nx_denoise/bin/activate
87
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
88
+ fi
89
+
90
+
91
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
92
+ $verbose && echo "stage 1: prepare data"
93
+ cd "${work_dir}" || exit 1
94
+ python3 step_1_prepare_data.py \
95
+ --file_dir "${file_dir}" \
96
+ --audio_dir "${audio_dir}" \
97
+ --train_dataset "${train_dataset}" \
98
+ --valid_dataset "${valid_dataset}" \
99
+ --max_count "${max_count}" \
100
+
101
+ fi
102
+
103
+
104
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
105
+ $verbose && echo "stage 2: train model"
106
+ cd "${work_dir}" || exit 1
107
+ python3 step_2_train_model.py \
108
+ --train_dataset "${train_dataset}" \
109
+ --valid_dataset "${valid_dataset}" \
110
+ --serialization_dir "${file_dir}" \
111
+ --config_file "${config_file}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
+ $verbose && echo "stage 3: test model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_3_evaluation.py \
120
+ --valid_dataset "${valid_dataset}" \
121
+ --model_dir "${file_dir}/best" \
122
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
123
+ --limit "${limit}" \
124
+
125
+ fi
126
+
127
+
128
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
+ $verbose && echo "stage 4: collect files"
130
+ cd "${work_dir}" || exit 1
131
+
132
+ mkdir -p ${final_model_dir}
133
+
134
+ cp "${file_dir}/best"/* "${final_model_dir}"
135
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
136
+
137
+ cd "${final_model_dir}/.." || exit 1;
138
+
139
+ if [ -e "${final_model_name}.zip" ]; then
140
+ rm -rf "${final_model_name}_backup.zip"
141
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
142
+ fi
143
+
144
+ zip -r "${final_model_name}.zip" "${final_model_name}"
145
+ rm -rf "${final_model_name}"
146
+
147
+ fi
148
+
149
+
150
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
151
+ $verbose && echo "stage 5: clear file_dir"
152
+ cd "${work_dir}" || exit 1
153
+
154
+ rm -rf "${file_dir}";
155
+
156
+ fi
examples/frcrn_mp3_to_wav/step_1_prepare_data.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ import sys
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--file_dir", default="./", type=str)
21
+
22
+ parser.add_argument(
23
+ "--audio_dir",
24
+ default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
25
+ type=str
26
+ )
27
+
28
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
29
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
30
+
31
+ parser.add_argument("--duration", default=4.0, type=float)
32
+
33
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
34
+
35
+ parser.add_argument("--max_count", default=-1, type=int)
36
+
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
42
+ data_dir = Path(data_dir)
43
+ for epoch_idx in range(max_epoch):
44
+ for filename in data_dir.glob("**/*.wav"):
45
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
46
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
47
+
48
+ if raw_duration < duration:
49
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
50
+ continue
51
+ if signal.ndim != 1:
52
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
53
+
54
+ signal_length = len(signal)
55
+ win_size = int(duration * sample_rate)
56
+ for begin in range(0, signal_length - win_size, win_size):
57
+ if np.sum(signal[begin: begin+win_size]) == 0:
58
+ continue
59
+ row = {
60
+ "epoch_idx": epoch_idx,
61
+ "filename": filename.as_posix(),
62
+ "raw_duration": round(raw_duration, 4),
63
+ "offset": round(begin / sample_rate, 4),
64
+ "duration": round(duration, 4),
65
+ }
66
+ yield row
67
+
68
+
69
+ def main():
70
+ args = get_args()
71
+
72
+ file_dir = Path(args.file_dir)
73
+ file_dir.mkdir(exist_ok=True)
74
+
75
+ audio_dir = Path(args.audio_dir)
76
+
77
+ audio_generator = target_second_signal_generator(
78
+ audio_dir.as_posix(),
79
+ duration=args.duration,
80
+ sample_rate=args.target_sample_rate,
81
+ max_epoch=1,
82
+ )
83
+ count = 0
84
+ process_bar = tqdm(desc="build dataset jsonl")
85
+ with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
86
+ for audio in audio_generator:
87
+ if count >= args.max_count > 0:
88
+ break
89
+
90
+ filename = audio["filename"]
91
+ raw_duration = audio["raw_duration"]
92
+ offset = audio["offset"]
93
+ duration = audio["duration"]
94
+
95
+ random1 = random.random()
96
+ random2 = random.random()
97
+
98
+ row = {
99
+ "count": count,
100
+
101
+ "filename": filename,
102
+ "raw_duration": raw_duration,
103
+ "offset": offset,
104
+ "duration": duration,
105
+
106
+ "random1": random1,
107
+ }
108
+ row = json.dumps(row, ensure_ascii=False)
109
+ if random2 < (1 / 10):
110
+ fvalid.write(f"{row}\n")
111
+ else:
112
+ ftrain.write(f"{row}\n")
113
+
114
+ count += 1
115
+ duration_seconds = count * args.duration
116
+ duration_hours = duration_seconds / 3600
117
+
118
+ process_bar.update(n=1)
119
+ process_bar.set_postfix({
120
+ "duration_hours": round(duration_hours, 4),
121
+ })
122
+
123
+ return
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
examples/frcrn_mp3_to_wav/step_2_train_model.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import json
5
+ import logging
6
+ from logging.handlers import TimedRotatingFileHandler
7
+ import os
8
+ import platform
9
+ from pathlib import Path
10
+ import random
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.data.dataloader import DataLoader
23
+ from tqdm import tqdm
24
+
25
+ from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
26
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
27
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
28
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
29
+ from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
30
+ from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
36
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
37
+
38
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
39
+ parser.add_argument("--patience", default=30, type=int)
40
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
41
+
42
+ parser.add_argument("--config_file", default="config.yaml", type=str)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def logging_config(file_dir: str):
49
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
50
+
51
+ logging.basicConfig(format=fmt,
52
+ datefmt="%m/%d/%Y %H:%M:%S",
53
+ level=logging.INFO)
54
+ file_handler = TimedRotatingFileHandler(
55
+ filename=os.path.join(file_dir, "main.log"),
56
+ encoding="utf-8",
57
+ when="D",
58
+ interval=1,
59
+ backupCount=7
60
+ )
61
+ file_handler.setLevel(logging.INFO)
62
+ file_handler.setFormatter(logging.Formatter(fmt))
63
+ logger = logging.getLogger(__name__)
64
+ logger.addHandler(file_handler)
65
+
66
+ return logger
67
+
68
+
69
+ class CollateFunction(object):
70
+ def __init__(self):
71
+ pass
72
+
73
+ def __call__(self, batch: List[dict]):
74
+ mp3_waveform_list = list()
75
+ wav_waveform_list = list()
76
+
77
+ for sample in batch:
78
+ mp3_waveform: torch.Tensor = sample["mp3_waveform"]
79
+ wav_waveform: torch.Tensor = sample["wav_waveform"]
80
+
81
+ mp3_waveform_list.append(mp3_waveform)
82
+ wav_waveform_list.append(wav_waveform)
83
+
84
+ mp3_waveform_list = torch.stack(mp3_waveform_list)
85
+ wav_waveform_list = torch.stack(wav_waveform_list)
86
+
87
+ # assert
88
+ if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
89
+ raise AssertionError("nan or inf in mp3_waveform_list")
90
+ if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
91
+ raise AssertionError("nan or inf in wav_waveform_list")
92
+
93
+ return mp3_waveform_list, wav_waveform_list
94
+
95
+
96
+ collate_fn = CollateFunction()
97
+
98
+
99
+ def main():
100
+ args = get_args()
101
+
102
+ config = FRCRNConfig.from_pretrained(
103
+ pretrained_model_name_or_path=args.config_file,
104
+ )
105
+
106
+ serialization_dir = Path(args.serialization_dir)
107
+ serialization_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ logger = logging_config(serialization_dir)
110
+
111
+ random.seed(config.seed)
112
+ np.random.seed(config.seed)
113
+ torch.manual_seed(config.seed)
114
+ logger.info(f"set seed: {config.seed}")
115
+
116
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
+ n_gpu = torch.cuda.device_count()
118
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
119
+
120
+ # datasets
121
+ train_dataset = Mp3ToWavJsonlDataset(
122
+ jsonl_file=args.train_dataset,
123
+ expected_sample_rate=config.sample_rate,
124
+ max_wave_value=32768.0,
125
+ # skip=225000,
126
+ )
127
+ valid_dataset = Mp3ToWavJsonlDataset(
128
+ jsonl_file=args.valid_dataset,
129
+ expected_sample_rate=config.sample_rate,
130
+ max_wave_value=32768.0,
131
+ )
132
+ train_data_loader = DataLoader(
133
+ dataset=train_dataset,
134
+ batch_size=config.batch_size,
135
+ # shuffle=True,
136
+ sampler=None,
137
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
138
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
139
+ collate_fn=collate_fn,
140
+ pin_memory=False,
141
+ prefetch_factor=2,
142
+ )
143
+ valid_data_loader = DataLoader(
144
+ dataset=valid_dataset,
145
+ batch_size=config.batch_size,
146
+ # shuffle=True,
147
+ sampler=None,
148
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
149
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
150
+ collate_fn=collate_fn,
151
+ pin_memory=False,
152
+ prefetch_factor=2,
153
+ )
154
+
155
+ # models
156
+ logger.info(f"prepare models. config_file: {args.config_file}")
157
+ model = FRCRNPretrainedModel(config).to(device)
158
+ model.to(device)
159
+ model.train()
160
+
161
+ # optimizer
162
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
163
+ optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
164
+
165
+ # resume training
166
+ last_step_idx = -1
167
+ last_epoch = -1
168
+ for step_idx_str in serialization_dir.glob("steps-*"):
169
+ step_idx_str = Path(step_idx_str)
170
+ step_idx = step_idx_str.stem.split("-")[1]
171
+ step_idx = int(step_idx)
172
+ if step_idx > last_step_idx:
173
+ last_step_idx = step_idx
174
+ # last_epoch = 0
175
+
176
+ if last_step_idx != -1:
177
+ logger.info(f"resume from steps-{last_step_idx}.")
178
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
179
+ # optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
180
+
181
+ logger.info(f"load state dict for model.")
182
+ with open(model_pt.as_posix(), "rb") as f:
183
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
184
+ model.load_state_dict(state_dict, strict=True)
185
+
186
+ # logger.info(f"load state dict for optimizer.")
187
+ # with open(optimizer_pth.as_posix(), "rb") as f:
188
+ # state_dict = torch.load(f, map_location="cpu", weights_only=True)
189
+ # optimizer.load_state_dict(state_dict)
190
+
191
+ if config.lr_scheduler == "CosineAnnealingLR":
192
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
193
+ optimizer,
194
+ last_epoch=last_epoch,
195
+ # T_max=10 * config.eval_steps,
196
+ # eta_min=0.01 * config.lr,
197
+ **config.lr_scheduler_kwargs,
198
+ )
199
+ elif config.lr_scheduler == "MultiStepLR":
200
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
201
+ optimizer,
202
+ last_epoch=last_epoch,
203
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
204
+ )
205
+ else:
206
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
207
+
208
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
209
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
210
+ fft_size_list=[256, 512, 1024],
211
+ win_size_list=[256, 512, 1024],
212
+ hop_size_list=[128, 256, 512],
213
+ factor_sc=1.5,
214
+ factor_mag=1.0,
215
+ reduction="mean"
216
+ ).to(device)
217
+
218
+ # training loop
219
+
220
+ # state
221
+ average_pesq_score = 1000000000
222
+ average_loss = 1000000000
223
+ average_neg_si_snr_loss = 1000000000
224
+ average_mask_loss = 1000000000
225
+
226
+ model_list = list()
227
+ best_epoch_idx = None
228
+ best_step_idx = None
229
+ best_metric = None
230
+ patience_count = 0
231
+
232
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
233
+
234
+ logger.info("training")
235
+ early_stop_flag = False
236
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
237
+ if early_stop_flag:
238
+ break
239
+
240
+ # train
241
+ model.train()
242
+
243
+ total_pesq_score = 0.
244
+ total_loss = 0.
245
+ total_mr_stft_loss = 0.
246
+ total_neg_si_snr_loss = 0.
247
+ total_mask_loss = 0.
248
+ total_batches = 0.
249
+
250
+ progress_bar_train = tqdm(
251
+ initial=step_idx,
252
+ desc="Training; epoch-{}".format(epoch_idx),
253
+ )
254
+ for train_batch in train_data_loader:
255
+ mp3_audios, wav_audios = train_batch
256
+ noisy_audios: torch.Tensor = mp3_audios.to(device)
257
+ clean_audios: torch.Tensor = wav_audios.to(device)
258
+
259
+ est_spec, est_wav, est_mask = model.forward(noisy_audios)
260
+ denoise_audios = est_wav
261
+
262
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
263
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
264
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
265
+
266
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
267
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
268
+ logger.info(f"find nan or inf in loss.")
269
+ continue
270
+
271
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
272
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
273
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
274
+
275
+ optimizer.zero_grad()
276
+ loss.backward()
277
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
278
+ optimizer.step()
279
+ lr_scheduler.step()
280
+
281
+ total_pesq_score += pesq_score
282
+ total_loss += loss.item()
283
+ total_mr_stft_loss += mr_stft_loss.item()
284
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
285
+ total_mask_loss += mask_loss.item()
286
+ total_batches += 1
287
+
288
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
289
+ average_loss = round(total_loss / total_batches, 4)
290
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
291
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
292
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
293
+
294
+ progress_bar_train.update(1)
295
+ progress_bar_train.set_postfix({
296
+ "lr": lr_scheduler.get_last_lr()[0],
297
+ "pesq_score": average_pesq_score,
298
+ "loss": average_loss,
299
+ "mr_stft_loss": average_mr_stft_loss,
300
+ "neg_si_snr_loss": average_neg_si_snr_loss,
301
+ "mask_loss": average_mask_loss,
302
+ })
303
+
304
+ # evaluation
305
+ step_idx += 1
306
+ if step_idx % config.eval_steps == 0:
307
+ model.eval()
308
+ with torch.no_grad():
309
+ torch.cuda.empty_cache()
310
+
311
+ total_pesq_score = 0.
312
+ total_loss = 0.
313
+ total_mr_stft_loss = 0.
314
+ total_neg_si_snr_loss = 0.
315
+ total_mask_loss = 0.
316
+ total_batches = 0.
317
+
318
+ progress_bar_train.close()
319
+ progress_bar_eval = tqdm(
320
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
321
+ )
322
+ for eval_batch in valid_data_loader:
323
+ mp3_audios, wav_audios = eval_batch
324
+ noisy_audios: torch.Tensor = mp3_audios.to(device)
325
+ clean_audios: torch.Tensor = wav_audios.to(device)
326
+
327
+ est_spec, est_wav, est_mask = model.forward(noisy_audios)
328
+ denoise_audios = est_wav
329
+
330
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
331
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
332
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
333
+
334
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
335
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
336
+ logger.info(f"find nan or inf in loss.")
337
+ continue
338
+
339
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
340
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
341
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
342
+
343
+ total_pesq_score += pesq_score
344
+ total_loss += loss.item()
345
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
346
+ total_mask_loss += mask_loss.item()
347
+ total_batches += 1
348
+
349
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
350
+ average_loss = round(total_loss / total_batches, 4)
351
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
352
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
353
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
354
+
355
+ progress_bar_eval.update(1)
356
+ progress_bar_eval.set_postfix({
357
+ "lr": lr_scheduler.get_last_lr()[0],
358
+ "pesq_score": average_pesq_score,
359
+ "loss": average_loss,
360
+ "mr_stft_loss": average_mr_stft_loss,
361
+ "neg_si_snr_loss": average_neg_si_snr_loss,
362
+ "mask_loss": average_mask_loss,
363
+ })
364
+
365
+ total_pesq_score = 0.
366
+ total_loss = 0.
367
+ total_mr_stft_loss = 0.
368
+ total_neg_si_snr_loss = 0.
369
+ total_mask_loss = 0.
370
+ total_batches = 0.
371
+
372
+ progress_bar_eval.close()
373
+ progress_bar_train = tqdm(
374
+ initial=progress_bar_train.n,
375
+ postfix=progress_bar_train.postfix,
376
+ desc=progress_bar_train.desc,
377
+ )
378
+
379
+ # save path
380
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
381
+ save_dir.mkdir(parents=True, exist_ok=False)
382
+
383
+ # save models
384
+ model.save_pretrained(save_dir.as_posix())
385
+
386
+ model_list.append(save_dir)
387
+ if len(model_list) >= args.num_serialized_models_to_keep:
388
+ model_to_delete: Path = model_list.pop(0)
389
+ shutil.rmtree(model_to_delete.as_posix())
390
+
391
+ # save metric
392
+ if best_metric is None:
393
+ best_epoch_idx = epoch_idx
394
+ best_step_idx = step_idx
395
+ best_metric = average_pesq_score
396
+ elif average_pesq_score >= best_metric:
397
+ # great is better.
398
+ best_epoch_idx = epoch_idx
399
+ best_step_idx = step_idx
400
+ best_metric = average_pesq_score
401
+ else:
402
+ pass
403
+
404
+ metrics = {
405
+ "epoch_idx": epoch_idx,
406
+ "best_epoch_idx": best_epoch_idx,
407
+ "best_step_idx": best_step_idx,
408
+ "pesq_score": average_pesq_score,
409
+ "loss": average_loss,
410
+ "neg_si_snr_loss": average_neg_si_snr_loss,
411
+ "mask_loss": average_mask_loss,
412
+ }
413
+ metrics_filename = save_dir / "metrics_epoch.json"
414
+ with open(metrics_filename, "w", encoding="utf-8") as f:
415
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
416
+
417
+ # save best
418
+ best_dir = serialization_dir / "best"
419
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
420
+ if best_dir.exists():
421
+ shutil.rmtree(best_dir)
422
+ shutil.copytree(save_dir, best_dir)
423
+
424
+ # early stop
425
+ early_stop_flag = False
426
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
427
+ patience_count = 0
428
+ else:
429
+ patience_count += 1
430
+ if patience_count >= args.patience:
431
+ early_stop_flag = True
432
+
433
+ # early stop
434
+ if early_stop_flag:
435
+ break
436
+ model.train()
437
+
438
+ return
439
+
440
+
441
+ if __name__ == "__main__":
442
+ main()