HoneyTian commited on
Commit
5b68ebd
·
1 Parent(s): 6ded3e4
Files changed (30) hide show
  1. README.md +2 -2
  2. examples/dtln_mp3_to_wav/run.sh +0 -168
  3. examples/dtln_mp3_to_wav/step_1_prepare_data.py +0 -127
  4. examples/dtln_mp3_to_wav/step_2_train_model.py +0 -445
  5. examples/dtln_mp3_to_wav/yaml/config-1024.yaml +0 -29
  6. examples/dtln_mp3_to_wav/yaml/config-256.yaml +0 -29
  7. examples/dtln_mp3_to_wav/yaml/config-512.yaml +0 -29
  8. examples/frcrn_mp3_to_wav/run.sh +0 -156
  9. examples/frcrn_mp3_to_wav/step_1_prepare_data.py +0 -127
  10. examples/frcrn_mp3_to_wav/step_2_train_model.py +0 -442
  11. examples/frcrn_mp3_to_wav/yaml/config-10.yaml +0 -31
  12. examples/frcrn_mp3_to_wav/yaml/config-14.yaml +0 -31
  13. examples/frcrn_mp3_to_wav/yaml/config-20.yaml +0 -31
  14. examples/simple_linear_irm_aishell/run.sh +0 -172
  15. examples/simple_linear_irm_aishell/step_1_prepare_data.py +0 -196
  16. examples/simple_linear_irm_aishell/step_2_train_model.py +0 -348
  17. examples/simple_linear_irm_aishell/step_3_evaluation.py +0 -239
  18. examples/simple_linear_irm_aishell/yaml/config.yaml +0 -13
  19. examples/spectrum_dfnet_aishell/run.sh +0 -178
  20. examples/spectrum_dfnet_aishell/step_1_prepare_data.py +0 -197
  21. examples/spectrum_dfnet_aishell/step_2_train_model.py +0 -440
  22. examples/spectrum_dfnet_aishell/step_3_evaluation.py +0 -302
  23. examples/spectrum_dfnet_aishell/yaml/config.yaml +0 -53
  24. examples/spectrum_unet_irm_aishell/run.sh +0 -178
  25. examples/spectrum_unet_irm_aishell/step_1_prepare_data.py +0 -197
  26. examples/spectrum_unet_irm_aishell/step_2_train_model.py +0 -420
  27. examples/spectrum_unet_irm_aishell/step_3_evaluation.py +0 -270
  28. examples/spectrum_unet_irm_aishell/yaml/config.yaml +0 -38
  29. main.py +1 -1
  30. toolbox/torch/utils/data/dataset/mp3_to_wav_jsonl_dataset.py +0 -197
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: NX Denoise
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: blue
@@ -9,7 +9,7 @@ license: apache-2.0
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
- ## NX Denoise
13
 
14
 
15
  ### datasets
 
1
  ---
2
+ title: CC Denoise
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: blue
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+ ## CC Denoise
13
 
14
 
15
  ### datasets
examples/dtln_mp3_to_wav/run.sh DELETED
@@ -1,168 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- : <<'END'
4
-
5
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \
6
- --config_file "yaml/config-256.yaml" \
7
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
8
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
-
10
-
11
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
12
- --config_file "yaml/config-512.yaml" \
13
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
14
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
15
-
16
-
17
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
18
- --config_file "yaml/config-1024.yaml" \
19
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
20
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
21
-
22
-
23
- bash run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3-mp3 --final_model_name dtln-256-nx2-dns3-mp3 \
24
- --config_file "yaml/config-256.yaml" \
25
- --audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
26
-
27
-
28
- END
29
-
30
-
31
- # params
32
- system_version="windows";
33
- verbose=true;
34
- stage=0 # start from 0 if you need to start from data preparation
35
- stop_stage=9
36
-
37
- work_dir="$(pwd)"
38
- file_folder_name=file_folder_name
39
- final_model_name=final_model_name
40
- config_file="yaml/config.yaml"
41
- limit=10
42
-
43
- audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
44
-
45
- max_count=-1
46
-
47
- nohup_name=nohup.out
48
-
49
- # model params
50
- batch_size=64
51
- max_epochs=200
52
- save_top_k=10
53
- patience=5
54
-
55
-
56
- # parse options
57
- while true; do
58
- [ -z "${1:-}" ] && break; # break if there are no arguments
59
- case "$1" in
60
- --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
61
- eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
62
- old_value="(eval echo \\$$name)";
63
- if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
64
- was_bool=true;
65
- else
66
- was_bool=false;
67
- fi
68
-
69
- # Set the variable to the right value-- the escaped quotes make it work if
70
- # the option had spaces, like --cmd "queue.pl -sync y"
71
- eval "${name}=\"$2\"";
72
-
73
- # Check that Boolean-valued arguments are really Boolean.
74
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
75
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
76
- exit 1;
77
- fi
78
- shift 2;
79
- ;;
80
-
81
- *) break;
82
- esac
83
- done
84
-
85
- file_dir="${work_dir}/${file_folder_name}"
86
- final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
87
- evaluation_audio_dir="${file_dir}/evaluation_audio"
88
-
89
- train_dataset="${file_dir}/train.jsonl"
90
- valid_dataset="${file_dir}/valid.jsonl"
91
-
92
- $verbose && echo "system_version: ${system_version}"
93
- $verbose && echo "file_folder_name: ${file_folder_name}"
94
-
95
- if [ $system_version == "windows" ]; then
96
- alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
97
- elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
98
- #source /data/local/bin/nx_denoise/bin/activate
99
- alias python3='/data/local/bin/nx_denoise/bin/python3'
100
- fi
101
-
102
-
103
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
104
- $verbose && echo "stage 1: prepare data"
105
- cd "${work_dir}" || exit 1
106
- python3 step_1_prepare_data.py \
107
- --file_dir "${file_dir}" \
108
- --audio_dir "${audio_dir}" \
109
- --train_dataset "${train_dataset}" \
110
- --valid_dataset "${valid_dataset}" \
111
- --max_count "${max_count}" \
112
-
113
- fi
114
-
115
-
116
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
117
- $verbose && echo "stage 2: train model"
118
- cd "${work_dir}" || exit 1
119
- python3 step_2_train_model.py \
120
- --train_dataset "${train_dataset}" \
121
- --valid_dataset "${valid_dataset}" \
122
- --serialization_dir "${file_dir}" \
123
- --config_file "${config_file}" \
124
-
125
- fi
126
-
127
-
128
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
129
- $verbose && echo "stage 3: test model"
130
- cd "${work_dir}" || exit 1
131
- python3 step_3_evaluation.py \
132
- --valid_dataset "${valid_dataset}" \
133
- --model_dir "${file_dir}/best" \
134
- --evaluation_audio_dir "${evaluation_audio_dir}" \
135
- --limit "${limit}" \
136
-
137
- fi
138
-
139
-
140
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
141
- $verbose && echo "stage 4: collect files"
142
- cd "${work_dir}" || exit 1
143
-
144
- mkdir -p ${final_model_dir}
145
-
146
- cp "${file_dir}/best"/* "${final_model_dir}"
147
- cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
148
-
149
- cd "${final_model_dir}/.." || exit 1;
150
-
151
- if [ -e "${final_model_name}.zip" ]; then
152
- rm -rf "${final_model_name}_backup.zip"
153
- mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
154
- fi
155
-
156
- zip -r "${final_model_name}.zip" "${final_model_name}"
157
- rm -rf "${final_model_name}"
158
-
159
- fi
160
-
161
-
162
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
163
- $verbose && echo "stage 5: clear file_dir"
164
- cd "${work_dir}" || exit 1
165
-
166
- rm -rf "${file_dir}";
167
-
168
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/dtln_mp3_to_wav/step_1_prepare_data.py DELETED
@@ -1,127 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import json
5
- import os
6
- from pathlib import Path
7
- import random
8
- import sys
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import librosa
14
- import numpy as np
15
- from tqdm import tqdm
16
-
17
-
18
- def get_args():
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument("--file_dir", default="./", type=str)
21
-
22
- parser.add_argument(
23
- "--audio_dir",
24
- default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
25
- type=str
26
- )
27
-
28
- parser.add_argument("--train_dataset", default="train.jsonl", type=str)
29
- parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
30
-
31
- parser.add_argument("--duration", default=4.0, type=float)
32
-
33
- parser.add_argument("--target_sample_rate", default=8000, type=int)
34
-
35
- parser.add_argument("--max_count", default=-1, type=int)
36
-
37
- args = parser.parse_args()
38
- return args
39
-
40
-
41
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
42
- data_dir = Path(data_dir)
43
- for epoch_idx in range(max_epoch):
44
- for filename in data_dir.glob("**/*.wav"):
45
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
46
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
47
-
48
- if raw_duration < duration:
49
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
50
- continue
51
- if signal.ndim != 1:
52
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
53
-
54
- signal_length = len(signal)
55
- win_size = int(duration * sample_rate)
56
- for begin in range(0, signal_length - win_size, win_size):
57
- if np.sum(signal[begin: begin+win_size]) == 0:
58
- continue
59
- row = {
60
- "epoch_idx": epoch_idx,
61
- "filename": filename.as_posix(),
62
- "raw_duration": round(raw_duration, 4),
63
- "offset": round(begin / sample_rate, 4),
64
- "duration": round(duration, 4),
65
- }
66
- yield row
67
-
68
-
69
- def main():
70
- args = get_args()
71
-
72
- file_dir = Path(args.file_dir)
73
- file_dir.mkdir(exist_ok=True)
74
-
75
- audio_dir = Path(args.audio_dir)
76
-
77
- audio_generator = target_second_signal_generator(
78
- audio_dir.as_posix(),
79
- duration=args.duration,
80
- sample_rate=args.target_sample_rate,
81
- max_epoch=1,
82
- )
83
- count = 0
84
- process_bar = tqdm(desc="build dataset jsonl")
85
- with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
86
- for audio in audio_generator:
87
- if count >= args.max_count > 0:
88
- break
89
-
90
- filename = audio["filename"]
91
- raw_duration = audio["raw_duration"]
92
- offset = audio["offset"]
93
- duration = audio["duration"]
94
-
95
- random1 = random.random()
96
- random2 = random.random()
97
-
98
- row = {
99
- "count": count,
100
-
101
- "filename": filename,
102
- "raw_duration": raw_duration,
103
- "offset": offset,
104
- "duration": duration,
105
-
106
- "random1": random1,
107
- }
108
- row = json.dumps(row, ensure_ascii=False)
109
- if random2 < (1 / 300):
110
- fvalid.write(f"{row}\n")
111
- else:
112
- ftrain.write(f"{row}\n")
113
-
114
- count += 1
115
- duration_seconds = count * args.duration
116
- duration_hours = duration_seconds / 3600
117
-
118
- process_bar.update(n=1)
119
- process_bar.set_postfix({
120
- "duration_hours": round(duration_hours, 4),
121
- })
122
-
123
- return
124
-
125
-
126
- if __name__ == "__main__":
127
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/dtln_mp3_to_wav/step_2_train_model.py DELETED
@@ -1,445 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/breizhn/DTLN
5
-
6
- """
7
- import argparse
8
- import json
9
- import logging
10
- from logging.handlers import TimedRotatingFileHandler
11
- import os
12
- import platform
13
- from pathlib import Path
14
- import random
15
- import sys
16
- import shutil
17
- from typing import List
18
-
19
- pwd = os.path.abspath(os.path.dirname(__file__))
20
- sys.path.append(os.path.join(pwd, "../../"))
21
-
22
- import numpy as np
23
- import torch
24
- import torch.nn as nn
25
- from torch.nn import functional as F
26
- from torch.utils.data.dataloader import DataLoader
27
- from tqdm import tqdm
28
-
29
- from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
30
- from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
31
- from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
32
- from toolbox.torchaudio.metrics.pesq import run_pesq_score
33
- from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
34
- from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
35
-
36
-
37
- def get_args():
38
- parser = argparse.ArgumentParser()
39
- parser.add_argument("--train_dataset", default="train.jsonl", type=str)
40
- parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
41
-
42
- parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
43
- parser.add_argument("--patience", default=30, type=int)
44
- parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
-
46
- parser.add_argument("--config_file", default="config.yaml", type=str)
47
-
48
- args = parser.parse_args()
49
- return args
50
-
51
-
52
- def logging_config(file_dir: str):
53
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
54
-
55
- logging.basicConfig(format=fmt,
56
- datefmt="%m/%d/%Y %H:%M:%S",
57
- level=logging.INFO)
58
- file_handler = TimedRotatingFileHandler(
59
- filename=os.path.join(file_dir, "main.log"),
60
- encoding="utf-8",
61
- when="D",
62
- interval=1,
63
- backupCount=7
64
- )
65
- file_handler.setLevel(logging.INFO)
66
- file_handler.setFormatter(logging.Formatter(fmt))
67
- logger = logging.getLogger(__name__)
68
- logger.addHandler(file_handler)
69
-
70
- return logger
71
-
72
-
73
- class CollateFunction(object):
74
- def __init__(self):
75
- pass
76
-
77
- def __call__(self, batch: List[dict]):
78
- mp3_waveform_list = list()
79
- wav_waveform_list = list()
80
-
81
- for sample in batch:
82
- mp3_waveform: torch.Tensor = sample["mp3_waveform"]
83
- wav_waveform: torch.Tensor = sample["wav_waveform"]
84
-
85
- mp3_waveform_list.append(mp3_waveform)
86
- wav_waveform_list.append(wav_waveform)
87
-
88
- mp3_waveform_list = torch.stack(mp3_waveform_list)
89
- wav_waveform_list = torch.stack(wav_waveform_list)
90
-
91
- # assert
92
- if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
93
- raise AssertionError("nan or inf in mp3_waveform_list")
94
- if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
95
- raise AssertionError("nan or inf in wav_waveform_list")
96
-
97
- return mp3_waveform_list, wav_waveform_list
98
-
99
-
100
- collate_fn = CollateFunction()
101
-
102
-
103
- def main():
104
- args = get_args()
105
-
106
- config = DTLNConfig.from_pretrained(
107
- pretrained_model_name_or_path=args.config_file,
108
- )
109
-
110
- serialization_dir = Path(args.serialization_dir)
111
- serialization_dir.mkdir(parents=True, exist_ok=True)
112
-
113
- logger = logging_config(serialization_dir)
114
-
115
- random.seed(config.seed)
116
- np.random.seed(config.seed)
117
- torch.manual_seed(config.seed)
118
- logger.info(f"set seed: {config.seed}")
119
-
120
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121
- n_gpu = torch.cuda.device_count()
122
- logger.info(f"GPU available count: {n_gpu}; device: {device}")
123
-
124
- # datasets
125
- train_dataset = Mp3ToWavJsonlDataset(
126
- jsonl_file=args.train_dataset,
127
- expected_sample_rate=config.sample_rate,
128
- max_wave_value=32768.0,
129
- # skip=225000,
130
- )
131
- valid_dataset = Mp3ToWavJsonlDataset(
132
- jsonl_file=args.valid_dataset,
133
- expected_sample_rate=config.sample_rate,
134
- max_wave_value=32768.0,
135
- )
136
- train_data_loader = DataLoader(
137
- dataset=train_dataset,
138
- batch_size=config.batch_size,
139
- # shuffle=True,
140
- sampler=None,
141
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
142
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
143
- collate_fn=collate_fn,
144
- pin_memory=False,
145
- prefetch_factor=None if platform.system() == "Windows" else 2,
146
- )
147
- valid_data_loader = DataLoader(
148
- dataset=valid_dataset,
149
- batch_size=config.batch_size,
150
- # shuffle=True,
151
- sampler=None,
152
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
153
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
154
- collate_fn=collate_fn,
155
- pin_memory=False,
156
- prefetch_factor=None if platform.system() == "Windows" else 2,
157
- )
158
-
159
- # models
160
- logger.info(f"prepare models. config_file: {args.config_file}")
161
- model = DTLNPretrainedModel(config).to(device)
162
- model.to(device)
163
- model.train()
164
-
165
- # optimizer
166
- logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
167
- optimizer = torch.optim.AdamW(model.parameters(), config.lr)
168
-
169
- # resume training
170
- last_step_idx = -1
171
- last_epoch = -1
172
- for step_idx_str in serialization_dir.glob("steps-*"):
173
- step_idx_str = Path(step_idx_str)
174
- step_idx = step_idx_str.stem.split("-")[1]
175
- step_idx = int(step_idx)
176
- if step_idx > last_step_idx:
177
- last_step_idx = step_idx
178
- # last_epoch = 1
179
-
180
- if last_step_idx != -1:
181
- logger.info(f"resume from steps-{last_step_idx}.")
182
- model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
183
-
184
- logger.info(f"load state dict for model.")
185
- with open(model_pt.as_posix(), "rb") as f:
186
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
187
- model.load_state_dict(state_dict, strict=True)
188
-
189
- if config.lr_scheduler == "CosineAnnealingLR":
190
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
191
- optimizer,
192
- last_epoch=last_epoch,
193
- # T_max=10 * config.eval_steps,
194
- # eta_min=0.01 * config.lr,
195
- **config.lr_scheduler_kwargs,
196
- )
197
- elif config.lr_scheduler == "MultiStepLR":
198
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
199
- optimizer,
200
- last_epoch=last_epoch,
201
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
202
- )
203
- else:
204
- raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
205
-
206
- neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
207
- mr_stft_loss_fn = MultiResolutionSTFTLoss(
208
- fft_size_list=[256, 512, 1024],
209
- win_size_list=[256, 512, 1024],
210
- hop_size_list=[128, 256, 512],
211
- factor_sc=1.5,
212
- factor_mag=1.0,
213
- reduction="mean"
214
- ).to(device)
215
- audio_l1_loss_fn = nn.L1Loss(reduction="mean")
216
-
217
- # training loop
218
-
219
- # state
220
- average_pesq_score = 1000000000
221
- average_loss = 1000000000
222
- average_mr_stft_loss = 1000000000
223
- average_audio_l1_loss = 1000000000
224
- average_neg_si_snr_loss = 1000000000
225
-
226
- model_list = list()
227
- best_epoch_idx = None
228
- best_step_idx = None
229
- best_metric = None
230
- patience_count = 0
231
-
232
- step_idx = 0 if last_step_idx == -1 else last_step_idx
233
-
234
- logger.info("training")
235
- early_stop_flag = False
236
- for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
237
- if early_stop_flag:
238
- break
239
-
240
- # train
241
- model.train()
242
-
243
- total_pesq_score = 0.
244
- total_loss = 0.
245
- total_mr_stft_loss = 0.
246
- total_audio_l1_loss = 0.
247
- total_neg_si_snr_loss = 0.
248
- total_batches = 0.
249
-
250
- progress_bar_train = tqdm(
251
- initial=step_idx,
252
- desc="Training; epoch-{}".format(epoch_idx),
253
- )
254
- for train_batch in train_data_loader:
255
- mp3_audios, wav_audios = train_batch
256
- noisy_audios: torch.Tensor = mp3_audios.to(device)
257
- clean_audios: torch.Tensor = wav_audios.to(device)
258
-
259
- denoise_audios = model.forward(noisy_audios)
260
- denoise_audios = torch.squeeze(denoise_audios, dim=1)
261
-
262
- mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
263
- audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
264
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
265
-
266
- loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
267
- if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
268
- logger.info(f"find nan or inf in loss.")
269
- continue
270
-
271
- denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
272
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
273
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
274
-
275
- optimizer.zero_grad()
276
- loss.backward()
277
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
278
- optimizer.step()
279
- lr_scheduler.step()
280
-
281
- total_pesq_score += pesq_score
282
- total_loss += loss.item()
283
- total_mr_stft_loss += mr_stft_loss.item()
284
- total_audio_l1_loss += audio_l1_loss.item()
285
- total_neg_si_snr_loss += neg_si_snr_loss.item()
286
- total_batches += 1
287
-
288
- average_pesq_score = round(total_pesq_score / total_batches, 4)
289
- average_loss = round(total_loss / total_batches, 4)
290
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
291
- average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
292
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
293
-
294
- progress_bar_train.update(1)
295
- progress_bar_train.set_postfix({
296
- "lr": lr_scheduler.get_last_lr()[0],
297
- "pesq_score": average_pesq_score,
298
- "loss": average_loss,
299
- "mr_stft_loss": average_mr_stft_loss,
300
- "audio_l1_loss": average_audio_l1_loss,
301
- "neg_si_snr_loss": average_neg_si_snr_loss,
302
- })
303
-
304
- # evaluation
305
- step_idx += 1
306
- if step_idx % config.eval_steps == 0:
307
- model.eval()
308
- with torch.no_grad():
309
- torch.cuda.empty_cache()
310
-
311
- total_pesq_score = 0.
312
- total_loss = 0.
313
- total_mr_stft_loss = 0.
314
- total_audio_l1_loss = 0.
315
- total_neg_si_snr_loss = 0.
316
- total_batches = 0.
317
-
318
- progress_bar_train.close()
319
- progress_bar_eval = tqdm(
320
- desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
321
- )
322
- for eval_batch in valid_data_loader:
323
- mp3_audios, wav_audios = eval_batch
324
- noisy_audios: torch.Tensor = mp3_audios.to(device)
325
- clean_audios: torch.Tensor = wav_audios.to(device)
326
-
327
- denoise_audios = model.forward(noisy_audios)
328
- denoise_audios = torch.squeeze(denoise_audios, dim=1)
329
-
330
- mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
331
- audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
332
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
333
-
334
- loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
335
- if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
336
- logger.info(f"find nan or inf in loss.")
337
- continue
338
-
339
- denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
340
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
341
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
342
-
343
- total_pesq_score += pesq_score
344
- total_loss += loss.item()
345
- total_mr_stft_loss += mr_stft_loss.item()
346
- total_audio_l1_loss += audio_l1_loss.item()
347
- total_neg_si_snr_loss += neg_si_snr_loss.item()
348
- total_batches += 1
349
-
350
- average_pesq_score = round(total_pesq_score / total_batches, 4)
351
- average_loss = round(total_loss / total_batches, 4)
352
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
353
- average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
354
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
355
-
356
- progress_bar_eval.update(1)
357
- progress_bar_eval.set_postfix({
358
- "lr": lr_scheduler.get_last_lr()[0],
359
- "pesq_score": average_pesq_score,
360
- "loss": average_loss,
361
- "mr_stft_loss": average_mr_stft_loss,
362
- "audio_l1_loss": average_audio_l1_loss,
363
- "neg_si_snr_loss": average_neg_si_snr_loss,
364
-
365
- })
366
-
367
- total_pesq_score = 0.
368
- total_loss = 0.
369
- total_mr_stft_loss = 0.
370
- total_audio_l1_loss = 0.
371
- total_neg_si_snr_loss = 0.
372
- total_batches = 0.
373
-
374
- progress_bar_eval.close()
375
- progress_bar_train = tqdm(
376
- initial=progress_bar_train.n,
377
- postfix=progress_bar_train.postfix,
378
- desc=progress_bar_train.desc,
379
- )
380
-
381
- # save path
382
- save_dir = serialization_dir / "steps-{}".format(step_idx)
383
- save_dir.mkdir(parents=True, exist_ok=False)
384
-
385
- # save models
386
- model.save_pretrained(save_dir.as_posix())
387
-
388
- model_list.append(save_dir)
389
- if len(model_list) >= args.num_serialized_models_to_keep:
390
- model_to_delete: Path = model_list.pop(0)
391
- shutil.rmtree(model_to_delete.as_posix())
392
-
393
- # save metric
394
- if best_metric is None:
395
- best_epoch_idx = epoch_idx
396
- best_step_idx = step_idx
397
- best_metric = average_pesq_score
398
- elif average_pesq_score >= best_metric:
399
- # great is better.
400
- best_epoch_idx = epoch_idx
401
- best_step_idx = step_idx
402
- best_metric = average_pesq_score
403
- else:
404
- pass
405
-
406
- metrics = {
407
- "epoch_idx": epoch_idx,
408
- "best_epoch_idx": best_epoch_idx,
409
- "best_step_idx": best_step_idx,
410
- "pesq_score": average_pesq_score,
411
- "loss": average_loss,
412
- "mr_stft_loss": average_mr_stft_loss,
413
- "audio_l1_loss": average_audio_l1_loss,
414
- "neg_si_snr_loss": average_neg_si_snr_loss,
415
- }
416
- metrics_filename = save_dir / "metrics_epoch.json"
417
- with open(metrics_filename, "w", encoding="utf-8") as f:
418
- json.dump(metrics, f, indent=4, ensure_ascii=False)
419
-
420
- # save best
421
- best_dir = serialization_dir / "best"
422
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
423
- if best_dir.exists():
424
- shutil.rmtree(best_dir)
425
- shutil.copytree(save_dir, best_dir)
426
-
427
- # early stop
428
- early_stop_flag = False
429
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
430
- patience_count = 0
431
- else:
432
- patience_count += 1
433
- if patience_count >= args.patience:
434
- early_stop_flag = True
435
-
436
- # early stop
437
- if early_stop_flag:
438
- break
439
- model.train()
440
-
441
- return
442
-
443
-
444
- if __name__ == "__main__":
445
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/dtln_mp3_to_wav/yaml/config-1024.yaml DELETED
@@ -1,29 +0,0 @@
1
- model_name: "DTLN"
2
-
3
- # spec
4
- sample_rate: 8000
5
- fft_size: 512
6
- hop_size: 128
7
- win_type: hann
8
-
9
- # data
10
- min_snr_db: -5
11
- max_snr_db: 25
12
-
13
- # model
14
- encoder_size: 1024
15
-
16
- # train
17
- lr: 0.001
18
- lr_scheduler: "CosineAnnealingLR"
19
- lr_scheduler_kwargs:
20
- T_max: 250000
21
- eta_min: 0.0001
22
-
23
- max_epochs: 100
24
- clip_grad_norm: 10.0
25
- seed: 1234
26
-
27
- num_workers: 4
28
- batch_size: 64
29
- eval_steps: 15000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/dtln_mp3_to_wav/yaml/config-256.yaml DELETED
@@ -1,29 +0,0 @@
1
- model_name: "DTLN"
2
-
3
- # spec
4
- sample_rate: 8000
5
- fft_size: 256
6
- hop_size: 128
7
- win_type: hann
8
-
9
- # data
10
- min_snr_db: -5
11
- max_snr_db: 25
12
-
13
- # model
14
- encoder_size: 256
15
-
16
- # train
17
- lr: 0.001
18
- lr_scheduler: "CosineAnnealingLR"
19
- lr_scheduler_kwargs:
20
- T_max: 250000
21
- eta_min: 0.0001
22
-
23
- max_epochs: 100
24
- clip_grad_norm: 10.0
25
- seed: 1234
26
-
27
- num_workers: 4
28
- batch_size: 64
29
- eval_steps: 15000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/dtln_mp3_to_wav/yaml/config-512.yaml DELETED
@@ -1,29 +0,0 @@
1
- model_name: "DTLN"
2
-
3
- # spec
4
- sample_rate: 8000
5
- fft_size: 512
6
- hop_size: 128
7
- win_type: hann
8
-
9
- # data
10
- min_snr_db: -5
11
- max_snr_db: 25
12
-
13
- # model
14
- encoder_size: 512
15
-
16
- # train
17
- lr: 0.001
18
- lr_scheduler: "CosineAnnealingLR"
19
- lr_scheduler_kwargs:
20
- T_max: 250000
21
- eta_min: 0.0001
22
-
23
- max_epochs: 100
24
- clip_grad_norm: 10.0
25
- seed: 1234
26
-
27
- num_workers: 4
28
- batch_size: 64
29
- eval_steps: 15000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/frcrn_mp3_to_wav/run.sh DELETED
@@ -1,156 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- : <<'END'
4
-
5
-
6
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
7
- --config_file "yaml/config-10.yaml" \
8
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
9
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
10
-
11
-
12
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
13
- --config_file "yaml/config-10.yaml" \
14
- --audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
15
-
16
- END
17
-
18
-
19
- # params
20
- system_version="windows";
21
- verbose=true;
22
- stage=0 # start from 0 if you need to start from data preparation
23
- stop_stage=9
24
-
25
- work_dir="$(pwd)"
26
- file_folder_name=file_folder_name
27
- final_model_name=final_model_name
28
- config_file="yaml/config.yaml"
29
- limit=10
30
-
31
- audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
32
-
33
- max_count=10000000
34
-
35
- nohup_name=nohup.out
36
-
37
- # model params
38
- batch_size=64
39
- max_epochs=200
40
- save_top_k=10
41
- patience=5
42
-
43
-
44
- # parse options
45
- while true; do
46
- [ -z "${1:-}" ] && break; # break if there are no arguments
47
- case "$1" in
48
- --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
49
- eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
50
- old_value="(eval echo \\$$name)";
51
- if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
52
- was_bool=true;
53
- else
54
- was_bool=false;
55
- fi
56
-
57
- # Set the variable to the right value-- the escaped quotes make it work if
58
- # the option had spaces, like --cmd "queue.pl -sync y"
59
- eval "${name}=\"$2\"";
60
-
61
- # Check that Boolean-valued arguments are really Boolean.
62
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
63
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
64
- exit 1;
65
- fi
66
- shift 2;
67
- ;;
68
-
69
- *) break;
70
- esac
71
- done
72
-
73
- file_dir="${work_dir}/${file_folder_name}"
74
- final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
75
- evaluation_audio_dir="${file_dir}/evaluation_audio"
76
-
77
- train_dataset="${file_dir}/train.jsonl"
78
- valid_dataset="${file_dir}/valid.jsonl"
79
-
80
- $verbose && echo "system_version: ${system_version}"
81
- $verbose && echo "file_folder_name: ${file_folder_name}"
82
-
83
- if [ $system_version == "windows" ]; then
84
- alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
85
- elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
86
- #source /data/local/bin/nx_denoise/bin/activate
87
- alias python3='/data/local/bin/nx_denoise/bin/python3'
88
- fi
89
-
90
-
91
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
92
- $verbose && echo "stage 1: prepare data"
93
- cd "${work_dir}" || exit 1
94
- python3 step_1_prepare_data.py \
95
- --file_dir "${file_dir}" \
96
- --audio_dir "${audio_dir}" \
97
- --train_dataset "${train_dataset}" \
98
- --valid_dataset "${valid_dataset}" \
99
- --max_count "${max_count}" \
100
-
101
- fi
102
-
103
-
104
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
105
- $verbose && echo "stage 2: train model"
106
- cd "${work_dir}" || exit 1
107
- python3 step_2_train_model.py \
108
- --train_dataset "${train_dataset}" \
109
- --valid_dataset "${valid_dataset}" \
110
- --serialization_dir "${file_dir}" \
111
- --config_file "${config_file}" \
112
-
113
- fi
114
-
115
-
116
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
- $verbose && echo "stage 3: test model"
118
- cd "${work_dir}" || exit 1
119
- python3 step_3_evaluation.py \
120
- --valid_dataset "${valid_dataset}" \
121
- --model_dir "${file_dir}/best" \
122
- --evaluation_audio_dir "${evaluation_audio_dir}" \
123
- --limit "${limit}" \
124
-
125
- fi
126
-
127
-
128
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
- $verbose && echo "stage 4: collect files"
130
- cd "${work_dir}" || exit 1
131
-
132
- mkdir -p ${final_model_dir}
133
-
134
- cp "${file_dir}/best"/* "${final_model_dir}"
135
- cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
136
-
137
- cd "${final_model_dir}/.." || exit 1;
138
-
139
- if [ -e "${final_model_name}.zip" ]; then
140
- rm -rf "${final_model_name}_backup.zip"
141
- mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
142
- fi
143
-
144
- zip -r "${final_model_name}.zip" "${final_model_name}"
145
- rm -rf "${final_model_name}"
146
-
147
- fi
148
-
149
-
150
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
151
- $verbose && echo "stage 5: clear file_dir"
152
- cd "${work_dir}" || exit 1
153
-
154
- rm -rf "${file_dir}";
155
-
156
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/frcrn_mp3_to_wav/step_1_prepare_data.py DELETED
@@ -1,127 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import json
5
- import os
6
- from pathlib import Path
7
- import random
8
- import sys
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import librosa
14
- import numpy as np
15
- from tqdm import tqdm
16
-
17
-
18
- def get_args():
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument("--file_dir", default="./", type=str)
21
-
22
- parser.add_argument(
23
- "--audio_dir",
24
- default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
25
- type=str
26
- )
27
-
28
- parser.add_argument("--train_dataset", default="train.jsonl", type=str)
29
- parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
30
-
31
- parser.add_argument("--duration", default=4.0, type=float)
32
-
33
- parser.add_argument("--target_sample_rate", default=8000, type=int)
34
-
35
- parser.add_argument("--max_count", default=-1, type=int)
36
-
37
- args = parser.parse_args()
38
- return args
39
-
40
-
41
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
42
- data_dir = Path(data_dir)
43
- for epoch_idx in range(max_epoch):
44
- for filename in data_dir.glob("**/*.wav"):
45
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
46
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
47
-
48
- if raw_duration < duration:
49
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
50
- continue
51
- if signal.ndim != 1:
52
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
53
-
54
- signal_length = len(signal)
55
- win_size = int(duration * sample_rate)
56
- for begin in range(0, signal_length - win_size, win_size):
57
- if np.sum(signal[begin: begin+win_size]) == 0:
58
- continue
59
- row = {
60
- "epoch_idx": epoch_idx,
61
- "filename": filename.as_posix(),
62
- "raw_duration": round(raw_duration, 4),
63
- "offset": round(begin / sample_rate, 4),
64
- "duration": round(duration, 4),
65
- }
66
- yield row
67
-
68
-
69
- def main():
70
- args = get_args()
71
-
72
- file_dir = Path(args.file_dir)
73
- file_dir.mkdir(exist_ok=True)
74
-
75
- audio_dir = Path(args.audio_dir)
76
-
77
- audio_generator = target_second_signal_generator(
78
- audio_dir.as_posix(),
79
- duration=args.duration,
80
- sample_rate=args.target_sample_rate,
81
- max_epoch=1,
82
- )
83
- count = 0
84
- process_bar = tqdm(desc="build dataset jsonl")
85
- with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
86
- for audio in audio_generator:
87
- if count >= args.max_count > 0:
88
- break
89
-
90
- filename = audio["filename"]
91
- raw_duration = audio["raw_duration"]
92
- offset = audio["offset"]
93
- duration = audio["duration"]
94
-
95
- random1 = random.random()
96
- random2 = random.random()
97
-
98
- row = {
99
- "count": count,
100
-
101
- "filename": filename,
102
- "raw_duration": raw_duration,
103
- "offset": offset,
104
- "duration": duration,
105
-
106
- "random1": random1,
107
- }
108
- row = json.dumps(row, ensure_ascii=False)
109
- if random2 < (1 / 10):
110
- fvalid.write(f"{row}\n")
111
- else:
112
- ftrain.write(f"{row}\n")
113
-
114
- count += 1
115
- duration_seconds = count * args.duration
116
- duration_hours = duration_seconds / 3600
117
-
118
- process_bar.update(n=1)
119
- process_bar.set_postfix({
120
- "duration_hours": round(duration_hours, 4),
121
- })
122
-
123
- return
124
-
125
-
126
- if __name__ == "__main__":
127
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/frcrn_mp3_to_wav/step_2_train_model.py DELETED
@@ -1,442 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import json
5
- import logging
6
- from logging.handlers import TimedRotatingFileHandler
7
- import os
8
- import platform
9
- from pathlib import Path
10
- import random
11
- import sys
12
- import shutil
13
- from typing import List
14
-
15
- pwd = os.path.abspath(os.path.dirname(__file__))
16
- sys.path.append(os.path.join(pwd, "../../"))
17
-
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- from torch.nn import functional as F
22
- from torch.utils.data.dataloader import DataLoader
23
- from tqdm import tqdm
24
-
25
- from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
26
- from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
27
- from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
28
- from toolbox.torchaudio.metrics.pesq import run_pesq_score
29
- from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
30
- from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
31
-
32
-
33
- def get_args():
34
- parser = argparse.ArgumentParser()
35
- parser.add_argument("--train_dataset", default="train.jsonl", type=str)
36
- parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
37
-
38
- parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
39
- parser.add_argument("--patience", default=30, type=int)
40
- parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
41
-
42
- parser.add_argument("--config_file", default="config.yaml", type=str)
43
-
44
- args = parser.parse_args()
45
- return args
46
-
47
-
48
- def logging_config(file_dir: str):
49
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
50
-
51
- logging.basicConfig(format=fmt,
52
- datefmt="%m/%d/%Y %H:%M:%S",
53
- level=logging.INFO)
54
- file_handler = TimedRotatingFileHandler(
55
- filename=os.path.join(file_dir, "main.log"),
56
- encoding="utf-8",
57
- when="D",
58
- interval=1,
59
- backupCount=7
60
- )
61
- file_handler.setLevel(logging.INFO)
62
- file_handler.setFormatter(logging.Formatter(fmt))
63
- logger = logging.getLogger(__name__)
64
- logger.addHandler(file_handler)
65
-
66
- return logger
67
-
68
-
69
- class CollateFunction(object):
70
- def __init__(self):
71
- pass
72
-
73
- def __call__(self, batch: List[dict]):
74
- mp3_waveform_list = list()
75
- wav_waveform_list = list()
76
-
77
- for sample in batch:
78
- mp3_waveform: torch.Tensor = sample["mp3_waveform"]
79
- wav_waveform: torch.Tensor = sample["wav_waveform"]
80
-
81
- mp3_waveform_list.append(mp3_waveform)
82
- wav_waveform_list.append(wav_waveform)
83
-
84
- mp3_waveform_list = torch.stack(mp3_waveform_list)
85
- wav_waveform_list = torch.stack(wav_waveform_list)
86
-
87
- # assert
88
- if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
89
- raise AssertionError("nan or inf in mp3_waveform_list")
90
- if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
91
- raise AssertionError("nan or inf in wav_waveform_list")
92
-
93
- return mp3_waveform_list, wav_waveform_list
94
-
95
-
96
- collate_fn = CollateFunction()
97
-
98
-
99
- def main():
100
- args = get_args()
101
-
102
- config = FRCRNConfig.from_pretrained(
103
- pretrained_model_name_or_path=args.config_file,
104
- )
105
-
106
- serialization_dir = Path(args.serialization_dir)
107
- serialization_dir.mkdir(parents=True, exist_ok=True)
108
-
109
- logger = logging_config(serialization_dir)
110
-
111
- random.seed(config.seed)
112
- np.random.seed(config.seed)
113
- torch.manual_seed(config.seed)
114
- logger.info(f"set seed: {config.seed}")
115
-
116
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117
- n_gpu = torch.cuda.device_count()
118
- logger.info(f"GPU available count: {n_gpu}; device: {device}")
119
-
120
- # datasets
121
- train_dataset = Mp3ToWavJsonlDataset(
122
- jsonl_file=args.train_dataset,
123
- expected_sample_rate=config.sample_rate,
124
- max_wave_value=32768.0,
125
- # skip=225000,
126
- )
127
- valid_dataset = Mp3ToWavJsonlDataset(
128
- jsonl_file=args.valid_dataset,
129
- expected_sample_rate=config.sample_rate,
130
- max_wave_value=32768.0,
131
- )
132
- train_data_loader = DataLoader(
133
- dataset=train_dataset,
134
- batch_size=config.batch_size,
135
- # shuffle=True,
136
- sampler=None,
137
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
138
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
139
- collate_fn=collate_fn,
140
- pin_memory=False,
141
- prefetch_factor=2,
142
- )
143
- valid_data_loader = DataLoader(
144
- dataset=valid_dataset,
145
- batch_size=config.batch_size,
146
- # shuffle=True,
147
- sampler=None,
148
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
149
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
150
- collate_fn=collate_fn,
151
- pin_memory=False,
152
- prefetch_factor=2,
153
- )
154
-
155
- # models
156
- logger.info(f"prepare models. config_file: {args.config_file}")
157
- model = FRCRNPretrainedModel(config).to(device)
158
- model.to(device)
159
- model.train()
160
-
161
- # optimizer
162
- logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
163
- optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
164
-
165
- # resume training
166
- last_step_idx = -1
167
- last_epoch = -1
168
- for step_idx_str in serialization_dir.glob("steps-*"):
169
- step_idx_str = Path(step_idx_str)
170
- step_idx = step_idx_str.stem.split("-")[1]
171
- step_idx = int(step_idx)
172
- if step_idx > last_step_idx:
173
- last_step_idx = step_idx
174
- # last_epoch = 0
175
-
176
- if last_step_idx != -1:
177
- logger.info(f"resume from steps-{last_step_idx}.")
178
- model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
179
- # optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
180
-
181
- logger.info(f"load state dict for model.")
182
- with open(model_pt.as_posix(), "rb") as f:
183
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
184
- model.load_state_dict(state_dict, strict=True)
185
-
186
- # logger.info(f"load state dict for optimizer.")
187
- # with open(optimizer_pth.as_posix(), "rb") as f:
188
- # state_dict = torch.load(f, map_location="cpu", weights_only=True)
189
- # optimizer.load_state_dict(state_dict)
190
-
191
- if config.lr_scheduler == "CosineAnnealingLR":
192
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
193
- optimizer,
194
- last_epoch=last_epoch,
195
- # T_max=10 * config.eval_steps,
196
- # eta_min=0.01 * config.lr,
197
- **config.lr_scheduler_kwargs,
198
- )
199
- elif config.lr_scheduler == "MultiStepLR":
200
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
201
- optimizer,
202
- last_epoch=last_epoch,
203
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
204
- )
205
- else:
206
- raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
207
-
208
- neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
209
- mr_stft_loss_fn = MultiResolutionSTFTLoss(
210
- fft_size_list=[256, 512, 1024],
211
- win_size_list=[256, 512, 1024],
212
- hop_size_list=[128, 256, 512],
213
- factor_sc=1.5,
214
- factor_mag=1.0,
215
- reduction="mean"
216
- ).to(device)
217
-
218
- # training loop
219
-
220
- # state
221
- average_pesq_score = 1000000000
222
- average_loss = 1000000000
223
- average_neg_si_snr_loss = 1000000000
224
- average_mask_loss = 1000000000
225
-
226
- model_list = list()
227
- best_epoch_idx = None
228
- best_step_idx = None
229
- best_metric = None
230
- patience_count = 0
231
-
232
- step_idx = 0 if last_step_idx == -1 else last_step_idx
233
-
234
- logger.info("training")
235
- early_stop_flag = False
236
- for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
237
- if early_stop_flag:
238
- break
239
-
240
- # train
241
- model.train()
242
-
243
- total_pesq_score = 0.
244
- total_loss = 0.
245
- total_mr_stft_loss = 0.
246
- total_neg_si_snr_loss = 0.
247
- total_mask_loss = 0.
248
- total_batches = 0.
249
-
250
- progress_bar_train = tqdm(
251
- initial=step_idx,
252
- desc="Training; epoch-{}".format(epoch_idx),
253
- )
254
- for train_batch in train_data_loader:
255
- mp3_audios, wav_audios = train_batch
256
- noisy_audios: torch.Tensor = mp3_audios.to(device)
257
- clean_audios: torch.Tensor = wav_audios.to(device)
258
-
259
- est_spec, est_wav, est_mask = model.forward(noisy_audios)
260
- denoise_audios = est_wav
261
-
262
- mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
263
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
264
- mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
265
-
266
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
267
- if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
268
- logger.info(f"find nan or inf in loss.")
269
- continue
270
-
271
- denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
272
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
273
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
274
-
275
- optimizer.zero_grad()
276
- loss.backward()
277
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
278
- optimizer.step()
279
- lr_scheduler.step()
280
-
281
- total_pesq_score += pesq_score
282
- total_loss += loss.item()
283
- total_mr_stft_loss += mr_stft_loss.item()
284
- total_neg_si_snr_loss += neg_si_snr_loss.item()
285
- total_mask_loss += mask_loss.item()
286
- total_batches += 1
287
-
288
- average_pesq_score = round(total_pesq_score / total_batches, 4)
289
- average_loss = round(total_loss / total_batches, 4)
290
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
291
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
292
- average_mask_loss = round(total_mask_loss / total_batches, 4)
293
-
294
- progress_bar_train.update(1)
295
- progress_bar_train.set_postfix({
296
- "lr": lr_scheduler.get_last_lr()[0],
297
- "pesq_score": average_pesq_score,
298
- "loss": average_loss,
299
- "mr_stft_loss": average_mr_stft_loss,
300
- "neg_si_snr_loss": average_neg_si_snr_loss,
301
- "mask_loss": average_mask_loss,
302
- })
303
-
304
- # evaluation
305
- step_idx += 1
306
- if step_idx % config.eval_steps == 0:
307
- model.eval()
308
- with torch.no_grad():
309
- torch.cuda.empty_cache()
310
-
311
- total_pesq_score = 0.
312
- total_loss = 0.
313
- total_mr_stft_loss = 0.
314
- total_neg_si_snr_loss = 0.
315
- total_mask_loss = 0.
316
- total_batches = 0.
317
-
318
- progress_bar_train.close()
319
- progress_bar_eval = tqdm(
320
- desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
321
- )
322
- for eval_batch in valid_data_loader:
323
- mp3_audios, wav_audios = eval_batch
324
- noisy_audios: torch.Tensor = mp3_audios.to(device)
325
- clean_audios: torch.Tensor = wav_audios.to(device)
326
-
327
- est_spec, est_wav, est_mask = model.forward(noisy_audios)
328
- denoise_audios = est_wav
329
-
330
- mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
331
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
332
- mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
333
-
334
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
335
- if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
336
- logger.info(f"find nan or inf in loss.")
337
- continue
338
-
339
- denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
340
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
341
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
342
-
343
- total_pesq_score += pesq_score
344
- total_loss += loss.item()
345
- total_neg_si_snr_loss += neg_si_snr_loss.item()
346
- total_mask_loss += mask_loss.item()
347
- total_batches += 1
348
-
349
- average_pesq_score = round(total_pesq_score / total_batches, 4)
350
- average_loss = round(total_loss / total_batches, 4)
351
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
352
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
353
- average_mask_loss = round(total_mask_loss / total_batches, 4)
354
-
355
- progress_bar_eval.update(1)
356
- progress_bar_eval.set_postfix({
357
- "lr": lr_scheduler.get_last_lr()[0],
358
- "pesq_score": average_pesq_score,
359
- "loss": average_loss,
360
- "mr_stft_loss": average_mr_stft_loss,
361
- "neg_si_snr_loss": average_neg_si_snr_loss,
362
- "mask_loss": average_mask_loss,
363
- })
364
-
365
- total_pesq_score = 0.
366
- total_loss = 0.
367
- total_mr_stft_loss = 0.
368
- total_neg_si_snr_loss = 0.
369
- total_mask_loss = 0.
370
- total_batches = 0.
371
-
372
- progress_bar_eval.close()
373
- progress_bar_train = tqdm(
374
- initial=progress_bar_train.n,
375
- postfix=progress_bar_train.postfix,
376
- desc=progress_bar_train.desc,
377
- )
378
-
379
- # save path
380
- save_dir = serialization_dir / "steps-{}".format(step_idx)
381
- save_dir.mkdir(parents=True, exist_ok=False)
382
-
383
- # save models
384
- model.save_pretrained(save_dir.as_posix())
385
-
386
- model_list.append(save_dir)
387
- if len(model_list) >= args.num_serialized_models_to_keep:
388
- model_to_delete: Path = model_list.pop(0)
389
- shutil.rmtree(model_to_delete.as_posix())
390
-
391
- # save metric
392
- if best_metric is None:
393
- best_epoch_idx = epoch_idx
394
- best_step_idx = step_idx
395
- best_metric = average_pesq_score
396
- elif average_pesq_score >= best_metric:
397
- # great is better.
398
- best_epoch_idx = epoch_idx
399
- best_step_idx = step_idx
400
- best_metric = average_pesq_score
401
- else:
402
- pass
403
-
404
- metrics = {
405
- "epoch_idx": epoch_idx,
406
- "best_epoch_idx": best_epoch_idx,
407
- "best_step_idx": best_step_idx,
408
- "pesq_score": average_pesq_score,
409
- "loss": average_loss,
410
- "neg_si_snr_loss": average_neg_si_snr_loss,
411
- "mask_loss": average_mask_loss,
412
- }
413
- metrics_filename = save_dir / "metrics_epoch.json"
414
- with open(metrics_filename, "w", encoding="utf-8") as f:
415
- json.dump(metrics, f, indent=4, ensure_ascii=False)
416
-
417
- # save best
418
- best_dir = serialization_dir / "best"
419
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
420
- if best_dir.exists():
421
- shutil.rmtree(best_dir)
422
- shutil.copytree(save_dir, best_dir)
423
-
424
- # early stop
425
- early_stop_flag = False
426
- if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
427
- patience_count = 0
428
- else:
429
- patience_count += 1
430
- if patience_count >= args.patience:
431
- early_stop_flag = True
432
-
433
- # early stop
434
- if early_stop_flag:
435
- break
436
- model.train()
437
-
438
- return
439
-
440
-
441
- if __name__ == "__main__":
442
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/frcrn_mp3_to_wav/yaml/config-10.yaml DELETED
@@ -1,31 +0,0 @@
1
- model_name: "frcrn"
2
-
3
- sample_rate: 8000
4
- segment_size: 32000
5
- nfft: 128
6
- win_size: 128
7
- hop_size: 64
8
- win_type: hann
9
-
10
- use_complex_networks: true
11
- model_depth: 10
12
- model_complexity: -1
13
-
14
- min_snr_db: -10
15
- max_snr_db: 20
16
-
17
- num_workers: 8
18
- batch_size: 32
19
- eval_steps: 20000
20
-
21
- lr: 0.001
22
- lr_scheduler: "CosineAnnealingLR"
23
- lr_scheduler_kwargs:
24
- T_max: 250000
25
- eta_min: 0.0001
26
-
27
- max_epochs: 100
28
- weight_decay: 1.0e-05
29
- clip_grad_norm: 10.0
30
- seed: 1234
31
- num_gpus: -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/frcrn_mp3_to_wav/yaml/config-14.yaml DELETED
@@ -1,31 +0,0 @@
1
- model_name: "frcrn"
2
-
3
- sample_rate: 8000
4
- segment_size: 32000
5
- nfft: 640
6
- win_size: 640
7
- hop_size: 320
8
- win_type: hann
9
-
10
- use_complex_networks: true
11
- model_depth: 14
12
- model_complexity: -1
13
-
14
- min_snr_db: -10
15
- max_snr_db: 20
16
-
17
- num_workers: 8
18
- batch_size: 32
19
- eval_steps: 10000
20
-
21
- lr: 0.001
22
- lr_scheduler: "CosineAnnealingLR"
23
- lr_scheduler_kwargs:
24
- T_max: 250000
25
- eta_min: 0.0001
26
-
27
- max_epochs: 100
28
- weight_decay: 1.0e-05
29
- clip_grad_norm: 10.0
30
- seed: 1234
31
- num_gpus: -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/frcrn_mp3_to_wav/yaml/config-20.yaml DELETED
@@ -1,31 +0,0 @@
1
- model_name: "frcrn"
2
-
3
- sample_rate: 8000
4
- segment_size: 32000
5
- nfft: 512
6
- win_size: 512
7
- hop_size: 256
8
- win_type: hann
9
-
10
- use_complex_networks: true
11
- model_depth: 20
12
- model_complexity: 45
13
-
14
- min_snr_db: -10
15
- max_snr_db: 20
16
-
17
- num_workers: 8
18
- batch_size: 32
19
- eval_steps: 10000
20
-
21
- lr: 0.001
22
- lr_scheduler: "CosineAnnealingLR"
23
- lr_scheduler_kwargs:
24
- T_max: 250000
25
- eta_min: 0.0001
26
-
27
- max_epochs: 100
28
- weight_decay: 1.0e-05
29
- clip_grad_norm: 10.0
30
- seed: 1234
31
- num_gpus: -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/simple_linear_irm_aishell/run.sh DELETED
@@ -1,172 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- : <<'END'
4
-
5
- sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir
6
-
7
- sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir
8
-
9
- sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
10
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
12
-
13
-
14
- END
15
-
16
-
17
- # params
18
- system_version="windows";
19
- verbose=true;
20
- stage=0 # start from 0 if you need to start from data preparation
21
- stop_stage=9
22
-
23
- work_dir="$(pwd)"
24
- file_folder_name=file_folder_name
25
- final_model_name=final_model_name
26
- config_file="yaml/config.yaml"
27
- limit=10
28
-
29
- noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
30
- speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
31
-
32
- nohup_name=nohup.out
33
-
34
- # model params
35
- batch_size=64
36
- max_epochs=200
37
- save_top_k=10
38
- patience=5
39
-
40
-
41
- # parse options
42
- while true; do
43
- [ -z "${1:-}" ] && break; # break if there are no arguments
44
- case "$1" in
45
- --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
46
- eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
47
- old_value="(eval echo \\$$name)";
48
- if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
49
- was_bool=true;
50
- else
51
- was_bool=false;
52
- fi
53
-
54
- # Set the variable to the right value-- the escaped quotes make it work if
55
- # the option had spaces, like --cmd "queue.pl -sync y"
56
- eval "${name}=\"$2\"";
57
-
58
- # Check that Boolean-valued arguments are really Boolean.
59
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
60
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
61
- exit 1;
62
- fi
63
- shift 2;
64
- ;;
65
-
66
- *) break;
67
- esac
68
- done
69
-
70
- file_dir="${work_dir}/${file_folder_name}"
71
- final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
- evaluation_audio_dir="${file_dir}/evaluation_audio"
73
-
74
- dataset="${file_dir}/dataset.xlsx"
75
- train_dataset="${file_dir}/train.xlsx"
76
- valid_dataset="${file_dir}/valid.xlsx"
77
-
78
- $verbose && echo "system_version: ${system_version}"
79
- $verbose && echo "file_folder_name: ${file_folder_name}"
80
-
81
- if [ $system_version == "windows" ]; then
82
- alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
83
- elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
84
- #source /data/local/bin/nx_denoise/bin/activate
85
- alias python3='/data/local/bin/nx_denoise/bin/python3'
86
- fi
87
-
88
-
89
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
90
- $verbose && echo "stage 1: prepare data"
91
- cd "${work_dir}" || exit 1
92
- python3 step_1_prepare_data.py \
93
- --file_dir "${file_dir}" \
94
- --noise_dir "${noise_dir}" \
95
- --speech_dir "${speech_dir}" \
96
- --train_dataset "${train_dataset}" \
97
- --valid_dataset "${valid_dataset}" \
98
-
99
- fi
100
-
101
-
102
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
103
- $verbose && echo "stage 2: train model"
104
- cd "${work_dir}" || exit 1
105
- python3 step_2_train_model.py \
106
- --train_dataset "${train_dataset}" \
107
- --valid_dataset "${valid_dataset}" \
108
- --serialization_dir "${file_dir}" \
109
- --config_file "${config_file}" \
110
-
111
- fi
112
-
113
-
114
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
115
- $verbose && echo "stage 3: test model"
116
- cd "${work_dir}" || exit 1
117
- python3 step_3_evaluation.py \
118
- --valid_dataset "${valid_dataset}" \
119
- --model_dir "${file_dir}/best" \
120
- --evaluation_audio_dir "${evaluation_audio_dir}" \
121
- --limit "${limit}" \
122
-
123
- fi
124
-
125
-
126
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
127
- $verbose && echo "stage 4: export model"
128
- cd "${work_dir}" || exit 1
129
- python3 step_5_export_models.py \
130
- --vocabulary_dir "${vocabulary_dir}" \
131
- --model_dir "${file_dir}/best" \
132
- --serialization_dir "${file_dir}" \
133
-
134
- fi
135
-
136
-
137
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
138
- $verbose && echo "stage 5: collect files"
139
- cd "${work_dir}" || exit 1
140
-
141
- mkdir -p ${final_model_dir}
142
-
143
- cp "${file_dir}/best"/* "${final_model_dir}"
144
- cp -r "${file_dir}/vocabulary" "${final_model_dir}"
145
-
146
- cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
147
-
148
- cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
149
- cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
150
- cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
151
- cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
152
-
153
- cd "${final_model_dir}/.." || exit 1;
154
-
155
- if [ -e "${final_model_name}.zip" ]; then
156
- rm -rf "${final_model_name}_backup.zip"
157
- mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
158
- fi
159
-
160
- zip -r "${final_model_name}.zip" "${final_model_name}"
161
- rm -rf "${final_model_name}"
162
-
163
- fi
164
-
165
-
166
- if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
167
- $verbose && echo "stage 6: clear file_dir"
168
- cd "${work_dir}" || exit 1
169
-
170
- rm -rf "${file_dir}";
171
-
172
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/simple_linear_irm_aishell/step_1_prepare_data.py DELETED
@@ -1,196 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import os
5
- from pathlib import Path
6
- import random
7
- import sys
8
- import shutil
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import pandas as pd
14
- from scipy.io import wavfile
15
- from tqdm import tqdm
16
- import librosa
17
-
18
- from project_settings import project_path
19
-
20
-
21
- def get_args():
22
- parser = argparse.ArgumentParser()
23
- parser.add_argument("--file_dir", default="./", type=str)
24
-
25
- parser.add_argument(
26
- "--noise_dir",
27
- default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
- type=str
29
- )
30
- parser.add_argument(
31
- "--speech_dir",
32
- default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
- type=str
34
- )
35
-
36
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
-
39
- parser.add_argument("--duration", default=2.0, type=float)
40
- parser.add_argument("--min_nsr_db", default=-20, type=float)
41
- parser.add_argument("--max_nsr_db", default=5, type=float)
42
-
43
- parser.add_argument("--target_sample_rate", default=8000, type=int)
44
-
45
- args = parser.parse_args()
46
- return args
47
-
48
-
49
- def filename_generator(data_dir: str):
50
- data_dir = Path(data_dir)
51
- for filename in data_dir.glob("**/*.wav"):
52
- yield filename.as_posix()
53
-
54
-
55
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
- data_dir = Path(data_dir)
57
- for filename in data_dir.glob("**/*.wav"):
58
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
-
61
- if raw_duration < duration:
62
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
- continue
64
- if signal.ndim != 1:
65
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
-
67
- signal_length = len(signal)
68
- win_size = int(duration * sample_rate)
69
- for begin in range(0, signal_length - win_size, win_size):
70
- row = {
71
- "filename": filename.as_posix(),
72
- "raw_duration": round(raw_duration, 4),
73
- "offset": round(begin / sample_rate, 4),
74
- "duration": round(duration, 4),
75
- }
76
- yield row
77
-
78
-
79
- def get_dataset(args):
80
- file_dir = Path(args.file_dir)
81
- file_dir.mkdir(exist_ok=True)
82
-
83
- noise_dir = Path(args.noise_dir)
84
- speech_dir = Path(args.speech_dir)
85
-
86
- noise_generator = target_second_signal_generator(
87
- noise_dir.as_posix(),
88
- duration=args.duration,
89
- sample_rate=args.target_sample_rate
90
- )
91
- speech_generator = target_second_signal_generator(
92
- speech_dir.as_posix(),
93
- duration=args.duration,
94
- sample_rate=args.target_sample_rate
95
- )
96
-
97
- dataset = list()
98
-
99
- count = 0
100
- process_bar = tqdm(desc="build dataset excel")
101
- for noise, speech in zip(noise_generator, speech_generator):
102
-
103
- noise_filename = noise["filename"]
104
- noise_raw_duration = noise["raw_duration"]
105
- noise_offset = noise["offset"]
106
- noise_duration = noise["duration"]
107
-
108
- speech_filename = speech["filename"]
109
- speech_raw_duration = speech["raw_duration"]
110
- speech_offset = speech["offset"]
111
- speech_duration = speech["duration"]
112
-
113
- random1 = random.random()
114
- random2 = random.random()
115
-
116
- row = {
117
- "noise_filename": noise_filename,
118
- "noise_raw_duration": noise_raw_duration,
119
- "noise_offset": noise_offset,
120
- "noise_duration": noise_duration,
121
-
122
- "speech_filename": speech_filename,
123
- "speech_raw_duration": speech_raw_duration,
124
- "speech_offset": speech_offset,
125
- "speech_duration": speech_duration,
126
-
127
- "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db),
128
-
129
- "random1": random1,
130
- "random2": random2,
131
- "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
- }
133
- dataset.append(row)
134
- count += 1
135
- duration_seconds = count * args.duration
136
- duration_hours = duration_seconds / 3600
137
-
138
- process_bar.update(n=1)
139
- process_bar.set_postfix({
140
- # "duration_seconds": round(duration_seconds, 4),
141
- "duration_hours": round(duration_hours, 4),
142
- })
143
-
144
- dataset = pd.DataFrame(dataset)
145
- dataset = dataset.sort_values(by=["random1"], ascending=False)
146
- dataset.to_excel(
147
- file_dir / "dataset.xlsx",
148
- index=False,
149
- )
150
- return
151
-
152
-
153
-
154
- def split_dataset(args):
155
- """分割训练集, 测试集"""
156
- file_dir = Path(args.file_dir)
157
- file_dir.mkdir(exist_ok=True)
158
-
159
- df = pd.read_excel(file_dir / "dataset.xlsx")
160
-
161
- train = list()
162
- test = list()
163
-
164
- for i, row in df.iterrows():
165
- flag = row["flag"]
166
- if flag == "TRAIN":
167
- train.append(row)
168
- else:
169
- test.append(row)
170
-
171
- train = pd.DataFrame(train)
172
- train.to_excel(
173
- args.train_dataset,
174
- index=False,
175
- # encoding="utf_8_sig"
176
- )
177
- test = pd.DataFrame(test)
178
- test.to_excel(
179
- args.valid_dataset,
180
- index=False,
181
- # encoding="utf_8_sig"
182
- )
183
-
184
- return
185
-
186
-
187
- def main():
188
- args = get_args()
189
-
190
- get_dataset(args)
191
- split_dataset(args)
192
- return
193
-
194
-
195
- if __name__ == "__main__":
196
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/simple_linear_irm_aishell/step_2_train_model.py DELETED
@@ -1,348 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
- """
6
- import argparse
7
- import json
8
- import logging
9
- from logging.handlers import TimedRotatingFileHandler
10
- import os
11
- import platform
12
- from pathlib import Path
13
- import random
14
- import sys
15
- import shutil
16
- from typing import List
17
-
18
- from torch import dtype
19
-
20
- pwd = os.path.abspath(os.path.dirname(__file__))
21
- sys.path.append(os.path.join(pwd, "../../"))
22
-
23
- import numpy as np
24
- import torch
25
- import torch.nn as nn
26
- from torch.utils.data.dataloader import DataLoader
27
- import torchaudio
28
- from tqdm import tqdm
29
-
30
- from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
31
- from toolbox.torchaudio.models.simple_linear_irm.configuration_simple_linear_irm import SimpleLinearIRMConfig
32
- from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel
33
-
34
-
35
- def get_args():
36
- parser = argparse.ArgumentParser()
37
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
38
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
39
-
40
- parser.add_argument("--max_epochs", default=100, type=int)
41
-
42
- parser.add_argument("--batch_size", default=64, type=int)
43
- parser.add_argument("--learning_rate", default=1e-3, type=float)
44
- parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
45
- parser.add_argument("--patience", default=5, type=int)
46
- parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
47
- parser.add_argument("--seed", default=0, type=int)
48
-
49
- parser.add_argument("--config_file", default="config.yaml", type=str)
50
-
51
- args = parser.parse_args()
52
- return args
53
-
54
-
55
- def logging_config(file_dir: str):
56
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
57
-
58
- logging.basicConfig(format=fmt,
59
- datefmt="%m/%d/%Y %H:%M:%S",
60
- level=logging.INFO)
61
- file_handler = TimedRotatingFileHandler(
62
- filename=os.path.join(file_dir, "main.log"),
63
- encoding="utf-8",
64
- when="D",
65
- interval=1,
66
- backupCount=7
67
- )
68
- file_handler.setLevel(logging.INFO)
69
- file_handler.setFormatter(logging.Formatter(fmt))
70
- logger = logging.getLogger(__name__)
71
- logger.addHandler(file_handler)
72
-
73
- return logger
74
-
75
-
76
- class CollateFunction(object):
77
- def __init__(self,
78
- n_fft: int = 512,
79
- win_length: int = 200,
80
- hop_length: int = 80,
81
- window_fn: str = "hamming",
82
- irm_beta: float = 1.0,
83
- epsilon: float = 1e-8,
84
- ):
85
- self.n_fft = n_fft
86
- self.win_length = win_length
87
- self.hop_length = hop_length
88
- self.window_fn = window_fn
89
- self.irm_beta = irm_beta
90
- self.epsilon = epsilon
91
-
92
- self.transform = torchaudio.transforms.Spectrogram(
93
- n_fft=self.n_fft,
94
- win_length=self.win_length,
95
- hop_length=self.hop_length,
96
- power=2.0,
97
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
98
- )
99
-
100
- def __call__(self, batch: List[dict]):
101
- mix_spec_list = list()
102
- speech_irm_list = list()
103
- snr_db_list = list()
104
- for sample in batch:
105
- noise_wave: torch.Tensor = sample["noise_wave"]
106
- speech_wave: torch.Tensor = sample["speech_wave"]
107
- mix_wave: torch.Tensor = sample["mix_wave"]
108
- snr_db: float = sample["snr_db"]
109
-
110
- noise_spec = self.transform.forward(noise_wave)
111
- speech_spec = self.transform.forward(speech_wave)
112
- mix_spec = self.transform.forward(mix_wave)
113
-
114
- # noise_irm = noise_spec / (noise_spec + speech_spec)
115
- speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
116
- speech_irm = torch.pow(speech_irm, self.irm_beta)
117
-
118
- mix_spec_list.append(mix_spec)
119
- speech_irm_list.append(speech_irm)
120
- snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32))
121
-
122
- mix_spec_list = torch.stack(mix_spec_list)
123
- speech_irm_list = torch.stack(speech_irm_list)
124
- snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,)
125
-
126
- # assert
127
- if torch.any(torch.isnan(mix_spec_list)):
128
- raise AssertionError("nan in mix_spec Tensor")
129
- if torch.any(torch.isnan(speech_irm_list)):
130
- raise AssertionError("nan in speech_irm Tensor")
131
- if torch.any(torch.isnan(snr_db_list)):
132
- raise AssertionError("nan in snr_db Tensor")
133
-
134
- return mix_spec_list, speech_irm_list, snr_db_list
135
-
136
-
137
- collate_fn = CollateFunction()
138
-
139
-
140
- def main():
141
- args = get_args()
142
-
143
- serialization_dir = Path(args.serialization_dir)
144
- serialization_dir.mkdir(parents=True, exist_ok=True)
145
-
146
- logger = logging_config(serialization_dir)
147
-
148
- random.seed(args.seed)
149
- np.random.seed(args.seed)
150
- torch.manual_seed(args.seed)
151
- logger.info("set seed: {}".format(args.seed))
152
-
153
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
- n_gpu = torch.cuda.device_count()
155
- logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
156
-
157
- # datasets
158
- logger.info("prepare datasets")
159
- train_dataset = DenoiseExcelDataset(
160
- excel_file=args.train_dataset,
161
- expected_sample_rate=8000,
162
- max_wave_value=32768.0,
163
- )
164
- valid_dataset = DenoiseExcelDataset(
165
- excel_file=args.valid_dataset,
166
- expected_sample_rate=8000,
167
- max_wave_value=32768.0,
168
- )
169
- train_data_loader = DataLoader(
170
- dataset=train_dataset,
171
- batch_size=args.batch_size,
172
- shuffle=True,
173
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
174
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
175
- collate_fn=collate_fn,
176
- pin_memory=False,
177
- # prefetch_factor=64,
178
- )
179
- valid_data_loader = DataLoader(
180
- dataset=valid_dataset,
181
- batch_size=args.batch_size,
182
- shuffle=True,
183
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
184
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
185
- collate_fn=collate_fn,
186
- pin_memory=False,
187
- # prefetch_factor=64,
188
- )
189
-
190
- # models
191
- logger.info(f"prepare models. config_file: {args.config_file}")
192
- config = SimpleLinearIRMConfig.from_pretrained(
193
- pretrained_model_name_or_path=args.config_file,
194
- # num_labels=vocabulary.get_vocab_size(namespace="labels")
195
- )
196
- model = SimpleLinearIRMPretrainedModel(
197
- config=config,
198
- )
199
- model.to(device)
200
- model.train()
201
-
202
- # optimizer
203
- logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
204
- param_optimizer = model.parameters()
205
- optimizer = torch.optim.Adam(
206
- param_optimizer,
207
- lr=args.learning_rate,
208
- )
209
- # lr_scheduler = torch.optim.lr_scheduler.StepLR(
210
- # optimizer,
211
- # step_size=2000
212
- # )
213
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
214
- optimizer,
215
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
216
- )
217
- mse_loss = nn.MSELoss(
218
- reduction="mean",
219
- )
220
-
221
- # training loop
222
- logger.info("training")
223
-
224
- training_loss = 10000000000
225
- evaluation_loss = 10000000000
226
-
227
- model_list = list()
228
- best_idx_epoch = None
229
- best_metric = None
230
- patience_count = 0
231
-
232
- for idx_epoch in range(args.max_epochs):
233
- total_loss = 0.
234
- total_examples = 0.
235
- progress_bar = tqdm(
236
- total=len(train_data_loader),
237
- desc="Training; epoch: {}".format(idx_epoch),
238
- )
239
-
240
- for batch in train_data_loader:
241
- mix_spec, speech_irm, snr_db = batch
242
- mix_spec = mix_spec.to(device)
243
- speech_irm_target = speech_irm.to(device)
244
- snr_db_target = snr_db.to(device)
245
-
246
- speech_irm_prediction = model.forward(mix_spec)
247
- loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
248
-
249
- total_loss += loss.item()
250
- total_examples += mix_spec.size(0)
251
-
252
- optimizer.zero_grad()
253
- loss.backward()
254
- optimizer.step()
255
- lr_scheduler.step()
256
-
257
- training_loss = total_loss / total_examples
258
- training_loss = round(training_loss, 4)
259
-
260
- progress_bar.update(1)
261
- progress_bar.set_postfix({
262
- "training_loss": training_loss,
263
- })
264
-
265
- total_loss = 0.
266
- total_examples = 0.
267
- progress_bar = tqdm(
268
- total=len(valid_data_loader),
269
- desc="Evaluation; epoch: {}".format(idx_epoch),
270
- )
271
- for batch in valid_data_loader:
272
- mix_spec, speech_irm, snr_db = batch
273
- mix_spec = mix_spec.to(device)
274
- speech_irm_target = speech_irm.to(device)
275
- snr_db_target = snr_db.to(device)
276
-
277
- with torch.no_grad():
278
- speech_irm_prediction = model.forward(mix_spec)
279
- loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
280
-
281
- total_loss += loss.item()
282
- total_examples += mix_spec.size(0)
283
-
284
- evaluation_loss = total_loss / total_examples
285
- evaluation_loss = round(evaluation_loss, 4)
286
-
287
- progress_bar.update(1)
288
- progress_bar.set_postfix({
289
- "evaluation_loss": evaluation_loss,
290
- })
291
-
292
- # save path
293
- epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
294
- epoch_dir.mkdir(parents=True, exist_ok=False)
295
-
296
- # save models
297
- model.save_pretrained(epoch_dir.as_posix())
298
-
299
- model_list.append(epoch_dir)
300
- if len(model_list) >= args.num_serialized_models_to_keep:
301
- model_to_delete: Path = model_list.pop(0)
302
- shutil.rmtree(model_to_delete.as_posix())
303
-
304
- # save metric
305
- if best_metric is None:
306
- best_idx_epoch = idx_epoch
307
- best_metric = evaluation_loss
308
- elif evaluation_loss < best_metric:
309
- best_idx_epoch = idx_epoch
310
- best_metric = evaluation_loss
311
- else:
312
- pass
313
-
314
- metrics = {
315
- "idx_epoch": idx_epoch,
316
- "best_idx_epoch": best_idx_epoch,
317
- "training_loss": training_loss,
318
- "evaluation_loss": evaluation_loss,
319
- "learning_rate": optimizer.param_groups[0]["lr"],
320
- }
321
- metrics_filename = epoch_dir / "metrics_epoch.json"
322
- with open(metrics_filename, "w", encoding="utf-8") as f:
323
- json.dump(metrics, f, indent=4, ensure_ascii=False)
324
-
325
- # save best
326
- best_dir = serialization_dir / "best"
327
- if best_idx_epoch == idx_epoch:
328
- if best_dir.exists():
329
- shutil.rmtree(best_dir)
330
- shutil.copytree(epoch_dir, best_dir)
331
-
332
- # early stop
333
- early_stop_flag = False
334
- if best_idx_epoch == idx_epoch:
335
- patience_count = 0
336
- else:
337
- patience_count += 1
338
- if patience_count >= args.patience:
339
- early_stop_flag = True
340
-
341
- # early stop
342
- if early_stop_flag:
343
- break
344
- return
345
-
346
-
347
- if __name__ == '__main__':
348
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/simple_linear_irm_aishell/step_3_evaluation.py DELETED
@@ -1,239 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import logging
5
- import os
6
- from pathlib import Path
7
- import sys
8
- import uuid
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import librosa
14
- import numpy as np
15
- import pandas as pd
16
- from scipy.io import wavfile
17
- import torch
18
- import torch.nn as nn
19
- import torchaudio
20
- from tqdm import tqdm
21
-
22
- from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel
23
-
24
-
25
- def get_args():
26
- parser = argparse.ArgumentParser()
27
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
28
- parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
29
- parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
30
-
31
- parser.add_argument("--limit", default=10, type=int)
32
-
33
- args = parser.parse_args()
34
- return args
35
-
36
-
37
- def logging_config():
38
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
39
-
40
- logging.basicConfig(format=fmt,
41
- datefmt="%m/%d/%Y %H:%M:%S",
42
- level=logging.INFO)
43
- stream_handler = logging.StreamHandler()
44
- stream_handler.setLevel(logging.INFO)
45
- stream_handler.setFormatter(logging.Formatter(fmt))
46
-
47
- logger = logging.getLogger(__name__)
48
-
49
- return logger
50
-
51
-
52
- def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
53
- l1 = len(speech)
54
- l2 = len(noise)
55
- l = min(l1, l2)
56
- speech = speech[:l]
57
- noise = noise[:l]
58
-
59
- # np.float32, value between (-1, 1).
60
-
61
- speech_power = np.mean(np.square(speech))
62
- noise_power = speech_power / (10 ** (snr_db / 10))
63
-
64
- noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
65
-
66
- noisy_signal = speech + noise_adjusted
67
-
68
- return noisy_signal
69
-
70
-
71
- stft_power = torchaudio.transforms.Spectrogram(
72
- n_fft=512,
73
- win_length=200,
74
- hop_length=80,
75
- power=2.0,
76
- window_fn=torch.hamming_window,
77
- )
78
-
79
-
80
- stft_complex = torchaudio.transforms.Spectrogram(
81
- n_fft=512,
82
- win_length=200,
83
- hop_length=80,
84
- power=None,
85
- window_fn=torch.hamming_window,
86
- )
87
-
88
-
89
- istft = torchaudio.transforms.InverseSpectrogram(
90
- n_fft=512,
91
- win_length=200,
92
- hop_length=80,
93
- window_fn=torch.hamming_window,
94
- )
95
-
96
-
97
- def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
98
- mix_spec_complex = mix_spec_complex.detach().cpu()
99
- speech_irm_prediction = speech_irm_prediction.detach().cpu()
100
-
101
- mask_speech = speech_irm_prediction
102
- mask_noise = 1.0 - speech_irm_prediction
103
-
104
- speech_spec = mix_spec_complex * mask_speech
105
- noise_spec = mix_spec_complex * mask_noise
106
-
107
- speech_wave = istft.forward(speech_spec)
108
- noise_wave = istft.forward(noise_spec)
109
-
110
- return speech_wave, noise_wave
111
-
112
-
113
- def save_audios(noise_wave: torch.Tensor,
114
- speech_wave: torch.Tensor,
115
- mix_wave: torch.Tensor,
116
- speech_wave_enhanced: torch.Tensor,
117
- noise_wave_enhanced: torch.Tensor,
118
- output_dir: str,
119
- sample_rate: int = 8000,
120
- ):
121
- basename = uuid.uuid4().__str__()
122
- output_dir = Path(output_dir) / basename
123
- output_dir.mkdir(parents=True, exist_ok=True)
124
-
125
- filename = output_dir / "noise_wave.wav"
126
- torchaudio.save(filename, noise_wave, sample_rate)
127
- filename = output_dir / "speech_wave.wav"
128
- torchaudio.save(filename, speech_wave, sample_rate)
129
- filename = output_dir / "mix_wave.wav"
130
- torchaudio.save(filename, mix_wave, sample_rate)
131
-
132
- filename = output_dir / "speech_wave_enhanced.wav"
133
- torchaudio.save(filename, speech_wave_enhanced, sample_rate)
134
- filename = output_dir / "noise_wave_enhanced.wav"
135
- torchaudio.save(filename, noise_wave_enhanced, sample_rate)
136
-
137
- return output_dir.as_posix()
138
-
139
-
140
- def main():
141
- args = get_args()
142
-
143
- logger = logging_config()
144
-
145
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
- n_gpu = torch.cuda.device_count()
147
- logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
148
-
149
- logger.info("prepare model")
150
- model = SimpleLinearIRMPretrainedModel.from_pretrained(
151
- pretrained_model_name_or_path=args.model_dir,
152
- )
153
- model.to(device)
154
- model.eval()
155
-
156
- # optimizer
157
- logger.info("prepare loss_fn")
158
- mse_loss = nn.MSELoss(
159
- reduction="mean",
160
- )
161
-
162
- logger.info("read excel")
163
- df = pd.read_excel(args.valid_dataset)
164
-
165
- total_loss = 0.
166
- total_examples = 0.
167
- progress_bar = tqdm(total=len(df), desc="Evaluation")
168
- for idx, row in df.iterrows():
169
- noise_filename = row["noise_filename"]
170
- noise_offset = row["noise_offset"]
171
- noise_duration = row["noise_duration"]
172
-
173
- speech_filename = row["speech_filename"]
174
- speech_offset = row["speech_offset"]
175
- speech_duration = row["speech_duration"]
176
-
177
- snr_db = row["snr_db"]
178
-
179
- noise_wave, _ = librosa.load(
180
- noise_filename,
181
- sr=8000,
182
- offset=noise_offset,
183
- duration=noise_duration,
184
- )
185
- speech_wave, _ = librosa.load(
186
- speech_filename,
187
- sr=8000,
188
- offset=speech_offset,
189
- duration=speech_duration,
190
- )
191
- mix_wave: np.ndarray = mix_speech_and_noise(
192
- speech=speech_wave,
193
- noise=noise_wave,
194
- snr_db=snr_db,
195
- )
196
- noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
197
- speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
198
- mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
199
-
200
- noise_wave = noise_wave.unsqueeze(dim=0)
201
- speech_wave = speech_wave.unsqueeze(dim=0)
202
- mix_wave = mix_wave.unsqueeze(dim=0)
203
-
204
- noise_spec: torch.Tensor = stft_power.forward(noise_wave)
205
- speech_spec: torch.Tensor = stft_power.forward(speech_wave)
206
- mix_spec: torch.Tensor = stft_power.forward(mix_wave)
207
- mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
208
-
209
- speech_irm = speech_spec / (noise_spec + speech_spec)
210
- speech_irm = torch.pow(speech_irm, 1.0)
211
-
212
- mix_spec = mix_spec.to(device)
213
- speech_irm_target = speech_irm.to(device)
214
- with torch.no_grad():
215
- speech_irm_prediction = model.forward(mix_spec)
216
- loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
217
-
218
- speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
219
- save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
220
-
221
- total_loss += loss.item()
222
- total_examples += mix_spec.size(0)
223
-
224
- evaluation_loss = total_loss / total_examples
225
- evaluation_loss = round(evaluation_loss, 4)
226
-
227
- progress_bar.update(1)
228
- progress_bar.set_postfix({
229
- "evaluation_loss": evaluation_loss,
230
- })
231
-
232
- if idx > args.limit:
233
- break
234
-
235
- return
236
-
237
-
238
- if __name__ == '__main__':
239
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/simple_linear_irm_aishell/yaml/config.yaml DELETED
@@ -1,13 +0,0 @@
1
- model_name: "simple_linear_irm"
2
-
3
- # spec
4
- sample_rate: 8000
5
- n_fft: 512
6
- win_length: 200
7
- hop_length: 80
8
-
9
- # model
10
- num_bins: 257
11
- hidden_size: 2048
12
- lookback: 3
13
- lookahead: 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_dfnet_aishell/run.sh DELETED
@@ -1,178 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- : <<'END'
4
-
5
-
6
- sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
7
- --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
- --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
-
10
-
11
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
12
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
-
15
- sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
-
19
-
20
- END
21
-
22
-
23
- # params
24
- system_version="windows";
25
- verbose=true;
26
- stage=0 # start from 0 if you need to start from data preparation
27
- stop_stage=9
28
-
29
- work_dir="$(pwd)"
30
- file_folder_name=file_folder_name
31
- final_model_name=final_model_name
32
- config_file="yaml/config.yaml"
33
- limit=10
34
-
35
- noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
36
- speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
37
-
38
- nohup_name=nohup.out
39
-
40
- # model params
41
- batch_size=64
42
- max_epochs=200
43
- save_top_k=10
44
- patience=5
45
-
46
-
47
- # parse options
48
- while true; do
49
- [ -z "${1:-}" ] && break; # break if there are no arguments
50
- case "$1" in
51
- --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
52
- eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
53
- old_value="(eval echo \\$$name)";
54
- if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
55
- was_bool=true;
56
- else
57
- was_bool=false;
58
- fi
59
-
60
- # Set the variable to the right value-- the escaped quotes make it work if
61
- # the option had spaces, like --cmd "queue.pl -sync y"
62
- eval "${name}=\"$2\"";
63
-
64
- # Check that Boolean-valued arguments are really Boolean.
65
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
66
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
67
- exit 1;
68
- fi
69
- shift 2;
70
- ;;
71
-
72
- *) break;
73
- esac
74
- done
75
-
76
- file_dir="${work_dir}/${file_folder_name}"
77
- final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
78
- evaluation_audio_dir="${file_dir}/evaluation_audio"
79
-
80
- dataset="${file_dir}/dataset.xlsx"
81
- train_dataset="${file_dir}/train.xlsx"
82
- valid_dataset="${file_dir}/valid.xlsx"
83
-
84
- $verbose && echo "system_version: ${system_version}"
85
- $verbose && echo "file_folder_name: ${file_folder_name}"
86
-
87
- if [ $system_version == "windows" ]; then
88
- alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
89
- elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
90
- #source /data/local/bin/nx_denoise/bin/activate
91
- alias python3='/data/local/bin/nx_denoise/bin/python3'
92
- fi
93
-
94
-
95
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
96
- $verbose && echo "stage 1: prepare data"
97
- cd "${work_dir}" || exit 1
98
- python3 step_1_prepare_data.py \
99
- --file_dir "${file_dir}" \
100
- --noise_dir "${noise_dir}" \
101
- --speech_dir "${speech_dir}" \
102
- --train_dataset "${train_dataset}" \
103
- --valid_dataset "${valid_dataset}" \
104
-
105
- fi
106
-
107
-
108
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
109
- $verbose && echo "stage 2: train model"
110
- cd "${work_dir}" || exit 1
111
- python3 step_2_train_model.py \
112
- --train_dataset "${train_dataset}" \
113
- --valid_dataset "${valid_dataset}" \
114
- --serialization_dir "${file_dir}" \
115
- --config_file "${config_file}" \
116
-
117
- fi
118
-
119
-
120
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
121
- $verbose && echo "stage 3: test model"
122
- cd "${work_dir}" || exit 1
123
- python3 step_3_evaluation.py \
124
- --valid_dataset "${valid_dataset}" \
125
- --model_dir "${file_dir}/best" \
126
- --evaluation_audio_dir "${evaluation_audio_dir}" \
127
- --limit "${limit}" \
128
-
129
- fi
130
-
131
-
132
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
133
- $verbose && echo "stage 4: export model"
134
- cd "${work_dir}" || exit 1
135
- python3 step_5_export_models.py \
136
- --vocabulary_dir "${vocabulary_dir}" \
137
- --model_dir "${file_dir}/best" \
138
- --serialization_dir "${file_dir}" \
139
-
140
- fi
141
-
142
-
143
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
144
- $verbose && echo "stage 5: collect files"
145
- cd "${work_dir}" || exit 1
146
-
147
- mkdir -p ${final_model_dir}
148
-
149
- cp "${file_dir}/best"/* "${final_model_dir}"
150
- cp -r "${file_dir}/vocabulary" "${final_model_dir}"
151
-
152
- cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
153
-
154
- cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
155
- cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
156
- cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
157
- cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
158
-
159
- cd "${final_model_dir}/.." || exit 1;
160
-
161
- if [ -e "${final_model_name}.zip" ]; then
162
- rm -rf "${final_model_name}_backup.zip"
163
- mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
164
- fi
165
-
166
- zip -r "${final_model_name}.zip" "${final_model_name}"
167
- rm -rf "${final_model_name}"
168
-
169
- fi
170
-
171
-
172
- if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
173
- $verbose && echo "stage 6: clear file_dir"
174
- cd "${work_dir}" || exit 1
175
-
176
- rm -rf "${file_dir}";
177
-
178
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_dfnet_aishell/step_1_prepare_data.py DELETED
@@ -1,197 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import os
5
- from pathlib import Path
6
- import random
7
- import sys
8
- import shutil
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import pandas as pd
14
- from scipy.io import wavfile
15
- from tqdm import tqdm
16
- import librosa
17
-
18
- from project_settings import project_path
19
-
20
-
21
- def get_args():
22
- parser = argparse.ArgumentParser()
23
- parser.add_argument("--file_dir", default="./", type=str)
24
-
25
- parser.add_argument(
26
- "--noise_dir",
27
- default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
- type=str
29
- )
30
- parser.add_argument(
31
- "--speech_dir",
32
- default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
- type=str
34
- )
35
-
36
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
-
39
- parser.add_argument("--duration", default=2.0, type=float)
40
- parser.add_argument("--min_snr_db", default=-10, type=float)
41
- parser.add_argument("--max_snr_db", default=20, type=float)
42
-
43
- parser.add_argument("--target_sample_rate", default=8000, type=int)
44
-
45
- args = parser.parse_args()
46
- return args
47
-
48
-
49
- def filename_generator(data_dir: str):
50
- data_dir = Path(data_dir)
51
- for filename in data_dir.glob("**/*.wav"):
52
- yield filename.as_posix()
53
-
54
-
55
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
- data_dir = Path(data_dir)
57
- for filename in data_dir.glob("**/*.wav"):
58
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
-
61
- if raw_duration < duration:
62
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
- continue
64
- if signal.ndim != 1:
65
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
-
67
- signal_length = len(signal)
68
- win_size = int(duration * sample_rate)
69
- for begin in range(0, signal_length - win_size, win_size):
70
- row = {
71
- "filename": filename.as_posix(),
72
- "raw_duration": round(raw_duration, 4),
73
- "offset": round(begin / sample_rate, 4),
74
- "duration": round(duration, 4),
75
- }
76
- yield row
77
-
78
-
79
- def get_dataset(args):
80
- file_dir = Path(args.file_dir)
81
- file_dir.mkdir(exist_ok=True)
82
-
83
- noise_dir = Path(args.noise_dir)
84
- speech_dir = Path(args.speech_dir)
85
-
86
- noise_generator = target_second_signal_generator(
87
- noise_dir.as_posix(),
88
- duration=args.duration,
89
- sample_rate=args.target_sample_rate
90
- )
91
- speech_generator = target_second_signal_generator(
92
- speech_dir.as_posix(),
93
- duration=args.duration,
94
- sample_rate=args.target_sample_rate
95
- )
96
-
97
- dataset = list()
98
-
99
- count = 0
100
- process_bar = tqdm(desc="build dataset excel")
101
- for noise, speech in zip(noise_generator, speech_generator):
102
-
103
- noise_filename = noise["filename"]
104
- noise_raw_duration = noise["raw_duration"]
105
- noise_offset = noise["offset"]
106
- noise_duration = noise["duration"]
107
-
108
- speech_filename = speech["filename"]
109
- speech_raw_duration = speech["raw_duration"]
110
- speech_offset = speech["offset"]
111
- speech_duration = speech["duration"]
112
-
113
- random1 = random.random()
114
- random2 = random.random()
115
-
116
- row = {
117
- "noise_filename": noise_filename,
118
- "noise_raw_duration": noise_raw_duration,
119
- "noise_offset": noise_offset,
120
- "noise_duration": noise_duration,
121
-
122
- "speech_filename": speech_filename,
123
- "speech_raw_duration": speech_raw_duration,
124
- "speech_offset": speech_offset,
125
- "speech_duration": speech_duration,
126
-
127
- "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
128
-
129
- "random1": random1,
130
- "random2": random2,
131
- "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
- }
133
- dataset.append(row)
134
- count += 1
135
- duration_seconds = count * args.duration
136
- duration_hours = duration_seconds / 3600
137
-
138
- process_bar.update(n=1)
139
- process_bar.set_postfix({
140
- # "duration_seconds": round(duration_seconds, 4),
141
- "duration_hours": round(duration_hours, 4),
142
-
143
- })
144
-
145
- dataset = pd.DataFrame(dataset)
146
- dataset = dataset.sort_values(by=["random1"], ascending=False)
147
- dataset.to_excel(
148
- file_dir / "dataset.xlsx",
149
- index=False,
150
- )
151
- return
152
-
153
-
154
-
155
- def split_dataset(args):
156
- """分割训练集, 测试集"""
157
- file_dir = Path(args.file_dir)
158
- file_dir.mkdir(exist_ok=True)
159
-
160
- df = pd.read_excel(file_dir / "dataset.xlsx")
161
-
162
- train = list()
163
- test = list()
164
-
165
- for i, row in df.iterrows():
166
- flag = row["flag"]
167
- if flag == "TRAIN":
168
- train.append(row)
169
- else:
170
- test.append(row)
171
-
172
- train = pd.DataFrame(train)
173
- train.to_excel(
174
- args.train_dataset,
175
- index=False,
176
- # encoding="utf_8_sig"
177
- )
178
- test = pd.DataFrame(test)
179
- test.to_excel(
180
- args.valid_dataset,
181
- index=False,
182
- # encoding="utf_8_sig"
183
- )
184
-
185
- return
186
-
187
-
188
- def main():
189
- args = get_args()
190
-
191
- get_dataset(args)
192
- split_dataset(args)
193
- return
194
-
195
-
196
- if __name__ == "__main__":
197
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_dfnet_aishell/step_2_train_model.py DELETED
@@ -1,440 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
- """
6
- import argparse
7
- import json
8
- import logging
9
- from logging.handlers import TimedRotatingFileHandler
10
- import os
11
- import platform
12
- from pathlib import Path
13
- import random
14
- import sys
15
- import shutil
16
- from typing import List
17
-
18
- pwd = os.path.abspath(os.path.dirname(__file__))
19
- sys.path.append(os.path.join(pwd, "../../"))
20
-
21
- import numpy as np
22
- import torch
23
- import torch.nn as nn
24
- from torch.nn import functional as F
25
- from torch.utils.data.dataloader import DataLoader
26
- import torchaudio
27
- from tqdm import tqdm
28
-
29
- from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
30
- from toolbox.torchaudio.models.spectrum_dfnet.configuration_spectrum_dfnet import SpectrumDfNetConfig
31
- from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel
32
-
33
-
34
- def get_args():
35
- parser = argparse.ArgumentParser()
36
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
-
39
- parser.add_argument("--max_epochs", default=100, type=int)
40
-
41
- parser.add_argument("--batch_size", default=16, type=int)
42
- parser.add_argument("--learning_rate", default=1e-4, type=float)
43
- parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
- parser.add_argument("--patience", default=5, type=int)
45
- parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
- parser.add_argument("--seed", default=0, type=int)
47
-
48
- parser.add_argument("--config_file", default="config.yaml", type=str)
49
-
50
- args = parser.parse_args()
51
- return args
52
-
53
-
54
- def logging_config(file_dir: str):
55
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
56
-
57
- logging.basicConfig(format=fmt,
58
- datefmt="%m/%d/%Y %H:%M:%S",
59
- level=logging.INFO)
60
- file_handler = TimedRotatingFileHandler(
61
- filename=os.path.join(file_dir, "main.log"),
62
- encoding="utf-8",
63
- when="D",
64
- interval=1,
65
- backupCount=7
66
- )
67
- file_handler.setLevel(logging.INFO)
68
- file_handler.setFormatter(logging.Formatter(fmt))
69
- logger = logging.getLogger(__name__)
70
- logger.addHandler(file_handler)
71
-
72
- return logger
73
-
74
-
75
- class CollateFunction(object):
76
- def __init__(self,
77
- n_fft: int = 512,
78
- win_length: int = 200,
79
- hop_length: int = 80,
80
- window_fn: str = "hamming",
81
- irm_beta: float = 1.0,
82
- epsilon: float = 1e-8,
83
- ):
84
- self.n_fft = n_fft
85
- self.win_length = win_length
86
- self.hop_length = hop_length
87
- self.window_fn = window_fn
88
- self.irm_beta = irm_beta
89
- self.epsilon = epsilon
90
-
91
- self.complex_transform = torchaudio.transforms.Spectrogram(
92
- n_fft=self.n_fft,
93
- win_length=self.win_length,
94
- hop_length=self.hop_length,
95
- power=None,
96
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
97
- )
98
- self.transform = torchaudio.transforms.Spectrogram(
99
- n_fft=self.n_fft,
100
- win_length=self.win_length,
101
- hop_length=self.hop_length,
102
- power=2.0,
103
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
104
- )
105
-
106
- @staticmethod
107
- def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
108
- batch_size, channels, freq_dim, time_steps = x.shape
109
-
110
- # kernel: [freq_dim, n_time_step]
111
- kernel_size = (freq_dim, n_time_steps)
112
-
113
- # pad
114
- pad = n_time_steps // 2
115
- x = torch.concat(tensors=[
116
- x[:, :, :, :pad],
117
- x,
118
- x[:, :, :, -pad:],
119
- ], dim=-1)
120
-
121
- x = F.unfold(
122
- input=x,
123
- kernel_size=kernel_size,
124
- )
125
- # x shape: [batch_size, fold, time_steps]
126
- return x
127
-
128
- def __call__(self, batch: List[dict]):
129
- speech_complex_spec_list = list()
130
- mix_complex_spec_list = list()
131
- speech_irm_list = list()
132
- snr_db_list = list()
133
- for sample in batch:
134
- noise_wave: torch.Tensor = sample["noise_wave"]
135
- speech_wave: torch.Tensor = sample["speech_wave"]
136
- mix_wave: torch.Tensor = sample["mix_wave"]
137
- # snr_db: float = sample["snr_db"]
138
-
139
- noise_spec = self.transform.forward(noise_wave)
140
- speech_spec = self.transform.forward(speech_wave)
141
-
142
- speech_complex_spec = self.complex_transform.forward(speech_wave)
143
- mix_complex_spec = self.complex_transform.forward(mix_wave)
144
-
145
- # noise_irm = noise_spec / (noise_spec + speech_spec)
146
- speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
147
- speech_irm = torch.pow(speech_irm, self.irm_beta)
148
-
149
- # noise_spec, speech_spec, mix_spec, speech_irm
150
- # shape: [freq_dim, time_steps]
151
-
152
- snr_db: torch.Tensor = 10 * torch.log10(
153
- speech_spec / (noise_spec + self.epsilon)
154
- )
155
- snr_db = torch.clamp(snr_db, min=self.epsilon)
156
-
157
- snr_db_ = torch.unsqueeze(snr_db, dim=0)
158
- snr_db_ = torch.unsqueeze(snr_db_, dim=0)
159
- snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
160
- snr_db_ = torch.squeeze(snr_db_, dim=0)
161
- # snr_db_ shape: [fold, time_steps]
162
-
163
- snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
164
- # snr_db shape: [1, time_steps]
165
-
166
- speech_complex_spec_list.append(speech_complex_spec)
167
- mix_complex_spec_list.append(mix_complex_spec)
168
- speech_irm_list.append(speech_irm)
169
- snr_db_list.append(snr_db)
170
-
171
- speech_complex_spec_list = torch.stack(speech_complex_spec_list)
172
- mix_complex_spec_list = torch.stack(mix_complex_spec_list)
173
- speech_irm_list = torch.stack(speech_irm_list)
174
- snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
175
-
176
- speech_complex_spec_list = speech_complex_spec_list[:, :-1, :]
177
- mix_complex_spec_list = mix_complex_spec_list[:, :-1, :]
178
- speech_irm_list = speech_irm_list[:, :-1, :]
179
-
180
- # speech_complex_spec_list shape: [batch_size, freq_dim, time_steps]
181
- # mix_complex_spec_list shape: [batch_size, freq_dim, time_steps]
182
- # speech_irm_list shape: [batch_size, freq_dim, time_steps]
183
- # snr_db shape: [batch_size, 1, time_steps]
184
-
185
- # assert
186
- if torch.any(torch.isnan(speech_complex_spec_list)) or torch.any(torch.isinf(speech_complex_spec_list)):
187
- raise AssertionError("nan or inf in speech_complex_spec_list")
188
- if torch.any(torch.isnan(mix_complex_spec_list)) or torch.any(torch.isinf(mix_complex_spec_list)):
189
- raise AssertionError("nan or inf in mix_complex_spec_list")
190
- if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
191
- raise AssertionError("nan or inf in speech_irm_list")
192
- if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
193
- raise AssertionError("nan or inf in snr_db_list")
194
-
195
- return speech_complex_spec_list, mix_complex_spec_list, speech_irm_list, snr_db_list
196
-
197
-
198
- collate_fn = CollateFunction()
199
-
200
-
201
- def main():
202
- args = get_args()
203
-
204
- serialization_dir = Path(args.serialization_dir)
205
- serialization_dir.mkdir(parents=True, exist_ok=True)
206
-
207
- logger = logging_config(serialization_dir)
208
-
209
- random.seed(args.seed)
210
- np.random.seed(args.seed)
211
- torch.manual_seed(args.seed)
212
- logger.info("set seed: {}".format(args.seed))
213
-
214
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
215
- n_gpu = torch.cuda.device_count()
216
- logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
217
-
218
- # datasets
219
- logger.info("prepare datasets")
220
- train_dataset = DenoiseExcelDataset(
221
- excel_file=args.train_dataset,
222
- expected_sample_rate=8000,
223
- max_wave_value=32768.0,
224
- )
225
- valid_dataset = DenoiseExcelDataset(
226
- excel_file=args.valid_dataset,
227
- expected_sample_rate=8000,
228
- max_wave_value=32768.0,
229
- )
230
- train_data_loader = DataLoader(
231
- dataset=train_dataset,
232
- batch_size=args.batch_size,
233
- shuffle=True,
234
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
235
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
236
- collate_fn=collate_fn,
237
- pin_memory=False,
238
- # prefetch_factor=64,
239
- )
240
- valid_data_loader = DataLoader(
241
- dataset=valid_dataset,
242
- batch_size=args.batch_size,
243
- shuffle=True,
244
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
245
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
246
- collate_fn=collate_fn,
247
- pin_memory=False,
248
- # prefetch_factor=64,
249
- )
250
-
251
- # models
252
- logger.info(f"prepare models. config_file: {args.config_file}")
253
- config = SpectrumDfNetConfig.from_pretrained(
254
- pretrained_model_name_or_path=args.config_file,
255
- # num_labels=vocabulary.get_vocab_size(namespace="labels")
256
- )
257
- model = SpectrumDfNetPretrainedModel(
258
- config=config,
259
- )
260
- model.to(device)
261
- model.train()
262
-
263
- # optimizer
264
- logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
265
- param_optimizer = model.parameters()
266
- optimizer = torch.optim.Adam(
267
- param_optimizer,
268
- lr=args.learning_rate,
269
- )
270
- # lr_scheduler = torch.optim.lr_scheduler.StepLR(
271
- # optimizer,
272
- # step_size=2000
273
- # )
274
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
275
- optimizer,
276
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
277
- )
278
-
279
- speech_mse_loss = nn.MSELoss(
280
- reduction="mean",
281
- )
282
- irm_mse_loss = nn.MSELoss(
283
- reduction="mean",
284
- )
285
- snr_mse_loss = nn.MSELoss(
286
- reduction="mean",
287
- )
288
-
289
- # training loop
290
- logger.info("training")
291
-
292
- training_loss = 10000000000
293
- evaluation_loss = 10000000000
294
-
295
- model_list = list()
296
- best_idx_epoch = None
297
- best_metric = None
298
- patience_count = 0
299
-
300
- for idx_epoch in range(args.max_epochs):
301
- total_loss = 0.
302
- total_examples = 0.
303
- progress_bar = tqdm(
304
- total=len(train_data_loader),
305
- desc="Training; epoch: {}".format(idx_epoch),
306
- )
307
-
308
- for batch in train_data_loader:
309
- speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch
310
- speech_complex_spec = speech_complex_spec.to(device)
311
- mix_complex_spec = mix_complex_spec.to(device)
312
- speech_irm_target = speech_irm.to(device)
313
- snr_db_target = snr_db.to(device)
314
-
315
- speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
- if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
- raise AssertionError("nan or inf in speech_spec_prediction")
318
- if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
- raise AssertionError("nan or inf in speech_irm_prediction")
320
- if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
- raise AssertionError("nan or inf in lsnr_prediction")
322
-
323
- speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
- irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
-
327
- loss = speech_loss + irm_loss + snr_loss
328
-
329
- total_loss += loss.item()
330
- total_examples += mix_complex_spec.size(0)
331
-
332
- optimizer.zero_grad()
333
- loss.backward()
334
- optimizer.step()
335
- lr_scheduler.step()
336
-
337
- training_loss = total_loss / total_examples
338
- training_loss = round(training_loss, 4)
339
-
340
- progress_bar.update(1)
341
- progress_bar.set_postfix({
342
- "training_loss": training_loss,
343
- })
344
-
345
- total_loss = 0.
346
- total_examples = 0.
347
- progress_bar = tqdm(
348
- total=len(valid_data_loader),
349
- desc="Evaluation; epoch: {}".format(idx_epoch),
350
- )
351
- for batch in valid_data_loader:
352
- speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch
353
- speech_complex_spec = speech_complex_spec.to(device)
354
- mix_complex_spec = mix_complex_spec.to(device)
355
- speech_irm_target = speech_irm.to(device)
356
- snr_db_target = snr_db.to(device)
357
-
358
- with torch.no_grad():
359
- speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
360
- if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
361
- raise AssertionError("nan or inf in speech_spec_prediction")
362
- if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
363
- raise AssertionError("nan or inf in speech_irm_prediction")
364
- if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
365
- raise AssertionError("nan or inf in lsnr_prediction")
366
-
367
- speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
368
- irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
369
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
370
-
371
- loss = speech_loss + irm_loss + snr_loss
372
-
373
- total_loss += loss.item()
374
- total_examples += mix_complex_spec.size(0)
375
-
376
- evaluation_loss = total_loss / total_examples
377
- evaluation_loss = round(evaluation_loss, 4)
378
-
379
- progress_bar.update(1)
380
- progress_bar.set_postfix({
381
- "evaluation_loss": evaluation_loss,
382
- })
383
-
384
- # save path
385
- epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
386
- epoch_dir.mkdir(parents=True, exist_ok=False)
387
-
388
- # save models
389
- model.save_pretrained(epoch_dir.as_posix())
390
-
391
- model_list.append(epoch_dir)
392
- if len(model_list) >= args.num_serialized_models_to_keep:
393
- model_to_delete: Path = model_list.pop(0)
394
- shutil.rmtree(model_to_delete.as_posix())
395
-
396
- # save metric
397
- if best_metric is None:
398
- best_idx_epoch = idx_epoch
399
- best_metric = evaluation_loss
400
- elif evaluation_loss < best_metric:
401
- best_idx_epoch = idx_epoch
402
- best_metric = evaluation_loss
403
- else:
404
- pass
405
-
406
- metrics = {
407
- "idx_epoch": idx_epoch,
408
- "best_idx_epoch": best_idx_epoch,
409
- "training_loss": training_loss,
410
- "evaluation_loss": evaluation_loss,
411
- "learning_rate": optimizer.param_groups[0]["lr"],
412
- }
413
- metrics_filename = epoch_dir / "metrics_epoch.json"
414
- with open(metrics_filename, "w", encoding="utf-8") as f:
415
- json.dump(metrics, f, indent=4, ensure_ascii=False)
416
-
417
- # save best
418
- best_dir = serialization_dir / "best"
419
- if best_idx_epoch == idx_epoch:
420
- if best_dir.exists():
421
- shutil.rmtree(best_dir)
422
- shutil.copytree(epoch_dir, best_dir)
423
-
424
- # early stop
425
- early_stop_flag = False
426
- if best_idx_epoch == idx_epoch:
427
- patience_count = 0
428
- else:
429
- patience_count += 1
430
- if patience_count >= args.patience:
431
- early_stop_flag = True
432
-
433
- # early stop
434
- if early_stop_flag:
435
- break
436
- return
437
-
438
-
439
- if __name__ == '__main__':
440
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_dfnet_aishell/step_3_evaluation.py DELETED
@@ -1,302 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import logging
5
- import os
6
- from pathlib import Path
7
- import sys
8
- import uuid
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import librosa
14
- import numpy as np
15
- import pandas as pd
16
- from scipy.io import wavfile
17
- import torch
18
- import torch.nn as nn
19
- import torchaudio
20
- from tqdm import tqdm
21
-
22
- from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel
23
-
24
-
25
- def get_args():
26
- parser = argparse.ArgumentParser()
27
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
28
- parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
29
- parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
30
-
31
- parser.add_argument("--limit", default=10, type=int)
32
-
33
- args = parser.parse_args()
34
- return args
35
-
36
-
37
- def logging_config():
38
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
39
-
40
- logging.basicConfig(format=fmt,
41
- datefmt="%m/%d/%Y %H:%M:%S",
42
- level=logging.INFO)
43
- stream_handler = logging.StreamHandler()
44
- stream_handler.setLevel(logging.INFO)
45
- stream_handler.setFormatter(logging.Formatter(fmt))
46
-
47
- logger = logging.getLogger(__name__)
48
-
49
- return logger
50
-
51
-
52
- def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
53
- l1 = len(speech)
54
- l2 = len(noise)
55
- l = min(l1, l2)
56
- speech = speech[:l]
57
- noise = noise[:l]
58
-
59
- # np.float32, value between (-1, 1).
60
-
61
- speech_power = np.mean(np.square(speech))
62
- noise_power = speech_power / (10 ** (snr_db / 10))
63
-
64
- noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
65
-
66
- noisy_signal = speech + noise_adjusted
67
-
68
- return noisy_signal
69
-
70
-
71
- stft_power = torchaudio.transforms.Spectrogram(
72
- n_fft=512,
73
- win_length=200,
74
- hop_length=80,
75
- power=2.0,
76
- window_fn=torch.hamming_window,
77
- )
78
-
79
-
80
- stft_complex = torchaudio.transforms.Spectrogram(
81
- n_fft=512,
82
- win_length=200,
83
- hop_length=80,
84
- power=None,
85
- window_fn=torch.hamming_window,
86
- )
87
-
88
-
89
- istft = torchaudio.transforms.InverseSpectrogram(
90
- n_fft=512,
91
- win_length=200,
92
- hop_length=80,
93
- window_fn=torch.hamming_window,
94
- )
95
-
96
-
97
- def enhance(mix_spec_complex: torch.Tensor,
98
- speech_spec_prediction: torch.Tensor,
99
- speech_irm_prediction: torch.Tensor,
100
- ):
101
- mix_spec_complex = mix_spec_complex.detach().cpu()
102
- speech_spec_prediction = speech_spec_prediction.detach().cpu()
103
- speech_irm_prediction = speech_irm_prediction.detach().cpu()
104
-
105
- mask_speech = speech_irm_prediction
106
- mask_noise = 1.0 - speech_irm_prediction
107
-
108
- speech_spec = mix_spec_complex * mask_speech
109
- noise_spec = mix_spec_complex * mask_noise
110
-
111
- # print(f"speech_spec_prediction: {speech_spec_prediction.shape}")
112
- # print(f"noise_spec: {noise_spec.shape}")
113
-
114
- speech_wave = istft.forward(speech_spec_prediction)
115
- # speech_wave = istft.forward(speech_spec)
116
- noise_wave = istft.forward(noise_spec)
117
-
118
- return speech_wave, noise_wave
119
-
120
-
121
- def save_audios(noise_wave: torch.Tensor,
122
- speech_wave: torch.Tensor,
123
- mix_wave: torch.Tensor,
124
- speech_wave_enhanced: torch.Tensor,
125
- noise_wave_enhanced: torch.Tensor,
126
- output_dir: str,
127
- sample_rate: int = 8000,
128
- ):
129
- basename = uuid.uuid4().__str__()
130
- output_dir = Path(output_dir) / basename
131
- output_dir.mkdir(parents=True, exist_ok=True)
132
-
133
- filename = output_dir / "noise_wave.wav"
134
- torchaudio.save(filename, noise_wave, sample_rate)
135
- filename = output_dir / "speech_wave.wav"
136
- torchaudio.save(filename, speech_wave, sample_rate)
137
- filename = output_dir / "mix_wave.wav"
138
- torchaudio.save(filename, mix_wave, sample_rate)
139
-
140
- filename = output_dir / "speech_wave_enhanced.wav"
141
- torchaudio.save(filename, speech_wave_enhanced, sample_rate)
142
- filename = output_dir / "noise_wave_enhanced.wav"
143
- torchaudio.save(filename, noise_wave_enhanced, sample_rate)
144
-
145
- return output_dir.as_posix()
146
-
147
-
148
- def main():
149
- args = get_args()
150
-
151
- logger = logging_config()
152
-
153
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
- n_gpu = torch.cuda.device_count()
155
- logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
156
-
157
- logger.info("prepare model")
158
- model = SpectrumDfNetPretrainedModel.from_pretrained(
159
- pretrained_model_name_or_path=args.model_dir,
160
- )
161
- model.to(device)
162
- model.eval()
163
-
164
- # optimizer
165
- logger.info("prepare loss_fn")
166
- irm_mse_loss = nn.MSELoss(
167
- reduction="mean",
168
- )
169
- snr_mse_loss = nn.MSELoss(
170
- reduction="mean",
171
- )
172
-
173
- logger.info("read excel")
174
- df = pd.read_excel(args.valid_dataset)
175
-
176
- total_loss = 0.
177
- total_examples = 0.
178
- progress_bar = tqdm(total=len(df), desc="Evaluation")
179
- for idx, row in df.iterrows():
180
- noise_filename = row["noise_filename"]
181
- noise_offset = row["noise_offset"]
182
- noise_duration = row["noise_duration"]
183
-
184
- speech_filename = row["speech_filename"]
185
- speech_offset = row["speech_offset"]
186
- speech_duration = row["speech_duration"]
187
-
188
- snr_db = row["snr_db"]
189
-
190
- noise_wave, _ = librosa.load(
191
- noise_filename,
192
- sr=8000,
193
- offset=noise_offset,
194
- duration=noise_duration,
195
- )
196
- speech_wave, _ = librosa.load(
197
- speech_filename,
198
- sr=8000,
199
- offset=speech_offset,
200
- duration=speech_duration,
201
- )
202
- mix_wave: np.ndarray = mix_speech_and_noise(
203
- speech=speech_wave,
204
- noise=noise_wave,
205
- snr_db=snr_db,
206
- )
207
- noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
208
- speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
209
- mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
210
-
211
- noise_wave = noise_wave.unsqueeze(dim=0)
212
- speech_wave = speech_wave.unsqueeze(dim=0)
213
- mix_wave = mix_wave.unsqueeze(dim=0)
214
-
215
- noise_spec: torch.Tensor = stft_power.forward(noise_wave)
216
- speech_spec: torch.Tensor = stft_power.forward(speech_wave)
217
- mix_spec: torch.Tensor = stft_power.forward(mix_wave)
218
-
219
- speech_spec_complex: torch.Tensor = stft_complex.forward(speech_wave)
220
- mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
221
- # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
222
-
223
- noise_spec = noise_spec[:, :-1, :]
224
- speech_spec = speech_spec[:, :-1, :]
225
- mix_spec = mix_spec[:, :-1, :]
226
- speech_spec_complex = speech_spec_complex[:, :-1, :]
227
- mix_spec_complex = mix_spec_complex[:, :-1, :]
228
-
229
- speech_irm = speech_spec / (noise_spec + speech_spec)
230
- speech_irm = torch.pow(speech_irm, 1.0)
231
-
232
- snr_db: torch.Tensor = 10 * torch.log10(
233
- speech_spec / (noise_spec + 1e-8)
234
- )
235
- snr_db = torch.clamp(snr_db, min=1e-8)
236
- snr_db = torch.mean(snr_db, dim=1, keepdim=True)
237
- # snr_db shape: [batch_size, 1, time_steps]
238
-
239
- speech_spec_complex = speech_spec_complex.to(device)
240
- mix_spec_complex = mix_spec_complex.to(device)
241
- mix_spec = mix_spec.to(device)
242
- speech_irm_target = speech_irm.to(device)
243
- snr_db_target = snr_db.to(device)
244
-
245
- with torch.no_grad():
246
- speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
247
- speech_spec_prediction = torch.view_as_complex(speech_spec_prediction)
248
-
249
- irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
250
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
251
- # loss = irm_loss + 0.1 * snr_loss
252
- loss = irm_loss
253
-
254
- # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
255
- # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
256
- batch_size, _, time_steps = speech_irm_prediction.shape
257
-
258
- mix_spec_complex = torch.concat(
259
- [
260
- mix_spec_complex,
261
- torch.zeros(size=(batch_size, 1, time_steps), dtype=mix_spec_complex.dtype).to(device)
262
- ],
263
- dim=1,
264
- )
265
- speech_spec_prediction = torch.concat(
266
- [
267
- speech_spec_prediction,
268
- torch.zeros(size=(batch_size, 1, time_steps), dtype=speech_spec_prediction.dtype).to(device)
269
- ],
270
- dim=1,
271
- )
272
- speech_irm_prediction = torch.concat(
273
- [
274
- speech_irm_prediction,
275
- 0.5 * torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
276
- ],
277
- dim=1,
278
- )
279
-
280
- # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
281
- speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
282
- save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
283
-
284
- total_loss += loss.item()
285
- total_examples += mix_spec.size(0)
286
-
287
- evaluation_loss = total_loss / total_examples
288
- evaluation_loss = round(evaluation_loss, 4)
289
-
290
- progress_bar.update(1)
291
- progress_bar.set_postfix({
292
- "evaluation_loss": evaluation_loss,
293
- })
294
-
295
- if idx > args.limit:
296
- break
297
-
298
- return
299
-
300
-
301
- if __name__ == '__main__':
302
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_dfnet_aishell/yaml/config.yaml DELETED
@@ -1,53 +0,0 @@
1
- model_name: "spectrum_unet_irm"
2
-
3
- # spec
4
- sample_rate: 8000
5
- n_fft: 512
6
- win_length: 200
7
- hop_length: 80
8
-
9
- spec_bins: 256
10
-
11
- # model
12
- conv_channels: 64
13
- conv_kernel_size_input:
14
- - 3
15
- - 3
16
- conv_kernel_size_inner:
17
- - 1
18
- - 3
19
- conv_lookahead: 0
20
-
21
- convt_kernel_size_inner:
22
- - 1
23
- - 3
24
-
25
- embedding_hidden_size: 256
26
- encoder_combine_op: "concat"
27
-
28
- encoder_emb_skip_op: "none"
29
- encoder_emb_linear_groups: 16
30
- encoder_emb_hidden_size: 256
31
-
32
- encoder_linear_groups: 32
33
-
34
- lsnr_max: 30
35
- lsnr_min: -15
36
- norm_tau: 1.
37
-
38
- decoder_emb_num_layers: 3
39
- decoder_emb_skip_op: "none"
40
- decoder_emb_linear_groups: 16
41
- decoder_emb_hidden_size: 256
42
-
43
- df_decoder_hidden_size: 256
44
- df_num_layers: 2
45
- df_order: 5
46
- df_bins: 96
47
- df_gru_skip: "grouped_linear"
48
- df_decoder_linear_groups: 16
49
- df_pathway_kernel_size_t: 5
50
- df_lookahead: 2
51
-
52
- # runtime
53
- use_post_filter: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_unet_irm_aishell/run.sh DELETED
@@ -1,178 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- : <<'END'
4
-
5
-
6
- sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
7
- --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
- --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
-
10
-
11
- sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
12
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
-
15
- sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
-
19
-
20
- END
21
-
22
-
23
- # params
24
- system_version="windows";
25
- verbose=true;
26
- stage=0 # start from 0 if you need to start from data preparation
27
- stop_stage=9
28
-
29
- work_dir="$(pwd)"
30
- file_folder_name=file_folder_name
31
- final_model_name=final_model_name
32
- config_file="yaml/config.yaml"
33
- limit=10
34
-
35
- noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
36
- speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
37
-
38
- nohup_name=nohup.out
39
-
40
- # model params
41
- batch_size=64
42
- max_epochs=200
43
- save_top_k=10
44
- patience=5
45
-
46
-
47
- # parse options
48
- while true; do
49
- [ -z "${1:-}" ] && break; # break if there are no arguments
50
- case "$1" in
51
- --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
52
- eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
53
- old_value="(eval echo \\$$name)";
54
- if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
55
- was_bool=true;
56
- else
57
- was_bool=false;
58
- fi
59
-
60
- # Set the variable to the right value-- the escaped quotes make it work if
61
- # the option had spaces, like --cmd "queue.pl -sync y"
62
- eval "${name}=\"$2\"";
63
-
64
- # Check that Boolean-valued arguments are really Boolean.
65
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
66
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
67
- exit 1;
68
- fi
69
- shift 2;
70
- ;;
71
-
72
- *) break;
73
- esac
74
- done
75
-
76
- file_dir="${work_dir}/${file_folder_name}"
77
- final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
78
- evaluation_audio_dir="${file_dir}/evaluation_audio"
79
-
80
- dataset="${file_dir}/dataset.xlsx"
81
- train_dataset="${file_dir}/train.xlsx"
82
- valid_dataset="${file_dir}/valid.xlsx"
83
-
84
- $verbose && echo "system_version: ${system_version}"
85
- $verbose && echo "file_folder_name: ${file_folder_name}"
86
-
87
- if [ $system_version == "windows" ]; then
88
- alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
89
- elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
90
- #source /data/local/bin/nx_denoise/bin/activate
91
- alias python3='/data/local/bin/nx_denoise/bin/python3'
92
- fi
93
-
94
-
95
- if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
96
- $verbose && echo "stage 1: prepare data"
97
- cd "${work_dir}" || exit 1
98
- python3 step_1_prepare_data.py \
99
- --file_dir "${file_dir}" \
100
- --noise_dir "${noise_dir}" \
101
- --speech_dir "${speech_dir}" \
102
- --train_dataset "${train_dataset}" \
103
- --valid_dataset "${valid_dataset}" \
104
-
105
- fi
106
-
107
-
108
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
109
- $verbose && echo "stage 2: train model"
110
- cd "${work_dir}" || exit 1
111
- python3 step_2_train_model.py \
112
- --train_dataset "${train_dataset}" \
113
- --valid_dataset "${valid_dataset}" \
114
- --serialization_dir "${file_dir}" \
115
- --config_file "${config_file}" \
116
-
117
- fi
118
-
119
-
120
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
121
- $verbose && echo "stage 3: test model"
122
- cd "${work_dir}" || exit 1
123
- python3 step_3_evaluation.py \
124
- --valid_dataset "${valid_dataset}" \
125
- --model_dir "${file_dir}/best" \
126
- --evaluation_audio_dir "${evaluation_audio_dir}" \
127
- --limit "${limit}" \
128
-
129
- fi
130
-
131
-
132
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
133
- $verbose && echo "stage 4: export model"
134
- cd "${work_dir}" || exit 1
135
- python3 step_5_export_models.py \
136
- --vocabulary_dir "${vocabulary_dir}" \
137
- --model_dir "${file_dir}/best" \
138
- --serialization_dir "${file_dir}" \
139
-
140
- fi
141
-
142
-
143
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
144
- $verbose && echo "stage 5: collect files"
145
- cd "${work_dir}" || exit 1
146
-
147
- mkdir -p ${final_model_dir}
148
-
149
- cp "${file_dir}/best"/* "${final_model_dir}"
150
- cp -r "${file_dir}/vocabulary" "${final_model_dir}"
151
-
152
- cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
153
-
154
- cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
155
- cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
156
- cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
157
- cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
158
-
159
- cd "${final_model_dir}/.." || exit 1;
160
-
161
- if [ -e "${final_model_name}.zip" ]; then
162
- rm -rf "${final_model_name}_backup.zip"
163
- mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
164
- fi
165
-
166
- zip -r "${final_model_name}.zip" "${final_model_name}"
167
- rm -rf "${final_model_name}"
168
-
169
- fi
170
-
171
-
172
- if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
173
- $verbose && echo "stage 6: clear file_dir"
174
- cd "${work_dir}" || exit 1
175
-
176
- rm -rf "${file_dir}";
177
-
178
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_unet_irm_aishell/step_1_prepare_data.py DELETED
@@ -1,197 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import os
5
- from pathlib import Path
6
- import random
7
- import sys
8
- import shutil
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import pandas as pd
14
- from scipy.io import wavfile
15
- from tqdm import tqdm
16
- import librosa
17
-
18
- from project_settings import project_path
19
-
20
-
21
- def get_args():
22
- parser = argparse.ArgumentParser()
23
- parser.add_argument("--file_dir", default="./", type=str)
24
-
25
- parser.add_argument(
26
- "--noise_dir",
27
- default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
- type=str
29
- )
30
- parser.add_argument(
31
- "--speech_dir",
32
- default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
- type=str
34
- )
35
-
36
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
-
39
- parser.add_argument("--duration", default=2.0, type=float)
40
- parser.add_argument("--min_snr_db", default=-10, type=float)
41
- parser.add_argument("--max_snr_db", default=20, type=float)
42
-
43
- parser.add_argument("--target_sample_rate", default=8000, type=int)
44
-
45
- args = parser.parse_args()
46
- return args
47
-
48
-
49
- def filename_generator(data_dir: str):
50
- data_dir = Path(data_dir)
51
- for filename in data_dir.glob("**/*.wav"):
52
- yield filename.as_posix()
53
-
54
-
55
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
- data_dir = Path(data_dir)
57
- for filename in data_dir.glob("**/*.wav"):
58
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
-
61
- if raw_duration < duration:
62
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
- continue
64
- if signal.ndim != 1:
65
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
-
67
- signal_length = len(signal)
68
- win_size = int(duration * sample_rate)
69
- for begin in range(0, signal_length - win_size, win_size):
70
- row = {
71
- "filename": filename.as_posix(),
72
- "raw_duration": round(raw_duration, 4),
73
- "offset": round(begin / sample_rate, 4),
74
- "duration": round(duration, 4),
75
- }
76
- yield row
77
-
78
-
79
- def get_dataset(args):
80
- file_dir = Path(args.file_dir)
81
- file_dir.mkdir(exist_ok=True)
82
-
83
- noise_dir = Path(args.noise_dir)
84
- speech_dir = Path(args.speech_dir)
85
-
86
- noise_generator = target_second_signal_generator(
87
- noise_dir.as_posix(),
88
- duration=args.duration,
89
- sample_rate=args.target_sample_rate
90
- )
91
- speech_generator = target_second_signal_generator(
92
- speech_dir.as_posix(),
93
- duration=args.duration,
94
- sample_rate=args.target_sample_rate
95
- )
96
-
97
- dataset = list()
98
-
99
- count = 0
100
- process_bar = tqdm(desc="build dataset excel")
101
- for noise, speech in zip(noise_generator, speech_generator):
102
-
103
- noise_filename = noise["filename"]
104
- noise_raw_duration = noise["raw_duration"]
105
- noise_offset = noise["offset"]
106
- noise_duration = noise["duration"]
107
-
108
- speech_filename = speech["filename"]
109
- speech_raw_duration = speech["raw_duration"]
110
- speech_offset = speech["offset"]
111
- speech_duration = speech["duration"]
112
-
113
- random1 = random.random()
114
- random2 = random.random()
115
-
116
- row = {
117
- "noise_filename": noise_filename,
118
- "noise_raw_duration": noise_raw_duration,
119
- "noise_offset": noise_offset,
120
- "noise_duration": noise_duration,
121
-
122
- "speech_filename": speech_filename,
123
- "speech_raw_duration": speech_raw_duration,
124
- "speech_offset": speech_offset,
125
- "speech_duration": speech_duration,
126
-
127
- "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
128
-
129
- "random1": random1,
130
- "random2": random2,
131
- "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
- }
133
- dataset.append(row)
134
- count += 1
135
- duration_seconds = count * args.duration
136
- duration_hours = duration_seconds / 3600
137
-
138
- process_bar.update(n=1)
139
- process_bar.set_postfix({
140
- # "duration_seconds": round(duration_seconds, 4),
141
- "duration_hours": round(duration_hours, 4),
142
-
143
- })
144
-
145
- dataset = pd.DataFrame(dataset)
146
- dataset = dataset.sort_values(by=["random1"], ascending=False)
147
- dataset.to_excel(
148
- file_dir / "dataset.xlsx",
149
- index=False,
150
- )
151
- return
152
-
153
-
154
-
155
- def split_dataset(args):
156
- """分割训练集, 测试集"""
157
- file_dir = Path(args.file_dir)
158
- file_dir.mkdir(exist_ok=True)
159
-
160
- df = pd.read_excel(file_dir / "dataset.xlsx")
161
-
162
- train = list()
163
- test = list()
164
-
165
- for i, row in df.iterrows():
166
- flag = row["flag"]
167
- if flag == "TRAIN":
168
- train.append(row)
169
- else:
170
- test.append(row)
171
-
172
- train = pd.DataFrame(train)
173
- train.to_excel(
174
- args.train_dataset,
175
- index=False,
176
- # encoding="utf_8_sig"
177
- )
178
- test = pd.DataFrame(test)
179
- test.to_excel(
180
- args.valid_dataset,
181
- index=False,
182
- # encoding="utf_8_sig"
183
- )
184
-
185
- return
186
-
187
-
188
- def main():
189
- args = get_args()
190
-
191
- get_dataset(args)
192
- split_dataset(args)
193
- return
194
-
195
-
196
- if __name__ == "__main__":
197
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_unet_irm_aishell/step_2_train_model.py DELETED
@@ -1,420 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
- """
6
- import argparse
7
- import json
8
- import logging
9
- from logging.handlers import TimedRotatingFileHandler
10
- import os
11
- import platform
12
- from pathlib import Path
13
- import random
14
- import sys
15
- import shutil
16
- from typing import List
17
-
18
- pwd = os.path.abspath(os.path.dirname(__file__))
19
- sys.path.append(os.path.join(pwd, "../../"))
20
-
21
- import numpy as np
22
- import torch
23
- import torch.nn as nn
24
- from torch.nn import functional as F
25
- from torch.utils.data.dataloader import DataLoader
26
- import torchaudio
27
- from tqdm import tqdm
28
-
29
- from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
30
- from toolbox.torchaudio.models.spectrum_unet_irm.configuration_specturm_unet_irm import SpectrumUnetIRMConfig
31
- from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel
32
-
33
-
34
- def get_args():
35
- parser = argparse.ArgumentParser()
36
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
-
39
- parser.add_argument("--max_epochs", default=100, type=int)
40
-
41
- parser.add_argument("--batch_size", default=64, type=int)
42
- parser.add_argument("--learning_rate", default=1e-4, type=float)
43
- parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
- parser.add_argument("--patience", default=5, type=int)
45
- parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
- parser.add_argument("--seed", default=0, type=int)
47
-
48
- parser.add_argument("--config_file", default="config.yaml", type=str)
49
-
50
- args = parser.parse_args()
51
- return args
52
-
53
-
54
- def logging_config(file_dir: str):
55
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
56
-
57
- logging.basicConfig(format=fmt,
58
- datefmt="%m/%d/%Y %H:%M:%S",
59
- level=logging.INFO)
60
- file_handler = TimedRotatingFileHandler(
61
- filename=os.path.join(file_dir, "main.log"),
62
- encoding="utf-8",
63
- when="D",
64
- interval=1,
65
- backupCount=7
66
- )
67
- file_handler.setLevel(logging.INFO)
68
- file_handler.setFormatter(logging.Formatter(fmt))
69
- logger = logging.getLogger(__name__)
70
- logger.addHandler(file_handler)
71
-
72
- return logger
73
-
74
-
75
- class CollateFunction(object):
76
- def __init__(self,
77
- n_fft: int = 512,
78
- win_length: int = 200,
79
- hop_length: int = 80,
80
- window_fn: str = "hamming",
81
- irm_beta: float = 1.0,
82
- epsilon: float = 1e-8,
83
- ):
84
- self.n_fft = n_fft
85
- self.win_length = win_length
86
- self.hop_length = hop_length
87
- self.window_fn = window_fn
88
- self.irm_beta = irm_beta
89
- self.epsilon = epsilon
90
-
91
- self.transform = torchaudio.transforms.Spectrogram(
92
- n_fft=self.n_fft,
93
- win_length=self.win_length,
94
- hop_length=self.hop_length,
95
- power=2.0,
96
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
97
- )
98
-
99
- @staticmethod
100
- def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
101
- batch_size, channels, freq_dim, time_steps = x.shape
102
-
103
- # kernel: [freq_dim, n_time_step]
104
- kernel_size = (freq_dim, n_time_steps)
105
-
106
- # pad
107
- pad = n_time_steps // 2
108
- x = torch.concat(tensors=[
109
- x[:, :, :, :pad],
110
- x,
111
- x[:, :, :, -pad:],
112
- ], dim=-1)
113
-
114
- x = F.unfold(
115
- input=x,
116
- kernel_size=kernel_size,
117
- )
118
- # x shape: [batch_size, fold, time_steps]
119
- return x
120
-
121
- def __call__(self, batch: List[dict]):
122
- mix_spec_list = list()
123
- speech_irm_list = list()
124
- snr_db_list = list()
125
- for sample in batch:
126
- noise_wave: torch.Tensor = sample["noise_wave"]
127
- speech_wave: torch.Tensor = sample["speech_wave"]
128
- mix_wave: torch.Tensor = sample["mix_wave"]
129
- # snr_db: float = sample["snr_db"]
130
-
131
- noise_spec = self.transform.forward(noise_wave)
132
- speech_spec = self.transform.forward(speech_wave)
133
- mix_spec = self.transform.forward(mix_wave)
134
-
135
- # noise_irm = noise_spec / (noise_spec + speech_spec)
136
- speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
137
- speech_irm = torch.pow(speech_irm, self.irm_beta)
138
-
139
- # noise_spec, speech_spec, mix_spec, speech_irm
140
- # shape: [freq_dim, time_steps]
141
-
142
- snr_db: torch.Tensor = 10 * torch.log10(
143
- speech_spec / (noise_spec + self.epsilon)
144
- )
145
- snr_db = torch.clamp(snr_db, min=self.epsilon)
146
-
147
- snr_db_ = torch.unsqueeze(snr_db, dim=0)
148
- snr_db_ = torch.unsqueeze(snr_db_, dim=0)
149
- snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
150
- snr_db_ = torch.squeeze(snr_db_, dim=0)
151
- # snr_db_ shape: [fold, time_steps]
152
-
153
- snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
154
- # snr_db shape: [1, time_steps]
155
-
156
- mix_spec_list.append(mix_spec)
157
- speech_irm_list.append(speech_irm)
158
- snr_db_list.append(snr_db)
159
-
160
- mix_spec_list = torch.stack(mix_spec_list)
161
- speech_irm_list = torch.stack(speech_irm_list)
162
- snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
163
-
164
- mix_spec_list = mix_spec_list[:, :-1, :]
165
- speech_irm_list = speech_irm_list[:, :-1, :]
166
-
167
- # mix_spec_list shape: [batch_size, freq_dim, time_steps]
168
- # speech_irm_list shape: [batch_size, freq_dim, time_steps]
169
- # snr_db shape: [batch_size, 1, time_steps]
170
-
171
- # assert
172
- if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
173
- raise AssertionError("nan or inf in mix_spec_list")
174
- if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
175
- raise AssertionError("nan or inf in speech_irm_list")
176
- if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
177
- raise AssertionError("nan or inf in snr_db_list")
178
-
179
- return mix_spec_list, speech_irm_list, snr_db_list
180
-
181
-
182
- collate_fn = CollateFunction()
183
-
184
-
185
- def main():
186
- args = get_args()
187
-
188
- serialization_dir = Path(args.serialization_dir)
189
- serialization_dir.mkdir(parents=True, exist_ok=True)
190
-
191
- logger = logging_config(serialization_dir)
192
-
193
- random.seed(args.seed)
194
- np.random.seed(args.seed)
195
- torch.manual_seed(args.seed)
196
- logger.info("set seed: {}".format(args.seed))
197
-
198
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
199
- n_gpu = torch.cuda.device_count()
200
- logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
201
-
202
- # datasets
203
- logger.info("prepare datasets")
204
- train_dataset = DenoiseExcelDataset(
205
- excel_file=args.train_dataset,
206
- expected_sample_rate=8000,
207
- max_wave_value=32768.0,
208
- )
209
- valid_dataset = DenoiseExcelDataset(
210
- excel_file=args.valid_dataset,
211
- expected_sample_rate=8000,
212
- max_wave_value=32768.0,
213
- )
214
- train_data_loader = DataLoader(
215
- dataset=train_dataset,
216
- batch_size=args.batch_size,
217
- shuffle=True,
218
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
219
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
220
- collate_fn=collate_fn,
221
- pin_memory=False,
222
- # prefetch_factor=64,
223
- )
224
- valid_data_loader = DataLoader(
225
- dataset=valid_dataset,
226
- batch_size=args.batch_size,
227
- shuffle=True,
228
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
229
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
230
- collate_fn=collate_fn,
231
- pin_memory=False,
232
- # prefetch_factor=64,
233
- )
234
-
235
- # models
236
- logger.info(f"prepare models. config_file: {args.config_file}")
237
- config = SpectrumUnetIRMConfig.from_pretrained(
238
- pretrained_model_name_or_path=args.config_file,
239
- # num_labels=vocabulary.get_vocab_size(namespace="labels")
240
- )
241
- model = SpectrumUnetIRMPretrainedModel(
242
- config=config,
243
- )
244
- model.to(device)
245
- model.train()
246
-
247
- # optimizer
248
- logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
249
- param_optimizer = model.parameters()
250
- optimizer = torch.optim.Adam(
251
- param_optimizer,
252
- lr=args.learning_rate,
253
- )
254
- # lr_scheduler = torch.optim.lr_scheduler.StepLR(
255
- # optimizer,
256
- # step_size=2000
257
- # )
258
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
259
- optimizer,
260
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
261
- )
262
- irm_mse_loss = nn.MSELoss(
263
- reduction="mean",
264
- )
265
- snr_mse_loss = nn.MSELoss(
266
- reduction="mean",
267
- )
268
-
269
- # training loop
270
- logger.info("training")
271
-
272
- training_loss = 10000000000
273
- evaluation_loss = 10000000000
274
-
275
- model_list = list()
276
- best_idx_epoch = None
277
- best_metric = None
278
- patience_count = 0
279
-
280
- for idx_epoch in range(args.max_epochs):
281
- total_loss = 0.
282
- total_examples = 0.
283
- progress_bar = tqdm(
284
- total=len(train_data_loader),
285
- desc="Training; epoch: {}".format(idx_epoch),
286
- )
287
-
288
- for batch in train_data_loader:
289
- mix_spec, speech_irm, snr_db = batch
290
- mix_spec = mix_spec.to(device)
291
- speech_irm_target = speech_irm.to(device)
292
- snr_db_target = snr_db.to(device)
293
-
294
- speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
295
- if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
296
- raise AssertionError("nan or inf in speech_irm_prediction")
297
- if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
298
- raise AssertionError("nan or inf in lsnr_prediction")
299
- irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
300
- lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
301
- if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
302
- raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
303
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
304
- if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
305
- raise AssertionError("nan or inf in snr_loss")
306
- # loss = irm_loss + 0.1 * snr_loss
307
- loss = 10.0 * irm_loss + 0.05 * snr_loss
308
- # loss = irm_loss
309
-
310
- total_loss += loss.item()
311
- total_examples += mix_spec.size(0)
312
-
313
- optimizer.zero_grad()
314
- loss.backward()
315
- optimizer.step()
316
- lr_scheduler.step()
317
-
318
- training_loss = total_loss / total_examples
319
- training_loss = round(training_loss, 4)
320
-
321
- progress_bar.update(1)
322
- progress_bar.set_postfix({
323
- "training_loss": training_loss,
324
- })
325
-
326
- total_loss = 0.
327
- total_examples = 0.
328
- progress_bar = tqdm(
329
- total=len(valid_data_loader),
330
- desc="Evaluation; epoch: {}".format(idx_epoch),
331
- )
332
- for batch in valid_data_loader:
333
- mix_spec, speech_irm, snr_db = batch
334
- mix_spec = mix_spec.to(device)
335
- speech_irm_target = speech_irm.to(device)
336
- snr_db_target = snr_db.to(device)
337
-
338
- with torch.no_grad():
339
- speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
340
- if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
341
- raise AssertionError("nan or inf in speech_irm_prediction")
342
- if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
343
- raise AssertionError("nan or inf in lsnr_prediction")
344
- irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
345
- lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
346
- if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
347
- raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
348
- snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
349
- # loss = irm_loss + 0.1 * snr_loss
350
- loss = 10.0 * irm_loss + 0.05 * snr_loss
351
- # loss = irm_loss
352
-
353
- total_loss += loss.item()
354
- total_examples += mix_spec.size(0)
355
-
356
- evaluation_loss = total_loss / total_examples
357
- evaluation_loss = round(evaluation_loss, 4)
358
-
359
- progress_bar.update(1)
360
- progress_bar.set_postfix({
361
- "evaluation_loss": evaluation_loss,
362
- })
363
-
364
- # save path
365
- epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
366
- epoch_dir.mkdir(parents=True, exist_ok=False)
367
-
368
- # save models
369
- model.save_pretrained(epoch_dir.as_posix())
370
-
371
- model_list.append(epoch_dir)
372
- if len(model_list) >= args.num_serialized_models_to_keep:
373
- model_to_delete: Path = model_list.pop(0)
374
- shutil.rmtree(model_to_delete.as_posix())
375
-
376
- # save metric
377
- if best_metric is None:
378
- best_idx_epoch = idx_epoch
379
- best_metric = evaluation_loss
380
- elif evaluation_loss < best_metric:
381
- best_idx_epoch = idx_epoch
382
- best_metric = evaluation_loss
383
- else:
384
- pass
385
-
386
- metrics = {
387
- "idx_epoch": idx_epoch,
388
- "best_idx_epoch": best_idx_epoch,
389
- "training_loss": training_loss,
390
- "evaluation_loss": evaluation_loss,
391
- "learning_rate": optimizer.param_groups[0]["lr"],
392
- }
393
- metrics_filename = epoch_dir / "metrics_epoch.json"
394
- with open(metrics_filename, "w", encoding="utf-8") as f:
395
- json.dump(metrics, f, indent=4, ensure_ascii=False)
396
-
397
- # save best
398
- best_dir = serialization_dir / "best"
399
- if best_idx_epoch == idx_epoch:
400
- if best_dir.exists():
401
- shutil.rmtree(best_dir)
402
- shutil.copytree(epoch_dir, best_dir)
403
-
404
- # early stop
405
- early_stop_flag = False
406
- if best_idx_epoch == idx_epoch:
407
- patience_count = 0
408
- else:
409
- patience_count += 1
410
- if patience_count >= args.patience:
411
- early_stop_flag = True
412
-
413
- # early stop
414
- if early_stop_flag:
415
- break
416
- return
417
-
418
-
419
- if __name__ == '__main__':
420
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_unet_irm_aishell/step_3_evaluation.py DELETED
@@ -1,270 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import logging
5
- import os
6
- from pathlib import Path
7
- import sys
8
- import uuid
9
-
10
- pwd = os.path.abspath(os.path.dirname(__file__))
11
- sys.path.append(os.path.join(pwd, "../../"))
12
-
13
- import librosa
14
- import numpy as np
15
- import pandas as pd
16
- from scipy.io import wavfile
17
- import torch
18
- import torch.nn as nn
19
- import torchaudio
20
- from tqdm import tqdm
21
-
22
- from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel
23
-
24
-
25
- def get_args():
26
- parser = argparse.ArgumentParser()
27
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
28
- parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
29
- parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
30
-
31
- parser.add_argument("--limit", default=10, type=int)
32
-
33
- args = parser.parse_args()
34
- return args
35
-
36
-
37
- def logging_config():
38
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
39
-
40
- logging.basicConfig(format=fmt,
41
- datefmt="%m/%d/%Y %H:%M:%S",
42
- level=logging.INFO)
43
- stream_handler = logging.StreamHandler()
44
- stream_handler.setLevel(logging.INFO)
45
- stream_handler.setFormatter(logging.Formatter(fmt))
46
-
47
- logger = logging.getLogger(__name__)
48
-
49
- return logger
50
-
51
-
52
- def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
53
- l1 = len(speech)
54
- l2 = len(noise)
55
- l = min(l1, l2)
56
- speech = speech[:l]
57
- noise = noise[:l]
58
-
59
- # np.float32, value between (-1, 1).
60
-
61
- speech_power = np.mean(np.square(speech))
62
- noise_power = speech_power / (10 ** (snr_db / 10))
63
-
64
- noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
65
-
66
- noisy_signal = speech + noise_adjusted
67
-
68
- return noisy_signal
69
-
70
-
71
- stft_power = torchaudio.transforms.Spectrogram(
72
- n_fft=512,
73
- win_length=200,
74
- hop_length=80,
75
- power=2.0,
76
- window_fn=torch.hamming_window,
77
- )
78
-
79
-
80
- stft_complex = torchaudio.transforms.Spectrogram(
81
- n_fft=512,
82
- win_length=200,
83
- hop_length=80,
84
- power=None,
85
- window_fn=torch.hamming_window,
86
- )
87
-
88
-
89
- istft = torchaudio.transforms.InverseSpectrogram(
90
- n_fft=512,
91
- win_length=200,
92
- hop_length=80,
93
- window_fn=torch.hamming_window,
94
- )
95
-
96
-
97
- def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
98
- mix_spec_complex = mix_spec_complex.detach().cpu()
99
- speech_irm_prediction = speech_irm_prediction.detach().cpu()
100
-
101
- mask_speech = speech_irm_prediction
102
- mask_noise = 1.0 - speech_irm_prediction
103
-
104
- speech_spec = mix_spec_complex * mask_speech
105
- noise_spec = mix_spec_complex * mask_noise
106
-
107
- speech_wave = istft.forward(speech_spec)
108
- noise_wave = istft.forward(noise_spec)
109
-
110
- return speech_wave, noise_wave
111
-
112
-
113
- def save_audios(noise_wave: torch.Tensor,
114
- speech_wave: torch.Tensor,
115
- mix_wave: torch.Tensor,
116
- speech_wave_enhanced: torch.Tensor,
117
- noise_wave_enhanced: torch.Tensor,
118
- output_dir: str,
119
- sample_rate: int = 8000,
120
- ):
121
- basename = uuid.uuid4().__str__()
122
- output_dir = Path(output_dir) / basename
123
- output_dir.mkdir(parents=True, exist_ok=True)
124
-
125
- filename = output_dir / "noise_wave.wav"
126
- torchaudio.save(filename, noise_wave, sample_rate)
127
- filename = output_dir / "speech_wave.wav"
128
- torchaudio.save(filename, speech_wave, sample_rate)
129
- filename = output_dir / "mix_wave.wav"
130
- torchaudio.save(filename, mix_wave, sample_rate)
131
-
132
- filename = output_dir / "speech_wave_enhanced.wav"
133
- torchaudio.save(filename, speech_wave_enhanced, sample_rate)
134
- filename = output_dir / "noise_wave_enhanced.wav"
135
- torchaudio.save(filename, noise_wave_enhanced, sample_rate)
136
-
137
- return output_dir.as_posix()
138
-
139
-
140
- def main():
141
- args = get_args()
142
-
143
- logger = logging_config()
144
-
145
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
- n_gpu = torch.cuda.device_count()
147
- logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
148
-
149
- logger.info("prepare model")
150
- model = SpectrumUnetIRMPretrainedModel.from_pretrained(
151
- pretrained_model_name_or_path=args.model_dir,
152
- )
153
- model.to(device)
154
- model.eval()
155
-
156
- # optimizer
157
- logger.info("prepare loss_fn")
158
- irm_mse_loss = nn.MSELoss(
159
- reduction="mean",
160
- )
161
- snr_mse_loss = nn.MSELoss(
162
- reduction="mean",
163
- )
164
-
165
- logger.info("read excel")
166
- df = pd.read_excel(args.valid_dataset)
167
-
168
- total_loss = 0.
169
- total_examples = 0.
170
- progress_bar = tqdm(total=len(df), desc="Evaluation")
171
- for idx, row in df.iterrows():
172
- noise_filename = row["noise_filename"]
173
- noise_offset = row["noise_offset"]
174
- noise_duration = row["noise_duration"]
175
-
176
- speech_filename = row["speech_filename"]
177
- speech_offset = row["speech_offset"]
178
- speech_duration = row["speech_duration"]
179
-
180
- snr_db = row["snr_db"]
181
-
182
- noise_wave, _ = librosa.load(
183
- noise_filename,
184
- sr=8000,
185
- offset=noise_offset,
186
- duration=noise_duration,
187
- )
188
- speech_wave, _ = librosa.load(
189
- speech_filename,
190
- sr=8000,
191
- offset=speech_offset,
192
- duration=speech_duration,
193
- )
194
- mix_wave: np.ndarray = mix_speech_and_noise(
195
- speech=speech_wave,
196
- noise=noise_wave,
197
- snr_db=snr_db,
198
- )
199
- noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
200
- speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
201
- mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
202
-
203
- noise_wave = noise_wave.unsqueeze(dim=0)
204
- speech_wave = speech_wave.unsqueeze(dim=0)
205
- mix_wave = mix_wave.unsqueeze(dim=0)
206
-
207
- noise_spec: torch.Tensor = stft_power.forward(noise_wave)
208
- speech_spec: torch.Tensor = stft_power.forward(speech_wave)
209
- mix_spec: torch.Tensor = stft_power.forward(mix_wave)
210
-
211
- noise_spec = noise_spec[:, :-1, :]
212
- speech_spec = speech_spec[:, :-1, :]
213
- mix_spec = mix_spec[:, :-1, :]
214
-
215
- mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
216
- # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
217
-
218
- speech_irm = speech_spec / (noise_spec + speech_spec)
219
- speech_irm = torch.pow(speech_irm, 1.0)
220
-
221
- snr_db: torch.Tensor = 10 * torch.log10(
222
- speech_spec / (noise_spec + 1e-8)
223
- )
224
- snr_db = torch.mean(snr_db, dim=1, keepdim=True)
225
- # snr_db shape: [batch_size, 1, time_steps]
226
-
227
- mix_spec = mix_spec.to(device)
228
- speech_irm_target = speech_irm.to(device)
229
- snr_db_target = snr_db.to(device)
230
-
231
- with torch.no_grad():
232
- speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
233
- irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
234
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
235
- # loss = irm_loss + 0.1 * snr_loss
236
- loss = irm_loss
237
-
238
- # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
239
- # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
240
- batch_size, _, time_steps = speech_irm_prediction.shape
241
- speech_irm_prediction = torch.concat(
242
- [
243
- speech_irm_prediction,
244
- 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
245
- ],
246
- dim=1,
247
- )
248
- # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
249
- speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
250
- save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
251
-
252
- total_loss += loss.item()
253
- total_examples += mix_spec.size(0)
254
-
255
- evaluation_loss = total_loss / total_examples
256
- evaluation_loss = round(evaluation_loss, 4)
257
-
258
- progress_bar.update(1)
259
- progress_bar.set_postfix({
260
- "evaluation_loss": evaluation_loss,
261
- })
262
-
263
- if idx > args.limit:
264
- break
265
-
266
- return
267
-
268
-
269
- if __name__ == '__main__':
270
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/spectrum_unet_irm_aishell/yaml/config.yaml DELETED
@@ -1,38 +0,0 @@
1
- model_name: "spectrum_unet_irm"
2
-
3
- # spec
4
- sample_rate: 8000
5
- n_fft: 512
6
- win_length: 200
7
- hop_length: 80
8
-
9
- spec_bins: 256
10
-
11
- # model
12
- conv_channels: 64
13
- conv_kernel_size_input:
14
- - 3
15
- - 3
16
- conv_kernel_size_inner:
17
- - 1
18
- - 3
19
- conv_lookahead: 0
20
-
21
- convt_kernel_size_inner:
22
- - 1
23
- - 3
24
-
25
- encoder_emb_skip_op: "none"
26
- encoder_emb_linear_groups: 16
27
- encoder_emb_hidden_size: 256
28
-
29
- lsnr_max: 30
30
- lsnr_min: -15
31
-
32
- decoder_emb_num_layers: 3
33
- decoder_emb_skip_op: "none"
34
- decoder_emb_linear_groups: 16
35
- decoder_emb_hidden_size: 256
36
-
37
- # runtime
38
- use_post_filter: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- docker build -t denoise:v20250609_1919 .
5
  docker stop denoise_7865 && docker rm denoise_7865
6
  docker run -itd \
7
  --name denoise_7865 \
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ docker build -t denoise:v20250626_1616 .
5
  docker stop denoise_7865 && docker rm denoise_7865
6
  docker run -itd \
7
  --name denoise_7865 \
toolbox/torch/utils/data/dataset/mp3_to_wav_jsonl_dataset.py DELETED
@@ -1,197 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import json
4
- import os
5
- import random
6
- from typing import List
7
- from pathlib import Path
8
- import tempfile
9
- import uuid
10
-
11
- from pydub import AudioSegment
12
- from scipy.io import wavfile
13
- import librosa
14
- import numpy as np
15
- import torch
16
- from torch.utils.data import Dataset, IterableDataset
17
-
18
-
19
- class Mp3ToWavJsonlDataset(IterableDataset):
20
- def __init__(self,
21
- jsonl_file: str,
22
- expected_sample_rate: int,
23
- resample: bool = False,
24
- max_wave_value: float = 1.0,
25
- buffer_size: int = 1000,
26
- eps: float = 1e-8,
27
- skip: int = 0,
28
- ):
29
- self.jsonl_file = jsonl_file
30
- self.expected_sample_rate = expected_sample_rate
31
- self.resample = resample
32
- self.max_wave_value = max_wave_value
33
- self.eps = eps
34
- self.skip = skip
35
-
36
- self.buffer_size = buffer_size
37
- self.buffer_samples: List[dict] = list()
38
-
39
- def __iter__(self):
40
- self.buffer_samples = list()
41
-
42
- iterable_source = self.iterable_source()
43
-
44
- try:
45
- for _ in range(self.skip):
46
- next(iterable_source)
47
- except StopIteration:
48
- pass
49
-
50
- # 初始填充缓冲区
51
- try:
52
- for _ in range(self.buffer_size):
53
- self.buffer_samples.append(next(iterable_source))
54
- except StopIteration:
55
- pass
56
-
57
- # 动态替换逻辑
58
- while True:
59
- try:
60
- item = next(iterable_source)
61
- # 随机替换缓冲区元素
62
- replace_idx = random.randint(0, len(self.buffer_samples) - 1)
63
- sample = self.buffer_samples[replace_idx]
64
- self.buffer_samples[replace_idx] = item
65
- yield self.convert_sample(sample)
66
- except StopIteration:
67
- break
68
-
69
- # 清空剩余元素
70
- random.shuffle(self.buffer_samples)
71
- for sample in self.buffer_samples:
72
- yield self.convert_sample(sample)
73
-
74
- def iterable_source(self):
75
- last_sample = None
76
- with open(self.jsonl_file, "r", encoding="utf-8") as f:
77
- for row in f:
78
- row = json.loads(row)
79
- filename = row["filename"]
80
- raw_duration = row["raw_duration"]
81
- offset = row["offset"]
82
- duration = row["duration"]
83
-
84
- sample = {
85
- "filename": filename,
86
- "raw_duration": raw_duration,
87
- "offset": offset,
88
- "duration": duration,
89
- }
90
- if last_sample is None:
91
- last_sample = sample
92
- continue
93
- yield sample
94
- yield last_sample
95
-
96
- def convert_sample(self, sample: dict):
97
- filename = sample["filename"]
98
- offset = sample["offset"]
99
- duration = sample["duration"]
100
-
101
- wav_waveform = self.filename_to_waveform(filename, offset, duration)
102
- mp3_waveform = self.filename_to_mp3_waveform(filename, offset, duration)
103
-
104
- if wav_waveform.shape != mp3_waveform.shape:
105
- raise AssertionError(f"wav_waveform: {wav_waveform.shape}, mp3_waveform: {mp3_waveform.shape}")
106
-
107
- result = {
108
- "mp3_waveform": mp3_waveform,
109
- "wav_waveform": wav_waveform,
110
- }
111
- return result
112
-
113
- @staticmethod
114
- def filename_to_waveform(filename: str, offset: float, duration: float, expected_sample_rate: int = 8000):
115
- try:
116
- waveform, sample_rate = librosa.load(
117
- filename,
118
- sr=expected_sample_rate,
119
- offset=offset,
120
- duration=duration,
121
- )
122
- except ValueError as e:
123
- print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
124
- raise e
125
- waveform = torch.tensor(waveform, dtype=torch.float32)
126
- return waveform
127
-
128
- @staticmethod
129
- def get_temporary_file(suffix: str = ".wav"):
130
- temp_audio_dir = Path(tempfile.gettempdir()) / "mp3_to_wav_jsonl_dataset"
131
- temp_audio_dir.mkdir(parents=True, exist_ok=True)
132
- filename = temp_audio_dir / f"{uuid.uuid4()}{suffix}"
133
- filename = filename.as_posix()
134
- return filename
135
-
136
- @staticmethod
137
- def filename_to_mp3_waveform(filename: str, offset: float, duration: float, expected_sample_rate: int = 8000):
138
- try:
139
- waveform, sample_rate = librosa.load(
140
- filename,
141
- sr=expected_sample_rate,
142
- offset=offset,
143
- duration=duration,
144
- )
145
- waveform = np.array(waveform * (1 << 15), dtype=np.int16)
146
- except ValueError as e:
147
- print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
148
- raise e
149
-
150
- wav_temporary_file = Mp3ToWavJsonlDataset.get_temporary_file(suffix=".wav")
151
- wavfile.write(
152
- wav_temporary_file,
153
- rate=sample_rate,
154
- data=waveform,
155
- )
156
-
157
- mp3_temporary_file = Mp3ToWavJsonlDataset.get_temporary_file(suffix=".mp3")
158
-
159
- audio = AudioSegment.from_wav(wav_temporary_file)
160
- audio.export(mp3_temporary_file,
161
- format="mp3",
162
- bitrate="64k", # 8kHz建议使用64kbps
163
- # parameters=["-ar", "8000"]
164
- parameters=["-ar", f"{expected_sample_rate}"]
165
- )
166
-
167
- try:
168
- waveform, sample_rate = librosa.load(mp3_temporary_file, sr=expected_sample_rate)
169
- except ValueError as e:
170
- print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
171
- raise e
172
-
173
- os.remove(wav_temporary_file)
174
- os.remove(mp3_temporary_file)
175
-
176
- waveform = torch.tensor(waveform, dtype=torch.float32)
177
- return waveform
178
-
179
-
180
- def main():
181
- filename = r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-PH\2025-06-13\active_media_r_2e6e6303-4a2e-4bc9-b814-98ceddc59e9d_23.wav"
182
-
183
- waveform = Mp3ToWavJsonlDataset.filename_to_mp3_waveform(filename, offset=0, duration=15)
184
- print(waveform.shape)
185
-
186
- signal = np.array(waveform.numpy() * (1 << 15), dtype=np.int16)
187
-
188
- wavfile.write(
189
- "temp.wav",
190
- 8000,
191
- signal,
192
- )
193
- return
194
-
195
-
196
- if __name__ == "__main__":
197
- main()